From a2a70f776ac46dc8987a05459de260ff2825ffbc Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 22 五月 2023 15:39:09 +0800
Subject: [PATCH] add paraforme infer code
---
funasr/bin/asr_infer.py | 4 +++-
funasr/models/e2e_asr_paraformer.py | 33 ++++++++++++++++++++++++++-------
2 files changed, 29 insertions(+), 8 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 9da7ef7..f9d6bf7 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -305,6 +305,7 @@
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
+ decoding_ind: int = 0,
**kwargs,
):
assert check_argument_types()
@@ -415,6 +416,7 @@
self.nbest = nbest
self.frontend = frontend
self.encoder_downsampling_factor = 1
+ self.decoding_ind = decoding_ind
if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
@@ -452,7 +454,7 @@
batch = to_device(batch, device=self.device)
# b. Forward Encoder
- enc, enc_len = self.asr_model.encode(**batch)
+ enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
if isinstance(enc, tuple):
enc = enc[0]
# assert len(enc) == 1, len(enc)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 9241271..8a4d4a0 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -153,6 +153,7 @@
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
+ decoding_ind: int = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
@@ -176,7 +177,11 @@
speech = speech[:, :speech_lengths.max()]
# 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
@@ -272,7 +277,7 @@
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
@@ -299,11 +304,25 @@
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc
- )
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ feats, feats_lengths, ctc=self.ctc, ind=ind
+ )
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ feats, feats_lengths, ctc=self.ctc
+ )
else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
@@ -1800,4 +1819,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
\ No newline at end of file
+ return var_dict_torch_update
--
Gitblit v1.9.1