zhifu gao
2024-03-11 9d48230c4f8f25bf88c5d6105f97370a36c9cf43
funasr/models/paraformer/model.py
@@ -554,21 +554,23 @@
        max_seq_len=512,
        **kwargs,
    ):
        onnx = kwargs.get("onnx", True)
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        self.encoder = encoder_class(self.encoder, onnx=onnx)
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        self.predictor = predictor_class(self.predictor, onnx=onnx)
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        self.decoder = decoder_class(self.decoder, onnx=onnx)
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        
        if onnx:
        if is_onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)