From e62e208a5cf632576cea92207d1005c2b7781844 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:50:52 +0800
Subject: [PATCH] Merge pull request #251 from alibaba-damo-academy/dev_lhn

---
 funasr/models/e2e_asr_paraformer.py |   74 ++++++++++++++++++++++++++++++++++++
 1 files changed, 73 insertions(+), 1 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 44c9de3..02f60af 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -325,12 +325,76 @@
 
         return encoder_out, encoder_out_lens
 
+    def encode_chunk(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+                feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, encoder_out_lens
+
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
                                                                                   ignore_id=self.ignore_id)
+        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+    def calc_predictor_chunk(self, encoder_out, cache=None):
+
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -341,6 +405,14 @@
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
+
+    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+        decoder_outs = self.decoder.forward_chunk(
+            encoder_out, sematic_embeds, cache["decoder"]
+        )
+        decoder_out = decoder_outs
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out
 
     def _extract_feats(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
@@ -1459,4 +1531,4 @@
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
\ No newline at end of file
+        return var_dict_torch_update

--
Gitblit v1.9.1