| | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.utils.register import register_class, registry_tables |
| | | from funasr.register import tables |
| | | |
| | | @register_class("model_classes", "CTTransformer") |
| | | @tables.register("model_classes", "CTTransformer") |
| | | class CTTransformer(nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | |
| | | |
| | | |
| | | self.embed = nn.Embedding(vocab_size, embed_unit) |
| | | encoder_class = registry_tables.encoder_classes.get(encoder.lower()) |
| | | encoder_class = tables.encoder_classes.get(encoder.lower()) |
| | | encoder = encoder_class(**encoder_conf) |
| | | |
| | | self.decoder = nn.Linear(att_unit, punc_size) |