| | |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | |
| | | |
| | | class LanguageModel(FunASRModel): |
| | | def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.lm = lm |
| | | self.sos = 1 |
| | |
| | | class PunctuationModel(FunASRModel): |
| | | |
| | | def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.punc_model = punc_model |
| | | self.punc_weight = torch.Tensor(punc_weight) |