From 0d9384c8c0161259192cc3d676ca0d60e0d18e5c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 22:32:26 +0800
Subject: [PATCH] Dev gzf (#1474)

---
 funasr/models/paraformer/model.py |  148 +++++++++++++++++++++++++++++++++++++------------
 1 files changed, 111 insertions(+), 37 deletions(-)

diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 2cd9c88..41a1bf7 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -1,35 +1,30 @@
-import os
-import logging
-from typing import Union, Dict, List, Tuple, Optional
-
-import torch
-import torch.nn as nn
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import time
-
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
-)
-
-from funasr.models.paraformer.cif_predictor import mae_loss
-
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-
-from funasr.models.paraformer.search import Hypothesis
-
+import torch
+import logging
 from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
 
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
 from funasr.register import tables
 from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.train_utils.device_funcs import to_device
 
 @tables.register("model_classes", "Paraformer")
-class Paraformer(nn.Module):
+class Paraformer(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -38,7 +33,6 @@
     
     def __init__(
         self,
-        # token_list: Union[Tuple[str, ...], List[str]],
         specaug: Optional[str] = None,
         specaug_conf: Optional[Dict] = None,
         normalize: str = None,
@@ -160,8 +154,8 @@
         self.predictor_bias = predictor_bias
         self.sampling_ratio = sampling_ratio
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
-        # self.step_cur = 0
-        #
+
+
         self.share_embedding = share_embedding
         if self.share_embedding:
             self.decoder.embed = None
@@ -169,6 +163,7 @@
         self.use_1st_decoder_loss = use_1st_decoder_loss
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
+        self.error_calculator = None
     
     def forward(
         self,
@@ -439,7 +434,7 @@
         #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
         self.beam_search = beam_search
         
-    def generate(self,
+    def inference(self,
              data_in,
              data_lengths=None,
              key: list=None,
@@ -456,11 +451,13 @@
             self.nbest = kwargs.get("nbest", 1)
         
         meta_data = {}
-        if isinstance(data_in, torch.Tensor): # fbank
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
             speech, speech_lengths = data_in, data_lengths
             if len(speech.shape) < 3:
                 speech = speech[None, :, :]
-            if speech_lengths is None:
+            if speech_lengths is not None:
+                speech_lengths = speech_lengths.squeeze(-1)
+            else:
                 speech_lengths = speech.shape[1]
         else:
             # extract fbank feats
@@ -496,6 +493,8 @@
         b, n, d = decoder_out.size()
         if isinstance(key[0], (list, tuple)):
             key = key[0]
+        if len(key) < b:
+            key = key*b
         for i in range(b):
             x = encoder_out[i, :encoder_out_lens[i], :]
             am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -517,9 +516,10 @@
                 nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
             for nbest_idx, hyp in enumerate(nbest_hyps):
                 ibest_writer = None
-                if ibest_writer is None and kwargs.get("output_dir") is not None:
-                    writer = DatadirWriter(kwargs.get("output_dir"))
-                    ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
                 # remove sos/eos and get results
                 last_pos = -1
                 if isinstance(hyp.yseq, list):
@@ -533,13 +533,12 @@
                 if tokenizer is not None:
                     # Change integer-ids to tokens
                     token = tokenizer.ids2tokens(token_int)
-                    text = tokenizer.tokens2text(token)
-                    
-                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                    text_postprocessed = tokenizer.tokens2text(token)
+                    if not hasattr(tokenizer, "bpemodel"):
+                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                     
                     result_i = {"key": key[i], "text": text_postprocessed}
 
-                    
                     if ibest_writer is not None:
                         ibest_writer["token"][key[i]] = " ".join(token)
                         # ibest_writer["text"][key[i]] = text
@@ -550,3 +549,78 @@
                 
         return results, meta_data
 
+    def export(
+        self,
+        max_seq_len=512,
+        **kwargs,
+    ):
+        self.device = kwargs.get("device")
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+        
+        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
+        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+
+
+        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
+        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
+        
+        from funasr.utils.torch_function import sequence_mask
+
+
+        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+        self.forward = self.export_forward
+        
+        return self
+
+    def export_forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+        pre_token_length = pre_token_length.floor().type(torch.int32)
+    
+        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        # sample_ids = decoder_out.argmax(dim=-1)
+    
+        return decoder_out, pre_token_length
+
+    def export_dummy_inputs(self):
+        speech = torch.randn(2, 30, 560)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+
+
+    def export_input_names(self):
+        return ['speech', 'speech_lengths']
+
+    def export_output_names(self):
+        return ['logits', 'token_num']
+
+    def export_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+
+    def export_name(self, ):
+        return "model.onnx"

--
Gitblit v1.9.1