from typing import Tuple import torch import torch.nn as nn class TargetDelayTransformer(nn.Module): def __init__( self, model, max_seq_len=512, model_name='punc_model', **kwargs, ): super().__init__() onnx = False if "onnx" in kwargs: onnx = kwargs["onnx"] self.embed = model.embed self.decoder = model.decoder # self.model = model self.feats_dim = self.embed.embedding_dim self.num_embeddings = self.embed.num_embeddings self.model_name = model_name # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder from funasr.models.encoder.sanm_encoder import SANMEncoder from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export if isinstance(model.encoder, SANMEncoder): self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) else: assert False, "Only support samn encode." def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: """Compute loss value from buffer sequences. Args: input (torch.Tensor): Input ids. (batch, len) hidden (torch.Tensor): Target ids. (batch, len) """ x = self.embed(input) # mask = self._target_mask(input) h, _ = self.encoder(x, text_lengths) y = self.decoder(h) return y def get_dummy_inputs(self): length = 120 text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) text_lengths = torch.tensor([length-20, length], dtype=torch.int32) return (text_indexes, text_lengths) def get_input_names(self): return ['input', 'text_lengths'] def get_output_names(self): return ['logits'] def get_dynamic_axes(self): return { 'input': { 0: 'batch_size', 1: 'feats_length' }, 'text_lengths': { 0: 'batch_size', }, 'logits': { 0: 'batch_size', 1: 'logits_length' }, }