| | |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | def export_rebuild_model(model, **kwargs): |
| | | is_onnx = kwargs.get("type", "onnx") == "onnx" |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") |
| | |
| | | |
| | | from funasr.utils.torch_function import sequence_mask |
| | | |
| | | model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False) |
| | | model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) |
| | | |
| | | model.forward = types.MethodType(export_forward, model) |
| | | model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) |
| | |
| | | model.export_name = types.MethodType(export_name, model) |
| | | |
| | | return model |
| | | |
| | | |
| | | def export_forward( |
| | | self, |
| | |
| | | |
| | | return decoder_out, pre_token_length, us_alphas, us_cif_peak |
| | | |
| | | |
| | | def export_dummy_inputs(self): |
| | | speech = torch.randn(2, 30, 560) |
| | | speech_lengths = torch.tensor([6, 30], dtype=torch.int32) |
| | | return (speech, speech_lengths) |
| | | |
| | | |
| | | def export_input_names(self): |
| | | return ['speech', 'speech_lengths'] |
| | | return ["speech", "speech_lengths"] |
| | | |
| | | |
| | | def export_output_names(self): |
| | | return ['logits', 'token_num', 'us_alphas', 'us_cif_peak'] |
| | | return ["logits", "token_num", "us_alphas", "us_cif_peak"] |
| | | |
| | | |
| | | def export_dynamic_axes(self): |
| | | return { |
| | | 'speech': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | "speech": {0: "batch_size", 1: "feats_length"}, |
| | | "speech_lengths": { |
| | | 0: "batch_size", |
| | | }, |
| | | 'speech_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | 'us_alphas': { |
| | | 0: 'batch_size', |
| | | 1: 'alphas_length' |
| | | }, |
| | | 'us_cif_peak': { |
| | | 0: 'batch_size', |
| | | 1: 'alphas_length' |
| | | }, |
| | | "logits": {0: "batch_size", 1: "logits_length"}, |
| | | "us_alphas": {0: "batch_size", 1: "alphas_length"}, |
| | | "us_cif_peak": {0: "batch_size", 1: "alphas_length"}, |
| | | } |
| | | |
| | | |
| | | def export_name(self): |
| | | return "model.onnx" |
| | | return "model.onnx" |