| | |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | class AbsLM(torch.nn.Module, BatchScorerInterface, ABC): |
| | | """The abstract LM class |
| | |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class LanguageModel(AbsESPnetModel): |
| | | class LanguageModel(FunASRModel): |
| | | def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
| | | assert check_argument_types() |
| | | super().__init__() |