From a2a70f776ac46dc8987a05459de260ff2825ffbc Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 22 五月 2023 15:39:09 +0800
Subject: [PATCH] add paraforme infer code

---
 funasr/bin/asr_infer.py             |    4 +++-
 funasr/models/e2e_asr_paraformer.py |   33 ++++++++++++++++++++++++++-------
 2 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 9da7ef7..f9d6bf7 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -305,6 +305,7 @@
             nbest: int = 1,
             frontend_conf: dict = None,
             hotword_list_or_file: str = None,
+            decoding_ind: int = 0,
             **kwargs,
     ):
         assert check_argument_types()
@@ -415,6 +416,7 @@
         self.nbest = nbest
         self.frontend = frontend
         self.encoder_downsampling_factor = 1
+        self.decoding_ind = decoding_ind
         if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
             self.encoder_downsampling_factor = 4
 
@@ -452,7 +454,7 @@
         batch = to_device(batch, device=self.device)
 
         # b. Forward Encoder
-        enc, enc_len = self.asr_model.encode(**batch)
+        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
         if isinstance(enc, tuple):
             enc = enc[0]
         # assert len(enc) == 1, len(enc)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 9241271..8a4d4a0 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -153,6 +153,7 @@
             speech_lengths: torch.Tensor,
             text: torch.Tensor,
             text_lengths: torch.Tensor,
+            decoding_ind: int = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -176,7 +177,11 @@
         speech = speech[:, :speech_lengths.max()]
 
         # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if hasattr(self.encoder, "overlap_chunk_cls"):
+            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -272,7 +277,7 @@
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -299,11 +304,25 @@
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
         if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                feats, feats_lengths, ctc=self.ctc
-            )
+            if hasattr(self.encoder, "overlap_chunk_cls"):
+                encoder_out, encoder_out_lens, _ = self.encoder(
+                    feats, feats_lengths, ctc=self.ctc, ind=ind
+                )
+                encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                            encoder_out_lens,
+                                                                                            chunk_outs=None)
+            else:
+                encoder_out, encoder_out_lens, _ = self.encoder(
+                    feats, feats_lengths, ctc=self.ctc
+                )
         else:
-            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+            if hasattr(self.encoder, "overlap_chunk_cls"):
+                encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
+                encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                            encoder_out_lens,
+                                                                                            chunk_outs=None)
+            else:
+                encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -1800,4 +1819,4 @@
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
\ No newline at end of file
+        return var_dict_torch_update

--
Gitblit v1.9.1