| | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export") |
| | | self.encoder = encoder_class(self.encoder, onnx=is_onnx) |
| | | |
| | | self.forward = self._export_forward |
| | | self.forward = self.export_forward |
| | | |
| | | return self |
| | | |
| | | def _export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor): |
| | | def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor): |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |