From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/transformer/model.py |  257 ++++++++++++++++++++++++++-------------------------
 1 files changed, 132 insertions(+), 125 deletions(-)

diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py
index e2367a7..adfd525 100644
--- a/funasr/models/transformer/model.py
+++ b/funasr/models/transformer/model.py
@@ -10,6 +10,7 @@
 from funasr.models.ctc.ctc import CTC
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.metrics.compute_acc import th_accuracy
+
 # from funasr.models.e2e_asr_common import ErrorCalculator
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@@ -17,25 +18,23 @@
 from funasr.utils.datadir_writer import DatadirWriter
 from funasr.register import tables
 
+
 @tables.register("model_classes", "Transformer")
 class Transformer(nn.Module):
     """CTC-attention hybrid Encoder-Decoder model"""
 
-    
     def __init__(
         self,
-        frontend: Optional[str] = None,
-        frontend_conf: Optional[Dict] = None,
-        specaug: Optional[str] = None,
-        specaug_conf: Optional[Dict] = None,
+        specaug: str = None,
+        specaug_conf: dict = None,
         normalize: str = None,
-        normalize_conf: Optional[Dict] = None,
+        normalize_conf: dict = None,
         encoder: str = None,
-        encoder_conf: Optional[Dict] = None,
+        encoder_conf: dict = None,
         decoder: str = None,
-        decoder_conf: Optional[Dict] = None,
+        decoder_conf: dict = None,
         ctc: str = None,
-        ctc_conf: Optional[Dict] = None,
+        ctc_conf: dict = None,
         ctc_weight: float = 0.5,
         interctc_weight: float = 0.0,
         input_size: int = 80,
@@ -59,41 +58,35 @@
 
         super().__init__()
 
-        if frontend is not None:
-            frontend_class = tables.frontend_classes.get_class(frontend)
-            frontend = frontend_class(**frontend_conf)
         if specaug is not None:
-            specaug_class = tables.specaug_classes.get_class(specaug)
+            specaug_class = tables.specaug_classes.get(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = tables.normalize_classes.get_class(normalize)
+            normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get_class(encoder)
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
         if decoder is not None:
-            decoder_class = tables.decoder_classes.get_class(decoder)
+            decoder_class = tables.decoder_classes.get(decoder)
             decoder = decoder_class(
                 vocab_size=vocab_size,
                 encoder_output_size=encoder_output_size,
                 **decoder_conf,
             )
         if ctc_weight > 0.0:
-            
+
             if ctc_conf is None:
                 ctc_conf = {}
-            
-            ctc = CTC(
-                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
-            )
-    
+
+            ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
+
         self.blank_id = blank_id
         self.sos = sos if sos is not None else vocab_size - 1
         self.eos = eos if eos is not None else vocab_size - 1
         self.vocab_size = vocab_size
         self.ignore_id = ignore_id
         self.ctc_weight = ctc_weight
-        self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
         self.encoder = encoder
@@ -111,7 +104,7 @@
             self.decoder = None
         else:
             self.decoder = decoder
-        
+
         self.criterion_att = LabelSmoothingLoss(
             size=vocab_size,
             padding_idx=ignore_id,
@@ -124,18 +117,19 @@
         #         token_list, sym_space, sym_blank, report_cer, report_wer
         #     )
         #
+        self.error_calculator = None
         if ctc_weight == 0.0:
             self.ctc = None
         else:
             self.ctc = ctc
-            
+
         self.share_embedding = share_embedding
         if self.share_embedding:
             self.decoder.embed = None
-        
+
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
-    
+
     def forward(
         self,
         speech: torch.Tensor,
@@ -151,36 +145,34 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
             speech_lengths = speech_lengths[:, 0]
-        
+
         batch_size = speech.shape[0]
-        
+
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
             encoder_out = encoder_out[0]
-        
+
         loss_att, acc_att, cer_att, wer_att = None, None, None, None
         loss_ctc, cer_ctc = None, None
         stats = dict()
-        
+
         # decoder: CTC branch
         if self.ctc_weight != 0.0:
             loss_ctc, cer_ctc = self._calc_ctc_loss(
                 encoder_out, encoder_out_lens, text, text_lengths
             )
-            
+
             # Collect CTC branch stats
             stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
             stats["cer_ctc"] = cer_ctc
-        
+
         # Intermediate CTC (optional)
         loss_interctc = 0.0
         if self.interctc_weight != 0.0 and intermediate_outs is not None:
@@ -191,25 +183,23 @@
                     intermediate_out, encoder_out_lens, text, text_lengths
                 )
                 loss_interctc = loss_interctc + loss_ic
-                
+
                 # Collect Intermedaite CTC stats
                 stats["loss_interctc_layer{}".format(layer_idx)] = (
                     loss_ic.detach() if loss_ic is not None else None
                 )
                 stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-            
+
             loss_interctc = loss_interctc / len(intermediate_outs)
-            
+
             # calculate whole encoder loss
-            loss_ctc = (
-                           1 - self.interctc_weight
-                       ) * loss_ctc + self.interctc_weight * loss_interctc
-        
+            loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
+
         # decoder: Attention decoder branch
         loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
             encoder_out, encoder_out_lens, text, text_lengths
         )
-        
+
         # 3. CTC-Att loss definition
         if self.ctc_weight == 0.0:
             loss = loss_att
@@ -217,25 +207,27 @@
             loss = loss_ctc
         else:
             loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-        
+
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
         stats["acc"] = acc_att
         stats["cer"] = cer_att
         stats["wer"] = wer_att
-        
+
         # Collect total loss stats
         stats["loss"] = torch.clone(loss.detach())
-        
+
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
             batch_size = int((text_lengths + 1).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
-    
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -248,30 +240,28 @@
             # Data augmentation
             if self.specaug is not None and self.training:
                 speech, speech_lengths = self.specaug(speech, speech_lengths)
-            
+
             # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
             if self.normalize is not None:
                 speech, speech_lengths = self.normalize(speech, speech_lengths)
-        
+
         # Forward encoder
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
         if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                speech, speech_lengths, ctc=self.ctc
-            )
+            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
         else:
             encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
             encoder_out = encoder_out[0]
-        
+
         if intermediate_outs is not None:
             return (encoder_out, intermediate_outs), encoder_out_lens
-        
+
         return encoder_out, encoder_out_lens
-    
+
     def _calc_att_loss(
         self,
         encoder_out: torch.Tensor,
@@ -281,12 +271,10 @@
     ):
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
-        
+
         # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
-        
+        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
+
         # 2. Compute attention loss
         loss_att = self.criterion_att(decoder_out, ys_out_pad)
         acc_att = th_accuracy(
@@ -294,16 +282,16 @@
             ys_out_pad,
             ignore_label=self.ignore_id,
         )
-        
+
         # Compute cer/wer using attention-decoder
         if self.training or self.error_calculator is None:
             cer_att, wer_att = None, None
         else:
             ys_hat = decoder_out.argmax(dim=-1)
             cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-        
+
         return loss_att, acc_att, cer_att, wer_att
-    
+
     def _calc_ctc_loss(
         self,
         encoder_out: torch.Tensor,
@@ -313,49 +301,48 @@
     ):
         # Calc CTC loss
         loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-        
+
         # Calc CER using CTC
         cer_ctc = None
         if not self.training and self.error_calculator is not None:
             ys_hat = self.ctc.argmax(encoder_out).data
             cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
         return loss_ctc, cer_ctc
-    
-    def init_beam_search(self,
-                         **kwargs,
-                         ):
+
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
         from funasr.models.transformer.search import BeamSearch
         from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
         from funasr.models.transformer.scorers.length_bonus import LengthBonus
-    
+
         # 1. Build ASR model
         scorers = {}
-        
+
         if self.ctc != None:
             ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
-            scorers.update(
-                ctc=ctc
-            )
+            scorers.update(ctc=ctc)
         token_list = kwargs.get("token_list")
         scorers.update(
+            decoder=self.decoder,
             length_bonus=LengthBonus(len(token_list)),
         )
 
-        
         # 3. Build ngram model
         # ngram is not supported now
         ngram = None
         scorers["ngram"] = ngram
-        
+
         weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
-            ctc=kwargs.get("decoding_ctc_weight", 0.0),
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
+            ctc=kwargs.get("decoding_ctc_weight", 0.5),
             lm=kwargs.get("lm_weight", 0.0),
             ngram=kwargs.get("ngram_weight", 0.0),
             length_bonus=kwargs.get("penalty", 0.0),
         )
         beam_search = BeamSearch(
-            beam_size=kwargs.get("beam_size", 2),
+            beam_size=kwargs.get("beam_size", 10),
             weights=weights,
             scorers=scorers,
             sos=self.sos,
@@ -364,57 +351,73 @@
             token_list=token_list,
             pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
         )
-        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
-        # for scorer in scorers.values():
-        #     if isinstance(scorer, torch.nn.Module):
-        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+
         self.beam_search = beam_search
-        
-    def generate(self,
-             data_in: list,
-             data_lengths: list=None,
-             key: list=None,
-             tokenizer=None,
-             **kwargs,
-             ):
-        
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
-        
+
         # init beamsearch
-        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
-        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-        if self.beam_search is None and (is_use_lm or is_use_ctc):
+        if self.beam_search is None:
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)
-        
+
         meta_data = {}
-        # extract fbank feats
-        time1 = time.perf_counter()
-        audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
-        time2 = time.perf_counter()
-        meta_data["load_data"] = f"{time2 - time1:0.3f}"
-        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
-        time3 = time.perf_counter()
-        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-        meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-        
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
-
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
-        
+
         # c. Passed the encoder result and the beam search
         nbest_hyps = self.beam_search(
-            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+            x=encoder_out[0],
+            maxlenratio=kwargs.get("maxlenratio", 0.0),
+            minlenratio=kwargs.get("minlenratio", 0.0),
         )
-        
-        nbest_hyps = nbest_hyps[: self.nbest]
 
+        nbest_hyps = nbest_hyps[: self.nbest]
 
         results = []
         b, n, d = encoder_out.size()
@@ -422,31 +425,35 @@
 
             for nbest_idx, hyp in enumerate(nbest_hyps):
                 ibest_writer = None
-                if ibest_writer is None and kwargs.get("output_dir") is not None:
-                    writer = DatadirWriter(kwargs.get("output_dir"))
-                    ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
                 # remove sos/eos and get results
                 last_pos = -1
                 if isinstance(hyp.yseq, list):
                     token_int = hyp.yseq[1:last_pos]
                 else:
                     token_int = hyp.yseq[1:last_pos].tolist()
-                    
+
                 # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-                
+                token_int = list(
+                    filter(
+                        lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                    )
+                )
+
                 # Change integer-ids to tokens
                 token = tokenizer.ids2tokens(token_int)
                 text = tokenizer.tokens2text(token)
-                
+
                 text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
                 results.append(result_i)
-                
+
                 if ibest_writer is not None:
                     ibest_writer["token"][key[i]] = " ".join(token)
-                    ibest_writer["text"][key[i]] = text
-                    ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-        
-        return results, meta_data
+                    ibest_writer["text"][key[i]] = text_postprocessed
 
+        return results, meta_data

--
Gitblit v1.9.1