From 0271fbe4fd11114c259ec970b1b04938010a16e3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 18:53:33 +0800
Subject: [PATCH] inference

---
 funasr/bin/asr_infer.py            |  251 +++++++++++++++++++++++++++++++++++
 funasr/bin/asr_inference_launch.py |  164 +++++++++++++++++++++++
 2 files changed, 414 insertions(+), 1 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index bcf5718..f8736e9 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -32,6 +32,7 @@
 from funasr.modules.beam_search.beam_search import Hypothesis
 from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
 from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
 from funasr.modules.scorers.ctc import CTCPrefixScorer
 from funasr.modules.scorers.length_bonus import LengthBonus
 from funasr.modules.subsampling import TooShortUttError
@@ -58,7 +59,7 @@
 from funasr.utils.vad_utils import slice_padding_fbank
 from funasr.tasks.vad import VADTask
 from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-
+from funasr.tasks.asr import frontend_choices
 
 class Speech2Text:
     """Speech2Text class
@@ -1599,3 +1600,251 @@
         
         return Speech2Text(**kwargs)
 
+
+class Speech2TextSAASR:
+    """Speech2Text class
+
+    Examples:
+        >>> import soundfile
+        >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+    
+    def __init__(
+        self,
+        asr_train_config: Union[Path, str] = None,
+        asr_model_file: Union[Path, str] = None,
+        cmvn_file: Union[Path, str] = None,
+        lm_train_config: Union[Path, str] = None,
+        lm_file: Union[Path, str] = None,
+        token_type: str = None,
+        bpemodel: str = None,
+        device: str = "cpu",
+        maxlenratio: float = 0.0,
+        minlenratio: float = 0.0,
+        batch_size: int = 1,
+        dtype: str = "float32",
+        beam_size: int = 20,
+        ctc_weight: float = 0.5,
+        lm_weight: float = 1.0,
+        ngram_weight: float = 0.9,
+        penalty: float = 0.0,
+        nbest: int = 1,
+        streaming: bool = False,
+        frontend_conf: dict = None,
+        **kwargs,
+    ):
+        assert check_argument_types()
+        
+        # 1. Build ASR model
+        from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+        frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            if asr_train_args.frontend == 'wav_frontend':
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            else:
+                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
+        
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+        
+        decoder = asr_model.decoder
+        
+        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        token_list = asr_model.token_list
+        scorers.update(
+            decoder=decoder,
+            ctc=ctc,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+        
+        # 2. Build Language model
+        if lm_train_config is not None:
+            lm, lm_train_args = LMTask.build_model_from_file(
+                lm_train_config, lm_file, None, device
+            )
+            scorers["lm"] = lm.lm
+        
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+        
+        # 4. Build BeamSearch object
+        # transducer is not supported now
+        beam_search_transducer = None
+        
+        weights = dict(
+            decoder=1.0 - ctc_weight,
+            ctc=ctc_weight,
+            lm=lm_weight,
+            ngram=ngram_weight,
+            length_bonus=penalty,
+        )
+        beam_search = BeamSearch(
+            beam_size=beam_size,
+            weights=weights,
+            scorers=scorers,
+            sos=asr_model.sos,
+            eos=asr_model.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+        )
+        
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+        
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        converter = TokenIDConverter(token_list=token_list)
+        logging.info(f"Text tokenizer: {tokenizer}")
+        
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.converter = converter
+        self.tokenizer = tokenizer
+        self.beam_search = beam_search
+        self.beam_search_transducer = beam_search_transducer
+        self.maxlenratio = maxlenratio
+        self.minlenratio = minlenratio
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+        self.frontend = frontend
+    
+    @torch.no_grad()
+    def __call__(
+        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
+        profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+    ) -> List[
+        Tuple[
+            Optional[str],
+            Optional[str],
+            List[str],
+            List[int],
+            Union[HypothesisSAASR],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, text_id, token, token_int, hyp
+
+        """
+        assert check_argument_types()
+        
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        
+        if isinstance(profile, np.ndarray):
+            profile = torch.tensor(profile)
+        
+        if self.frontend is not None:
+            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            feats = to_device(feats, device=self.device)
+            feats_len = feats_len.int()
+            self.asr_model.frontend = None
+        else:
+            feats = speech
+            feats_len = speech_lengths
+        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+        batch = {"speech": feats, "speech_lengths": feats_len}
+        
+        # a. To device
+        batch = to_device(batch, device=self.device)
+        
+        # b. Forward Encoder
+        asr_enc, _, spk_enc = self.asr_model.encode(**batch)
+        if isinstance(asr_enc, tuple):
+            asr_enc = asr_enc[0]
+        if isinstance(spk_enc, tuple):
+            spk_enc = spk_enc[0]
+        assert len(asr_enc) == 1, len(asr_enc)
+        assert len(spk_enc) == 1, len(spk_enc)
+        
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+        )
+        
+        nbest_hyps = nbest_hyps[: self.nbest]
+        
+        results = []
+        for hyp in nbest_hyps:
+            assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
+            
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1: last_pos]
+            else:
+                token_int = hyp.yseq[1: last_pos].tolist()
+            
+            spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
+            
+            token_ori = self.converter.ids2tokens(token_int)
+            text_ori = self.tokenizer.tokens2text(token_ori)
+            
+            text_ori_spklist = text_ori.split('$')
+            cur_index = 0
+            spk_choose = []
+            for i in range(len(text_ori_spklist)):
+                text_ori_split = text_ori_spklist[i]
+                n = len(text_ori_split)
+                spk_weights_local = spk_weigths[cur_index: cur_index + n]
+                cur_index = cur_index + n + 1
+                spk_weights_local = spk_weights_local.mean(dim=0)
+                spk_choose_local = spk_weights_local.argmax(-1)
+                spk_choose.append(spk_choose_local.item() + 1)
+            
+            # remove blank symbol id, which is assumed to be 0
+            token_int = list(filter(lambda x: x != 0, token_int))
+            
+            # Change integer-ids to tokens
+            token = self.converter.ids2tokens(token_int)
+            
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+            
+            text_spklist = text.split('$')
+            assert len(spk_choose) == len(text_spklist)
+            
+            spk_list = []
+            for i in range(len(text_spklist)):
+                text_split = text_spklist[i]
+                n = len(text_split)
+                spk_list.append(str(spk_choose[i]) * n)
+            
+            text_id = '$'.join(spk_list)
+            
+            assert len(text) == len(text_id)
+            
+            results.append((text, text_id, token, token_int, hyp))
+        
+        assert check_return_type(results)
+        return results
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f9b8571..dbbb3ed 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -77,6 +77,7 @@
 from funasr.bin.punc_infer import Text2Punc
 from funasr.bin.tp_infer import Speech2Timestamp
 from funasr.bin.asr_infer import Speech2TextTransducer
+from funasr.bin.asr_infer import Speech2TextSAASR
 
 def inference_asr(
     maxlenratio: float,
@@ -1444,6 +1445,167 @@
     return _forward
 
 
+def inference_sa_asr(
+    maxlenratio: float,
+    minlenratio: float,
+    batch_size: int,
+    beam_size: int,
+    ngpu: int,
+    ctc_weight: float,
+    lm_weight: float,
+    penalty: float,
+    log_level: Union[int, str],
+    # data_path_and_name_and_type,
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    cmvn_file: Optional[str] = None,
+    lm_train_config: Optional[str] = None,
+    lm_file: Optional[str] = None,
+    token_type: Optional[str] = None,
+    key_file: Optional[str] = None,
+    word_lm_train_config: Optional[str] = None,
+    bpemodel: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    streaming: bool = False,
+    output_dir: Optional[str] = None,
+    dtype: str = "float32",
+    seed: int = 0,
+    ngram_weight: float = 0.9,
+    nbest: int = 1,
+    num_workers: int = 1,
+    mc: bool = False,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+    
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+    
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        streaming=streaming,
+    )
+    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+    speech2text = Speech2TextSAASR(**speech2text_kwargs)
+    
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 **kwargs,
+                 ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            fs=fs,
+            mc=mc,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+        
+        finish_count = 0
+        file_count = 1
+        # 7 .Start for-loop
+        # FIXME(kamo): The output format should be discussed about
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+        
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+            # N-best list of (text, token, token_int, hyp_object)
+            try:
+                results = speech2text(**batch)
+            except TooShortUttError as e:
+                logging.warning(f"Utterance {keys} {e}")
+                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+                results = [[" ", ["sil"], [2], hyp]] * nbest
+            
+            # Only supporting batch_size==1
+            key = keys[0]
+            for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                # Create a directory: outdir/{n}best_recog
+                if writer is not None:
+                    ibest_writer = writer[f"{n}best_recog"]
+                    
+                    # Write the result to each file
+                    ibest_writer["token"][key] = " ".join(token)
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["score"][key] = str(hyp.score)
+                    ibest_writer["text_id"][key] = text_id
+                
+                if text is not None:
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                    item = {'key': key, 'value': text_postprocessed}
+                    asr_result_list.append(item)
+                    finish_count += 1
+                    asr_utils.print_progress(finish_count / file_count)
+                    if writer is not None:
+                        ibest_writer["text"][key] = text
+                
+                logging.info("uttid: {}".format(key))
+                logging.info("text predictions: {}".format(text))
+                logging.info("text_id predictions: {}\n".format(text_id))
+        return asr_result_list
+    
+    return _forward
+
+
 def inference_launch(**kwargs):
     if 'mode' in kwargs:
         mode = kwargs['mode']
@@ -1464,6 +1626,8 @@
         return inference_mfcca(**kwargs)
     elif mode == "rnnt":
         return inference_transducer(**kwargs)
+    elif mode == "sa_asr":
+        return inference_sa_asr(**kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None

--
Gitblit v1.9.1