From 00d0df3a1018c63ec8c5d13e611f53c564c0a7e2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 06 五月 2024 22:17:25 +0800
Subject: [PATCH] Dev gzf decoding (#1695)

---
 funasr/datasets/sense_voice_datasets/datasets.py              |   16 
 funasr/models/sense_voice/model.py                            |  217 +++++++++++++-
 funasr/datasets/audio_datasets/update_jsonl.py                |   21 
 funasr/models/sense_voice/decoder.py                          |   70 ++++
 funasr/models/sense_voice/search.py                           |  453 ++++++++++++++++++++++++++++++++
 funasr/train_utils/trainer.py                                 |    7 
 examples/industrial_data_pretraining/conformer/demo.py        |    3 
 examples/industrial_data_pretraining/sense_voice/demo_fsmn.py |   27 +
 8 files changed, 766 insertions(+), 48 deletions(-)

diff --git a/examples/industrial_data_pretraining/conformer/demo.py b/examples/industrial_data_pretraining/conformer/demo.py
index 43cf67d..c2d7682 100644
--- a/examples/industrial_data_pretraining/conformer/demo.py
+++ b/examples/industrial_data_pretraining/conformer/demo.py
@@ -8,6 +8,7 @@
 model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch")
 
 res = model.generate(
-    input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
+    input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav",
+    decoding_ctc_weight=0.0,
 )
 print(res)
diff --git a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
new file mode 100644
index 0000000..e063e1f
--- /dev/null
+++ b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(
+    model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscopeFSMN",
+    vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    vad_kwargs={"max_single_segment_time": 30000},
+)
+
+
+input_wav = (
+    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+)
+
+DecodingOptions = {
+    "task": ("ASR", "AED", "SER"),
+    "language": "auto",
+    "fp16": True,
+    "gain_event": True,
+}
+
+res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5)
+print(res)
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
index 05870fe..bc84708 100644
--- a/funasr/datasets/audio_datasets/update_jsonl.py
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -46,16 +46,17 @@
     data = json.loads(line.strip())
 
     wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
-    waveform, _ = librosa.load(wav_path, sr=16000)
-    sample_num = len(waveform)
-    source_len = int(sample_num / 16000 * 1000 / 10)
-    source_len_old = data["source_len"]
-    # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
-    #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
-    data["source_len"] = source_len
-    data["source"] = wav_path
-    jsonl_line = json.dumps(data, ensure_ascii=False)
-    lines[i] = jsonl_line
+    if os.path.exists(wav_path):
+        waveform, _ = librosa.load(wav_path, sr=16000)
+        sample_num = len(waveform)
+        source_len = int(sample_num / 16000 * 1000 / 10)
+        source_len_old = data["source_len"]
+        # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
+        #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
+        data["source_len"] = source_len
+        data["source"] = wav_path
+        jsonl_line = json.dumps(data, ensure_ascii=False)
+        lines[i] = jsonl_line
 
 
 def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 1d269dd..5d80956 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -2,7 +2,7 @@
 
 import torch
 import random
-
+import traceback
 from funasr.register import tables
 from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
 
@@ -73,15 +73,17 @@
             if idx == 0:
                 index_cur = index
             else:
-                if index <= self.retry:
-                    index_cur = index + idx
-                else:
-                    index_cur = torch.randint(0, index, ()).item()
+                index_cur = torch.randint(0, len(self.index_ds), ()).item()
 
             item = self.index_ds[index_cur]
 
             source = item["source"]
-            data_src = load_audio_text_image_video(source, fs=self.fs)
+            try:
+                data_src = load_audio_text_image_video(source, fs=self.fs)
+            except Exception as e:
+                logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
+                continue
+
             if self.preprocessor_speech:
                 data_src = self.preprocessor_speech(data_src, fs=self.fs)
             speech, speech_lengths = extract_fbank(
@@ -186,7 +188,7 @@
                 )
 
         if self.batch_type != "example":
-            for i in range(3):
+            for i in range(10):
                 outputs = self._filter_badcase(outputs, i=i)
 
         return outputs
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index f5b8825..19d9c16 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -15,6 +15,7 @@
 import torch
 import torch.nn.functional as F
 from torch import Tensor, nn
+from funasr.models.transformer.utils.mask import subsequent_mask
 
 
 class LayerNorm(nn.LayerNorm):
@@ -336,6 +337,29 @@
 
         return x
 
+    def init_state(self, x):
+        state = {}
+
+        return state
+
+    def final_score(self, state) -> float:
+        """Score eos (optional).
+
+        Args:
+            state: Scorer state for prefix tokens
+
+        Returns:
+            float: final score
+
+        """
+        return 0.0
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+        return logp.squeeze(0)[-1, :], state
+
 
 class MultiHeadedAttentionSANMDecoder(nn.Module):
     """Multi-Head Attention layer.
@@ -443,9 +467,19 @@
         kv_cache: Optional[dict] = None,
         **kwargs,
     ):
+        cache = kwargs.get("cache", {})
+        layer = kwargs.get("layer", 0)
         is_pad_mask = kwargs.get("is_pad_mask", False)
         is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
-        x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+
+        fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None
+        # if fsmn_cache is not None:
+        #     x = x[:, -1:]
+        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
+        # if len(cache)>1:
+        #     cache[layer]["fsmn_cache"] = fsmn_cache
+        #     x = x[:, -1:]
+        x = x + att_res
         if self.cross_attn:
             x = (
                 x
@@ -510,10 +544,9 @@
 
         ys_in_lens = kwargs.get("ys_in_lens", None)
 
-        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
         tgt, memory = x, xa
         tgt[tgt == -1] = 0
-        tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
+        tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
         # tgt = self.dropout(tgt)
 
         x = tgt.to(memory.dtype)
@@ -531,9 +564,40 @@
                 memory_mask=memory_mask,
                 is_pad_mask=False,
                 is_pad_memory_mask=True,
+                cache=kwargs.get("cache", None),
+                layer=layer,
             )
 
         x = self.ln(x)
         x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
 
         return x
+
+    def init_state(self, x):
+        state = {}
+        for layer, block in enumerate(self.blocks):
+            state[layer] = {
+                "fsmn_cache": None,
+                "memory_key": None,
+                "memory_value": None,
+            }
+
+        return state
+
+    def final_score(self, state) -> float:
+        """Score eos (optional).
+
+        Args:
+            state: Scorer state for prefix tokens
+
+        Returns:
+            float: final score
+
+        """
+        return 0.0
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+        return logp.squeeze(0)[-1, :], state
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index c12107e..d5e4130 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -15,6 +15,7 @@
 from funasr.train_utils.device_funcs import force_gatherable
 from . import whisper_lib as whisper
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
 
 from funasr.register import tables
 
@@ -395,6 +396,42 @@
 
         return loss_att, acc_att, None, None
 
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
+        from .search import BeamSearch
+
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+        # 1. Build ASR model
+        scorers = {}
+
+        scorers.update(
+            decoder=self.model.decoder,
+            length_bonus=LengthBonus(self.vocab_size),
+        )
+
+        weights = dict(
+            decoder=1.0,
+            ctc=0.0,
+            lm=0.0,
+            ngram=0.0,
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 5),
+            weights=weights,
+            scorers=scorers,
+            sos=None,
+            eos=None,
+            vocab_size=self.vocab_size,
+            token_list=None,
+            pre_beam_score_key="full",
+        )
+
+        self.beam_search = beam_search
+
     def inference(
         self,
         data_in,
@@ -406,6 +443,12 @@
     ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
+
+        # init beamsearch
+        if not hasattr(self, "beam_search") or self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
 
         if frontend is None and not hasattr(self, "frontend"):
             frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -455,25 +498,65 @@
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
         initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-        DecodingOptions["initial_prompt"] = initial_prompt
 
         language = DecodingOptions.get("language", None)
         language = None if language == "auto" else language
-        DecodingOptions["language"] = language
 
-        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+        sos_int = tokenizer.encode(sos, allowed_special="all")
+        eos = kwargs.get("model_conf").get("eos")
+        eos_int = tokenizer.encode(eos, allowed_special="all")
+        self.beam_search.sos = sos_int
+        self.beam_search.eos = eos_int[0]
 
-        if "without_timestamps" not in DecodingOptions:
-            DecodingOptions["without_timestamps"] = True
+        encoder_out, encoder_out_lens = self.encode(
+            speech[None, :, :].permute(0, 2, 1), speech_lengths
+        )
 
-        options = whisper.DecodingOptions(**DecodingOptions)
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0],
+            maxlenratio=kwargs.get("maxlenratio", 0.0),
+            minlenratio=kwargs.get("minlenratio", 0.0),
+        )
 
-        result = whisper.decode(self.model, speech, options)
-        text = f"{result.text}"
+        nbest_hyps = nbest_hyps[: self.nbest]
+
         results = []
-        result_i = {"key": key[0], "text": text}
+        b, n, d = encoder_out.size()
+        for i in range(b):
 
-        results.append(result_i)
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # # remove blank symbol id, which is assumed to be 0
+                # token_int = list(
+                #     filter(
+                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                #     )
+                # )
+
+                # Change integer-ids to tokens
+                # token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.decode(token_int)
+
+                result_i = {"key": key[i], "text": text}
+                results.append(result_i)
+
+                if ibest_writer is not None:
+                    # ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text
 
         return results, meta_data
 
@@ -497,12 +580,14 @@
         # decoder
         del model.decoder
         decoder = kwargs.get("decoder", "SenseVoiceDecoder")
-        decoder_conf = kwargs.get("decoder_conf", {})
         decoder_class = tables.decoder_classes.get(decoder)
         decoder = decoder_class(
-            vocab_size=dims.n_vocab,
-            encoder_output_size=dims.n_audio_state,
-            **decoder_conf,
+            n_vocab=dims.n_vocab,
+            n_ctx=dims.n_text_ctx,
+            n_state=dims.n_text_state,
+            n_head=dims.n_text_head,
+            n_layer=dims.n_text_layer,
+            **kwargs.get("decoder_conf"),
         )
         model.decoder = decoder
 
@@ -512,7 +597,7 @@
 
         self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
         self.ignore_id = kwargs.get("ignore_id", -1)
-        self.vocab_size = kwargs.get("vocab_size", -1)
+        self.vocab_size = dims.n_vocab
         self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
         self.criterion_att = LabelSmoothingLoss(
             size=self.vocab_size,
@@ -630,6 +715,42 @@
 
         return loss_att, acc_att, None, None
 
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
+        from .search import BeamSearch
+
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+        # 1. Build ASR model
+        scorers = {}
+
+        scorers.update(
+            decoder=self.model.decoder,
+            length_bonus=LengthBonus(self.vocab_size),
+        )
+
+        weights = dict(
+            decoder=1.0,
+            ctc=0.0,
+            lm=0.0,
+            ngram=0.0,
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 5),
+            weights=weights,
+            scorers=scorers,
+            sos=None,
+            eos=None,
+            vocab_size=self.vocab_size,
+            token_list=None,
+            pre_beam_score_key="full",
+        )
+
+        self.beam_search = beam_search
+
     def inference(
         self,
         data_in,
@@ -641,6 +762,12 @@
     ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
+
+        # init beamsearch
+        if not hasattr(self, "beam_search") or self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
 
         if frontend is None and not hasattr(self, "frontend"):
             frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -690,24 +817,64 @@
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
         initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-        DecodingOptions["initial_prompt"] = initial_prompt
 
         language = DecodingOptions.get("language", None)
         language = None if language == "auto" else language
-        DecodingOptions["language"] = language
 
-        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+        sos_int = tokenizer.encode(sos, allowed_special="all")
+        eos = kwargs.get("model_conf").get("eos")
+        eos_int = tokenizer.encode(eos, allowed_special="all")
+        self.beam_search.sos = sos_int
+        self.beam_search.eos = eos_int[0]
 
-        if "without_timestamps" not in DecodingOptions:
-            DecodingOptions["without_timestamps"] = True
+        encoder_out, encoder_out_lens = self.encode(
+            speech[None, :, :].permute(0, 2, 1), speech_lengths
+        )
 
-        options = whisper.DecodingOptions(**DecodingOptions)
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0],
+            maxlenratio=kwargs.get("maxlenratio", 0.0),
+            minlenratio=kwargs.get("minlenratio", 0.0),
+        )
 
-        result = whisper.decode(self.model, speech, options)
-        text = f"{result.text}"
+        nbest_hyps = nbest_hyps[: self.nbest]
+
         results = []
-        result_i = {"key": key[0], "text": text}
+        b, n, d = encoder_out.size()
+        for i in range(b):
 
-        results.append(result_i)
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # # remove blank symbol id, which is assumed to be 0
+                # token_int = list(
+                #     filter(
+                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                #     )
+                # )
+
+                # Change integer-ids to tokens
+                # token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.decode(token_int)
+
+                result_i = {"key": key[i], "text": text}
+                results.append(result_i)
+
+                if ibest_writer is not None:
+                    # ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text
 
         return results, meta_data
diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
new file mode 100644
index 0000000..98d02db
--- /dev/null
+++ b/funasr/models/sense_voice/search.py
@@ -0,0 +1,453 @@
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.metrics.common import end_detect
+from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
+from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
+
+
+class Hypothesis(NamedTuple):
+    """Hypothesis data type."""
+
+    yseq: torch.Tensor
+    score: Union[float, torch.Tensor] = 0
+    scores: Dict[str, Union[float, torch.Tensor]] = dict()
+    states: Dict[str, Any] = dict()
+
+    def asdict(self) -> dict:
+        """Convert data to JSON-friendly dict."""
+        return self._replace(
+            yseq=self.yseq.tolist(),
+            score=float(self.score),
+            scores={k: float(v) for k, v in self.scores.items()},
+        )._asdict()
+
+
+class BeamSearch(torch.nn.Module):
+    """Beam search implementation."""
+
+    def __init__(
+        self,
+        scorers: Dict[str, ScorerInterface],
+        weights: Dict[str, float],
+        beam_size: int,
+        vocab_size: int,
+        sos=None,
+        eos=None,
+        token_list: List[str] = None,
+        pre_beam_ratio: float = 1.5,
+        pre_beam_score_key: str = None,
+    ):
+        """Initialize beam search.
+
+        Args:
+            scorers (dict[str, ScorerInterface]): Dict of decoder modules
+                e.g., Decoder, CTCPrefixScorer, LM
+                The scorer will be ignored if it is `None`
+            weights (dict[str, float]): Dict of weights for each scorers
+                The scorer will be ignored if its weight is 0
+            beam_size (int): The number of hypotheses kept during search
+            vocab_size (int): The number of vocabulary
+            sos (int): Start of sequence id
+            eos (int): End of sequence id
+            token_list (list[str]): List of tokens for debug log
+            pre_beam_score_key (str): key of scores to perform pre-beam search
+            pre_beam_ratio (float): beam size in the pre-beam search
+                will be `int(pre_beam_ratio * beam_size)`
+
+        """
+        super().__init__()
+        # set scorers
+        self.weights = weights
+        self.scorers = dict()
+        self.full_scorers = dict()
+        self.part_scorers = dict()
+        # this module dict is required for recursive cast
+        # `self.to(device, dtype)` in `recog.py`
+        self.nn_dict = torch.nn.ModuleDict()
+        for k, v in scorers.items():
+            w = weights.get(k, 0)
+            if w == 0 or v is None:
+                continue
+            # assert isinstance(
+            #     v, ScorerInterface
+            # ), f"{k} ({type(v)}) does not implement ScorerInterface"
+            self.scorers[k] = v
+            if isinstance(v, PartialScorerInterface):
+                self.part_scorers[k] = v
+            else:
+                self.full_scorers[k] = v
+            if isinstance(v, torch.nn.Module):
+                self.nn_dict[k] = v
+
+        # set configurations
+        self.sos = sos
+        self.eos = eos
+        if isinstance(self.eos, (list, tuple)):
+            self.eos = eos[0]
+        self.token_list = token_list
+        self.pre_beam_size = int(pre_beam_ratio * beam_size)
+        self.beam_size = beam_size
+        self.n_vocab = vocab_size
+        if (
+            pre_beam_score_key is not None
+            and pre_beam_score_key != "full"
+            and pre_beam_score_key not in self.full_scorers
+        ):
+            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+        self.pre_beam_score_key = pre_beam_score_key
+        self.do_pre_beam = (
+            self.pre_beam_score_key is not None
+            and self.pre_beam_size < self.n_vocab
+            and len(self.part_scorers) > 0
+        )
+
+    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+        """Get an initial hypothesis data.
+
+        Args:
+            x (torch.Tensor): The encoder output feature
+
+        Returns:
+            Hypothesis: The initial hypothesis.
+
+        """
+        init_states = dict()
+        init_scores = dict()
+        for k, d in self.scorers.items():
+            init_states[k] = d.init_state(x)
+            init_scores[k] = 0.0
+        if not isinstance(self.sos, (list, tuple)):
+            self.sos = [self.sos]
+        return [
+            Hypothesis(
+                score=0.0,
+                scores=init_scores,
+                states=init_states,
+                yseq=torch.tensor(self.sos, device=x.device),
+            )
+        ]
+
+    @staticmethod
+    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+        """Append new token to prefix tokens.
+
+        Args:
+            xs (torch.Tensor): The prefix token
+            x (int): The new token to append
+
+        Returns:
+            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+        """
+        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+        return torch.cat((xs, x))
+
+    def score_full(
+        self, hyp: Hypothesis, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.full_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.full_scorers`
+                and tensor score values of shape: `(self.n_vocab,)`,
+                and state dict that has string keys
+                and state values of `self.full_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.full_scorers.items():
+            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+        return scores, states
+
+    def score_partial(
+        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.part_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            ids (torch.Tensor): 1D tensor of new partial tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.part_scorers`
+                and tensor score values of shape: `(len(ids),)`,
+                and state dict that has string keys
+                and state values of `self.part_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.part_scorers.items():
+            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+        return scores, states
+
+    def beam(
+        self, weighted_scores: torch.Tensor, ids: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute topk full token ids and partial token ids.
+
+        Args:
+            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+            Its shape is `(self.n_vocab,)`.
+            ids (torch.Tensor): The partial token ids to compute topk
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]:
+                The topk full token ids and partial token ids.
+                Their shapes are `(self.beam_size,)`
+
+        """
+        # no pre beam performed
+        if weighted_scores.size(0) == ids.size(0):
+            top_ids = weighted_scores.topk(self.beam_size)[1]
+            return top_ids, top_ids
+
+        # mask pruned in pre-beam not to select in topk
+        tmp = weighted_scores[ids]
+        weighted_scores[:] = -float("inf")
+        weighted_scores[ids] = tmp
+        top_ids = weighted_scores.topk(self.beam_size)[1]
+        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+        return top_ids, local_ids
+
+    @staticmethod
+    def merge_scores(
+        prev_scores: Dict[str, float],
+        next_full_scores: Dict[str, torch.Tensor],
+        full_idx: int,
+        next_part_scores: Dict[str, torch.Tensor],
+        part_idx: int,
+    ) -> Dict[str, torch.Tensor]:
+        """Merge scores for new hypothesis.
+
+        Args:
+            prev_scores (Dict[str, float]):
+                The previous hypothesis scores by `self.scorers`
+            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+            full_idx (int): The next token id for `next_full_scores`
+            next_part_scores (Dict[str, torch.Tensor]):
+                scores of partial tokens by `self.part_scorers`
+            part_idx (int): The new token id for `next_part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are scalar tensors by the scorers.
+
+        """
+        new_scores = dict()
+        for k, v in next_full_scores.items():
+            new_scores[k] = prev_scores[k] + v[full_idx]
+        for k, v in next_part_scores.items():
+            new_scores[k] = prev_scores[k] + v[part_idx]
+        return new_scores
+
+    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+        """Merge states for new hypothesis.
+
+        Args:
+            states: states of `self.full_scorers`
+            part_states: states of `self.part_scorers`
+            part_idx (int): The new token id for `part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are states of the scorers.
+
+        """
+        new_states = dict()
+        for k, v in states.items():
+            new_states[k] = v
+        for k, d in self.part_scorers.items():
+            new_states[k] = d.select_state(part_states[k], part_idx)
+        return new_states
+
+    def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]:
+        """Search new tokens for running hypotheses and encoded speech x.
+
+        Args:
+            running_hyps (List[Hypothesis]): Running hypotheses on beam
+            x (torch.Tensor): Encoded speech feature (T, D)
+
+        Returns:
+            List[Hypotheses]: Best sorted hypotheses
+
+        """
+        best_hyps = []
+        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
+        for hyp in running_hyps:
+            # scoring
+            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+            scores, states = self.score_full(hyp, x)
+            for k in self.full_scorers:
+                weighted_scores += self.weights[k] * scores[k]
+            # partial scoring
+            if self.do_pre_beam:
+                pre_beam_scores = (
+                    weighted_scores
+                    if self.pre_beam_score_key == "full"
+                    else scores[self.pre_beam_score_key]
+                )
+                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+            part_scores, part_states = self.score_partial(hyp, part_ids, x)
+            for k in self.part_scorers:
+                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+            # add previous hyp score
+            weighted_scores += hyp.score
+
+            # update hyps
+            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+                # will be (2 x beam at most)
+                best_hyps.append(
+                    Hypothesis(
+                        score=weighted_scores[j],
+                        yseq=self.append_token(hyp.yseq, j),
+                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
+                        states=self.merge_states(states, part_states, part_j),
+                    )
+                )
+
+            # sort and prune 2 x beam -> beam
+            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+                : min(len(best_hyps), self.beam_size)
+            ]
+        return best_hyps
+
+    def forward(
+        self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            x (torch.Tensor): Encoded speech feature (T, D)
+            maxlenratio (float): Input length ratio to obtain max output length.
+                If maxlenratio=0.0 (default), it uses a end-detect function
+                to automatically find maximum hypothesis lengths
+                If maxlenratio<0.0, its absolute value is interpreted
+                as a constant max output length.
+            minlenratio (float): Input length ratio to obtain min output length.
+
+        Returns:
+            list[Hypothesis]: N-best decoding results
+
+        """
+        # set length bounds
+        if maxlenratio == 0:
+            maxlen = x.shape[0]
+        elif maxlenratio < 0:
+            maxlen = -1 * int(maxlenratio)
+        else:
+            maxlen = max(1, int(maxlenratio * x.size(0)))
+        minlen = int(minlenratio * x.size(0))
+        logging.info("decoder input length: " + str(x.shape[0]))
+        logging.info("max output length: " + str(maxlen))
+        logging.info("min output length: " + str(minlen))
+
+        # main loop of prefix search
+        running_hyps = self.init_hyp(x)
+        ended_hyps = []
+        for i in range(maxlen):
+            logging.debug("position " + str(i))
+            best = self.search(running_hyps, x)
+            # post process of one iteration
+            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+            # end detection
+            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+                logging.info(f"end detected at {i}")
+                break
+            if len(running_hyps) == 0:
+                logging.info("no hypothesis. Finish decoding.")
+                break
+            else:
+                logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+        # check the number of hypotheses reaching to eos
+        if len(nbest_hyps) == 0:
+            logging.warning(
+                "there is no N-best results, perform recognition " "again with smaller minlenratio."
+            )
+            return (
+                []
+                if minlenratio < 0.1
+                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+            )
+
+        # report the best result
+        best = nbest_hyps[0]
+        for k, v in best.scores.items():
+            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
+        logging.info(f"total log probability: {best.score:.2f}")
+        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+        if self.token_list is not None:
+            logging.info(
+                "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
+            )
+        return nbest_hyps
+
+    def post_process(
+        self,
+        i: int,
+        maxlen: int,
+        maxlenratio: float,
+        running_hyps: List[Hypothesis],
+        ended_hyps: List[Hypothesis],
+    ) -> List[Hypothesis]:
+        """Perform post-processing of beam search iterations.
+
+        Args:
+            i (int): The length of hypothesis tokens.
+            maxlen (int): The maximum length of tokens in beam search.
+            maxlenratio (int): The maximum length ratio in beam search.
+            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+        Returns:
+            List[Hypothesis]: The new running hypotheses.
+
+        """
+        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+        if self.token_list is not None:
+            logging.debug(
+                "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+            )
+        # add eos in the final loop to avoid that there are no ended hyps
+        if i == maxlen - 1:
+            logging.info("adding <eos> in the last position in the loop")
+            running_hyps = [
+                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
+            ]
+
+        # add ended hypotheses to a final list, and removed them from current hypotheses
+        # (this will be a problem, number of hyps < beam)
+        remained_hyps = []
+        for hyp in running_hyps:
+            if hyp.yseq[-1] == self.eos:
+                # e.g., Word LM needs to add final <eos> score
+                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+                    s = d.final_score(hyp.states[k])
+                    hyp.scores[k] += s
+                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+                ended_hyps.append(hyp)
+            else:
+                remained_hyps.append(hyp)
+        return remained_hyps
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index dd0ac7a..01e2924 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -308,6 +308,7 @@
                     checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
                 )
                 self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
+                print(checkpoint["train_acc_avg"])
                 self.train_acc_avg = (
                     checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
                 )
@@ -464,7 +465,8 @@
                     batch_num_epoch = len(dataloader_train)
                 self.log(
                     epoch,
-                    batch_idx + kwargs.get("start_step", 0),
+                    batch_idx,
+                    log_step=batch_idx + kwargs.get("start_step", 0),
                     step_in_epoch=self.step_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,
@@ -633,11 +635,12 @@
         tag="train",
         data_split_i=0,
         data_split_num=1,
+        log_step=None,
         **kwargs,
     ):
 
         if (batch_idx + 1) % self.log_interval == 0:
-
+            batch_idx = log_step if log_step is not None else batch_idx
             gpu_info = (
                 "GPU, memory: usage: {:.3f} GB, "
                 "peak: {:.3f} GB, "

--
Gitblit v1.9.1