| | |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | |
| | | class ESPnetLanguageModel(AbsESPnetModel): |
| | | class LanguageModel(AbsESPnetModel): |
| | | def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | |
| | | |
| | | # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| | | # text: (Batch, Length) -> x, y: (Batch, Length + 1) |
| | | x = F.pad(text, [1, 0], "constant", self.eos) |
| | | x = F.pad(text, [1, 0], "constant", self.sos) |
| | | t = F.pad(text, [0, 1], "constant", self.ignore_id) |
| | | for i, l in enumerate(text_lengths): |
| | | t[i, l] = self.sos |
| | | t[i, l] = self.eos |
| | | x_lengths = text_lengths + 1 |
| | | |
| | | # 2. Forward Language model |