zhifu gao
2024-03-11 a7d7a0f3a2e7cd44a337ced34e3536b12ccb534e
funasr/models/paraformer_streaming/model.py
@@ -566,7 +566,7 @@
        max_seq_len=512,
        **kwargs,
    ):
        self.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -612,7 +612,7 @@
    
        return encoder_model, decoder_model
    def _export_encoder_forward(
    def export_encoder_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
@@ -663,7 +663,7 @@
    def export_encoder_name(self):
        return "model.onnx"
    
    def _export_decoder_forward(
    def export_decoder_forward(
        self,
        enc: torch.Tensor,
        enc_len: torch.Tensor,