From d79287c37e4e7ae2694a992cbbfb03a5ca4f7670 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 20 二月 2024 14:05:58 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge

---
 funasr/models/transducer/model.py |  155 ++++++++++++++++-----------------------------------
 1 files changed, 48 insertions(+), 107 deletions(-)

diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py
index 906aa60..fd8ad71 100644
--- a/funasr/models/transducer/model.py
+++ b/funasr/models/transducer/model.py
@@ -1,42 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
 import logging
 from contextlib import contextmanager
+from typing import Dict, Optional, Tuple
 from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.frontends.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.paraformer.cif_predictor import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.paraformer.cif_predictor import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
 
-from funasr.models.model_class_factory import *
+from funasr.register import tables
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.scorers.length_bonus import LengthBonus
+from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
+
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -45,16 +29,10 @@
     @contextmanager
     def autocast(enabled=True):
         yield
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
 
 
-class Transducer(nn.Module):
-    """ESPnet2ASRTransducerModel module definition."""
-
-    
+@tables.register("model_classes", "Transducer")
+class Transducer(torch.nn.Module):
     def __init__(
         self,
         frontend: Optional[str] = None,
@@ -96,35 +74,30 @@
 
         super().__init__()
 
-        if frontend is not None:
-            frontend_class = frontend_classes.get_class(frontend)
-            frontend = frontend_class(**frontend_conf)
         if specaug is not None:
-            specaug_class = specaug_classes.get_class(specaug)
+            specaug_class = tables.specaug_classes.get(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = normalize_classes.get_class(normalize)
+            normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = encoder_classes.get_class(encoder)
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
 
-        decoder_class = decoder_classes.get_class(decoder)
+        decoder_class = tables.decoder_classes.get(decoder)
         decoder = decoder_class(
             vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
             **decoder_conf,
         )
         decoder_output_size = decoder.output_size
 
-        joint_network_class = joint_network_classes.get_class(decoder)
+        joint_network_class = tables.joint_network_classes.get(joint_network)
         joint_network = joint_network_class(
             vocab_size,
             encoder_output_size,
             decoder_output_size,
             **joint_network_conf,
         )
-        
         
         self.criterion_transducer = None
         self.error_calculator = None
@@ -157,23 +130,17 @@
         self.decoder = decoder
         self.joint_network = joint_network
 
-
-        
         self.criterion_att = LabelSmoothingLoss(
             size=vocab_size,
             padding_idx=ignore_id,
             smoothing=lsm_weight,
             normalize_length=length_normalized_loss,
         )
-        #
-        # if report_cer or report_wer:
-        #     self.error_calculator = ErrorCalculator(
-        #         token_list, sym_space, sym_blank, report_cer, report_wer
-        #     )
-        #
 
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
+        self.ctc = None
+        self.ctc_weight = 0.0
     
     def forward(
         self,
@@ -190,8 +157,6 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -283,12 +248,7 @@
         # Forward encoder
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                speech, speech_lengths, ctc=self.ctc
-            )
-        else:
-            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -449,9 +409,6 @@
     def init_beam_search(self,
                          **kwargs,
                          ):
-        from funasr.models.transformer.search import BeamSearch
-        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
-        from funasr.models.transformer.scorers.length_bonus import LengthBonus
     
         # 1. Build ASR model
         scorers = {}
@@ -466,28 +423,16 @@
             length_bonus=LengthBonus(len(token_list)),
         )
 
-        
         # 3. Build ngram model
         # ngram is not supported now
         ngram = None
         scorers["ngram"] = ngram
         
-        weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
-            ctc=kwargs.get("decoding_ctc_weight", 0.0),
-            lm=kwargs.get("lm_weight", 0.0),
-            ngram=kwargs.get("ngram_weight", 0.0),
-            length_bonus=kwargs.get("penalty", 0.0),
-        )
-        beam_search = BeamSearch(
-            beam_size=kwargs.get("beam_size", 2),
-            weights=weights,
-            scorers=scorers,
-            sos=self.sos,
-            eos=self.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        beam_search = BeamSearchTransducer(
+            self.decoder,
+            self.joint_network,
+            kwargs.get("beam_size", 2),
+            nbest=1,
         )
         # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
         # for scorer in scorers.values():
@@ -495,13 +440,13 @@
         #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
         self.beam_search = beam_search
         
-    def generate(self,
-             data_in: list,
-             data_lengths: list=None,
-             key: list=None,
-             tokenizer=None,
-             **kwargs,
-             ):
+    def inference(self,
+                  data_in: list,
+                  data_lengths: list=None,
+                  key: list=None,
+                  tokenizer=None,
+                  **kwargs,
+                  ):
         
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
@@ -509,10 +454,10 @@
         # init beamsearch
         is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
         is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-        if self.beam_search is None and (is_use_lm or is_use_ctc):
-            logging.info("enable beam_search")
-            self.init_beam_search(**kwargs)
-            self.nbest = kwargs.get("nbest", 1)
+        # if self.beam_search is None and (is_use_lm or is_use_ctc):
+        logging.info("enable beam_search")
+        self.init_beam_search(**kwargs)
+        self.nbest = kwargs.get("nbest", 1)
         
         meta_data = {}
         # extract fbank feats
@@ -534,12 +479,8 @@
             encoder_out = encoder_out[0]
         
         # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
-        )
-        
+        nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
         nbest_hyps = nbest_hyps[: self.nbest]
-
 
         results = []
         b, n, d = encoder_out.size()
@@ -553,9 +494,9 @@
                 # remove sos/eos and get results
                 last_pos = -1
                 if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
+                    token_int = hyp.yseq#[1:last_pos]
                 else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
+                    token_int = hyp.yseq#[1:last_pos].tolist()
                     
                 # remove blank symbol id, which is assumed to be 0
                 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))

--
Gitblit v1.9.1