| | |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | if V(torch.__version__) >= V("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | yield |
| | | |
| | | |
| | | class TransducerModel(AbsESPnetModel): |
| | | class TransducerModel(FunASRModel): |
| | | """ESPnet2ASRTransducerModel module definition. |
| | | |
| | | Args: |
| | |
| | | |
| | | return loss_lm |
| | | |
| | | class UnifiedTransducerModel(AbsESPnetModel): |
| | | class UnifiedTransducerModel(FunASRModel): |
| | | """ESPnet2ASRTransducerModel module definition. |
| | | Args: |
| | | vocab_size: Size of complete vocabulary (w/ EOS and blank included). |