From 00d0df3a1018c63ec8c5d13e611f53c564c0a7e2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 06 五月 2024 22:17:25 +0800
Subject: [PATCH] Dev gzf decoding (#1695)
---
funasr/datasets/sense_voice_datasets/datasets.py | 16
funasr/models/sense_voice/model.py | 217 +++++++++++++-
funasr/datasets/audio_datasets/update_jsonl.py | 21
funasr/models/sense_voice/decoder.py | 70 ++++
funasr/models/sense_voice/search.py | 453 ++++++++++++++++++++++++++++++++
funasr/train_utils/trainer.py | 7
examples/industrial_data_pretraining/conformer/demo.py | 3
examples/industrial_data_pretraining/sense_voice/demo_fsmn.py | 27 +
8 files changed, 766 insertions(+), 48 deletions(-)
diff --git a/examples/industrial_data_pretraining/conformer/demo.py b/examples/industrial_data_pretraining/conformer/demo.py
index 43cf67d..c2d7682 100644
--- a/examples/industrial_data_pretraining/conformer/demo.py
+++ b/examples/industrial_data_pretraining/conformer/demo.py
@@ -8,6 +8,7 @@
model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch")
res = model.generate(
- input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
+ input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav",
+ decoding_ctc_weight=0.0,
)
print(res)
diff --git a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
new file mode 100644
index 0000000..e063e1f
--- /dev/null
+++ b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(
+ model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscopeFSMN",
+ vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ vad_kwargs={"max_single_segment_time": 30000},
+)
+
+
+input_wav = (
+ "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+)
+
+DecodingOptions = {
+ "task": ("ASR", "AED", "SER"),
+ "language": "auto",
+ "fp16": True,
+ "gain_event": True,
+}
+
+res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5)
+print(res)
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
index 05870fe..bc84708 100644
--- a/funasr/datasets/audio_datasets/update_jsonl.py
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -46,16 +46,17 @@
data = json.loads(line.strip())
wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
- waveform, _ = librosa.load(wav_path, sr=16000)
- sample_num = len(waveform)
- source_len = int(sample_num / 16000 * 1000 / 10)
- source_len_old = data["source_len"]
- # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
- # logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
- data["source_len"] = source_len
- data["source"] = wav_path
- jsonl_line = json.dumps(data, ensure_ascii=False)
- lines[i] = jsonl_line
+ if os.path.exists(wav_path):
+ waveform, _ = librosa.load(wav_path, sr=16000)
+ sample_num = len(waveform)
+ source_len = int(sample_num / 16000 * 1000 / 10)
+ source_len_old = data["source_len"]
+ # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
+ # logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
+ data["source_len"] = source_len
+ data["source"] = wav_path
+ jsonl_line = json.dumps(data, ensure_ascii=False)
+ lines[i] = jsonl_line
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 1d269dd..5d80956 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -2,7 +2,7 @@
import torch
import random
-
+import traceback
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@@ -73,15 +73,17 @@
if idx == 0:
index_cur = index
else:
- if index <= self.retry:
- index_cur = index + idx
- else:
- index_cur = torch.randint(0, index, ()).item()
+ index_cur = torch.randint(0, len(self.index_ds), ()).item()
item = self.index_ds[index_cur]
source = item["source"]
- data_src = load_audio_text_image_video(source, fs=self.fs)
+ try:
+ data_src = load_audio_text_image_video(source, fs=self.fs)
+ except Exception as e:
+ logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
+ continue
+
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(
@@ -186,7 +188,7 @@
)
if self.batch_type != "example":
- for i in range(3):
+ for i in range(10):
outputs = self._filter_badcase(outputs, i=i)
return outputs
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index f5b8825..19d9c16 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -15,6 +15,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
+from funasr.models.transformer.utils.mask import subsequent_mask
class LayerNorm(nn.LayerNorm):
@@ -336,6 +337,29 @@
return x
+ def init_state(self, x):
+ state = {}
+
+ return state
+
+ def final_score(self, state) -> float:
+ """Score eos (optional).
+
+ Args:
+ state: Scorer state for prefix tokens
+
+ Returns:
+ float: final score
+
+ """
+ return 0.0
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+ logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+ return logp.squeeze(0)[-1, :], state
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
@@ -443,9 +467,19 @@
kv_cache: Optional[dict] = None,
**kwargs,
):
+ cache = kwargs.get("cache", {})
+ layer = kwargs.get("layer", 0)
is_pad_mask = kwargs.get("is_pad_mask", False)
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
- x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+
+ fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None
+ # if fsmn_cache is not None:
+ # x = x[:, -1:]
+ att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
+ # if len(cache)>1:
+ # cache[layer]["fsmn_cache"] = fsmn_cache
+ # x = x[:, -1:]
+ x = x + att_res
if self.cross_attn:
x = (
x
@@ -510,10 +544,9 @@
ys_in_lens = kwargs.get("ys_in_lens", None)
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
tgt, memory = x, xa
tgt[tgt == -1] = 0
- tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
+ tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
# tgt = self.dropout(tgt)
x = tgt.to(memory.dtype)
@@ -531,9 +564,40 @@
memory_mask=memory_mask,
is_pad_mask=False,
is_pad_memory_mask=True,
+ cache=kwargs.get("cache", None),
+ layer=layer,
)
x = self.ln(x)
x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
return x
+
+ def init_state(self, x):
+ state = {}
+ for layer, block in enumerate(self.blocks):
+ state[layer] = {
+ "fsmn_cache": None,
+ "memory_key": None,
+ "memory_value": None,
+ }
+
+ return state
+
+ def final_score(self, state) -> float:
+ """Score eos (optional).
+
+ Args:
+ state: Scorer state for prefix tokens
+
+ Returns:
+ float: final score
+
+ """
+ return 0.0
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+ logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+ return logp.squeeze(0)[-1, :], state
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index c12107e..d5e4130 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -15,6 +15,7 @@
from funasr.train_utils.device_funcs import force_gatherable
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@@ -395,6 +396,42 @@
return loss_att, acc_att, None, None
+ def init_beam_search(
+ self,
+ **kwargs,
+ ):
+ from .search import BeamSearch
+
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ scorers.update(
+ decoder=self.model.decoder,
+ length_bonus=LengthBonus(self.vocab_size),
+ )
+
+ weights = dict(
+ decoder=1.0,
+ ctc=0.0,
+ lm=0.0,
+ ngram=0.0,
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 5),
+ weights=weights,
+ scorers=scorers,
+ sos=None,
+ eos=None,
+ vocab_size=self.vocab_size,
+ token_list=None,
+ pre_beam_score_key="full",
+ )
+
+ self.beam_search = beam_search
+
def inference(
self,
data_in,
@@ -406,6 +443,12 @@
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ if not hasattr(self, "beam_search") or self.beam_search is None:
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -455,25 +498,65 @@
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
- DecodingOptions["initial_prompt"] = initial_prompt
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
- DecodingOptions["language"] = language
- DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+ sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ sos_int = tokenizer.encode(sos, allowed_special="all")
+ eos = kwargs.get("model_conf").get("eos")
+ eos_int = tokenizer.encode(eos, allowed_special="all")
+ self.beam_search.sos = sos_int
+ self.beam_search.eos = eos_int[0]
- if "without_timestamps" not in DecodingOptions:
- DecodingOptions["without_timestamps"] = True
+ encoder_out, encoder_out_lens = self.encode(
+ speech[None, :, :].permute(0, 2, 1), speech_lengths
+ )
- options = whisper.DecodingOptions(**DecodingOptions)
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0],
+ maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0),
+ )
- result = whisper.decode(self.model, speech, options)
- text = f"{result.text}"
+ nbest_hyps = nbest_hyps[: self.nbest]
+
results = []
- result_i = {"key": key[0], "text": text}
+ b, n, d = encoder_out.size()
+ for i in range(b):
- results.append(result_i)
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # # remove blank symbol id, which is assumed to be 0
+ # token_int = list(
+ # filter(
+ # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+ # )
+ # )
+
+ # Change integer-ids to tokens
+ # token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.decode(token_int)
+
+ result_i = {"key": key[i], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
return results, meta_data
@@ -497,12 +580,14 @@
# decoder
del model.decoder
decoder = kwargs.get("decoder", "SenseVoiceDecoder")
- decoder_conf = kwargs.get("decoder_conf", {})
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
- vocab_size=dims.n_vocab,
- encoder_output_size=dims.n_audio_state,
- **decoder_conf,
+ n_vocab=dims.n_vocab,
+ n_ctx=dims.n_text_ctx,
+ n_state=dims.n_text_state,
+ n_head=dims.n_text_head,
+ n_layer=dims.n_text_layer,
+ **kwargs.get("decoder_conf"),
)
model.decoder = decoder
@@ -512,7 +597,7 @@
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
self.ignore_id = kwargs.get("ignore_id", -1)
- self.vocab_size = kwargs.get("vocab_size", -1)
+ self.vocab_size = dims.n_vocab
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
self.criterion_att = LabelSmoothingLoss(
size=self.vocab_size,
@@ -630,6 +715,42 @@
return loss_att, acc_att, None, None
+ def init_beam_search(
+ self,
+ **kwargs,
+ ):
+ from .search import BeamSearch
+
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ scorers.update(
+ decoder=self.model.decoder,
+ length_bonus=LengthBonus(self.vocab_size),
+ )
+
+ weights = dict(
+ decoder=1.0,
+ ctc=0.0,
+ lm=0.0,
+ ngram=0.0,
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 5),
+ weights=weights,
+ scorers=scorers,
+ sos=None,
+ eos=None,
+ vocab_size=self.vocab_size,
+ token_list=None,
+ pre_beam_score_key="full",
+ )
+
+ self.beam_search = beam_search
+
def inference(
self,
data_in,
@@ -641,6 +762,12 @@
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ if not hasattr(self, "beam_search") or self.beam_search is None:
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -690,24 +817,64 @@
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
- DecodingOptions["initial_prompt"] = initial_prompt
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
- DecodingOptions["language"] = language
- DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+ sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ sos_int = tokenizer.encode(sos, allowed_special="all")
+ eos = kwargs.get("model_conf").get("eos")
+ eos_int = tokenizer.encode(eos, allowed_special="all")
+ self.beam_search.sos = sos_int
+ self.beam_search.eos = eos_int[0]
- if "without_timestamps" not in DecodingOptions:
- DecodingOptions["without_timestamps"] = True
+ encoder_out, encoder_out_lens = self.encode(
+ speech[None, :, :].permute(0, 2, 1), speech_lengths
+ )
- options = whisper.DecodingOptions(**DecodingOptions)
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0],
+ maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0),
+ )
- result = whisper.decode(self.model, speech, options)
- text = f"{result.text}"
+ nbest_hyps = nbest_hyps[: self.nbest]
+
results = []
- result_i = {"key": key[0], "text": text}
+ b, n, d = encoder_out.size()
+ for i in range(b):
- results.append(result_i)
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # # remove blank symbol id, which is assumed to be 0
+ # token_int = list(
+ # filter(
+ # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+ # )
+ # )
+
+ # Change integer-ids to tokens
+ # token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.decode(token_int)
+
+ result_i = {"key": key[i], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
return results, meta_data
diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
new file mode 100644
index 0000000..98d02db
--- /dev/null
+++ b/funasr/models/sense_voice/search.py
@@ -0,0 +1,453 @@
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.metrics.common import end_detect
+from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
+from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
+
+
+class Hypothesis(NamedTuple):
+ """Hypothesis data type."""
+
+ yseq: torch.Tensor
+ score: Union[float, torch.Tensor] = 0
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
+ states: Dict[str, Any] = dict()
+
+ def asdict(self) -> dict:
+ """Convert data to JSON-friendly dict."""
+ return self._replace(
+ yseq=self.yseq.tolist(),
+ score=float(self.score),
+ scores={k: float(v) for k, v in self.scores.items()},
+ )._asdict()
+
+
+class BeamSearch(torch.nn.Module):
+ """Beam search implementation."""
+
+ def __init__(
+ self,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ beam_size: int,
+ vocab_size: int,
+ sos=None,
+ eos=None,
+ token_list: List[str] = None,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = None,
+ ):
+ """Initialize beam search.
+
+ Args:
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ token_list (list[str]): List of tokens for debug log
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ """
+ super().__init__()
+ # set scorers
+ self.weights = weights
+ self.scorers = dict()
+ self.full_scorers = dict()
+ self.part_scorers = dict()
+ # this module dict is required for recursive cast
+ # `self.to(device, dtype)` in `recog.py`
+ self.nn_dict = torch.nn.ModuleDict()
+ for k, v in scorers.items():
+ w = weights.get(k, 0)
+ if w == 0 or v is None:
+ continue
+ # assert isinstance(
+ # v, ScorerInterface
+ # ), f"{k} ({type(v)}) does not implement ScorerInterface"
+ self.scorers[k] = v
+ if isinstance(v, PartialScorerInterface):
+ self.part_scorers[k] = v
+ else:
+ self.full_scorers[k] = v
+ if isinstance(v, torch.nn.Module):
+ self.nn_dict[k] = v
+
+ # set configurations
+ self.sos = sos
+ self.eos = eos
+ if isinstance(self.eos, (list, tuple)):
+ self.eos = eos[0]
+ self.token_list = token_list
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
+ self.beam_size = beam_size
+ self.n_vocab = vocab_size
+ if (
+ pre_beam_score_key is not None
+ and pre_beam_score_key != "full"
+ and pre_beam_score_key not in self.full_scorers
+ ):
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+ self.pre_beam_score_key = pre_beam_score_key
+ self.do_pre_beam = (
+ self.pre_beam_score_key is not None
+ and self.pre_beam_size < self.n_vocab
+ and len(self.part_scorers) > 0
+ )
+
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+ """Get an initial hypothesis data.
+
+ Args:
+ x (torch.Tensor): The encoder output feature
+
+ Returns:
+ Hypothesis: The initial hypothesis.
+
+ """
+ init_states = dict()
+ init_scores = dict()
+ for k, d in self.scorers.items():
+ init_states[k] = d.init_state(x)
+ init_scores[k] = 0.0
+ if not isinstance(self.sos, (list, tuple)):
+ self.sos = [self.sos]
+ return [
+ Hypothesis(
+ score=0.0,
+ scores=init_scores,
+ states=init_states,
+ yseq=torch.tensor(self.sos, device=x.device),
+ )
+ ]
+
+ @staticmethod
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+ """Append new token to prefix tokens.
+
+ Args:
+ xs (torch.Tensor): The prefix token
+ x (int): The new token to append
+
+ Returns:
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+ """
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+ return torch.cat((xs, x))
+
+ def score_full(
+ self, hyp: Hypothesis, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.full_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.full_scorers`
+ and tensor score values of shape: `(self.n_vocab,)`,
+ and state dict that has string keys
+ and state values of `self.full_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.full_scorers.items():
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+ return scores, states
+
+ def score_partial(
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.part_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.part_scorers`
+ and tensor score values of shape: `(len(ids),)`,
+ and state dict that has string keys
+ and state values of `self.part_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.part_scorers.items():
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+ return scores, states
+
+ def beam(
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute topk full token ids and partial token ids.
+
+ Args:
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+ Its shape is `(self.n_vocab,)`.
+ ids (torch.Tensor): The partial token ids to compute topk
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ The topk full token ids and partial token ids.
+ Their shapes are `(self.beam_size,)`
+
+ """
+ # no pre beam performed
+ if weighted_scores.size(0) == ids.size(0):
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ return top_ids, top_ids
+
+ # mask pruned in pre-beam not to select in topk
+ tmp = weighted_scores[ids]
+ weighted_scores[:] = -float("inf")
+ weighted_scores[ids] = tmp
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+ return top_ids, local_ids
+
+ @staticmethod
+ def merge_scores(
+ prev_scores: Dict[str, float],
+ next_full_scores: Dict[str, torch.Tensor],
+ full_idx: int,
+ next_part_scores: Dict[str, torch.Tensor],
+ part_idx: int,
+ ) -> Dict[str, torch.Tensor]:
+ """Merge scores for new hypothesis.
+
+ Args:
+ prev_scores (Dict[str, float]):
+ The previous hypothesis scores by `self.scorers`
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+ full_idx (int): The next token id for `next_full_scores`
+ next_part_scores (Dict[str, torch.Tensor]):
+ scores of partial tokens by `self.part_scorers`
+ part_idx (int): The new token id for `next_part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are scalar tensors by the scorers.
+
+ """
+ new_scores = dict()
+ for k, v in next_full_scores.items():
+ new_scores[k] = prev_scores[k] + v[full_idx]
+ for k, v in next_part_scores.items():
+ new_scores[k] = prev_scores[k] + v[part_idx]
+ return new_scores
+
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+ """Merge states for new hypothesis.
+
+ Args:
+ states: states of `self.full_scorers`
+ part_states: states of `self.part_scorers`
+ part_idx (int): The new token id for `part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are states of the scorers.
+
+ """
+ new_states = dict()
+ for k, v in states.items():
+ new_states[k] = v
+ for k, d in self.part_scorers.items():
+ new_states[k] = d.select_state(part_states[k], part_idx)
+ return new_states
+
+ def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]:
+ """Search new tokens for running hypotheses and encoded speech x.
+
+ Args:
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
+ x (torch.Tensor): Encoded speech feature (T, D)
+
+ Returns:
+ List[Hypotheses]: Best sorted hypotheses
+
+ """
+ best_hyps = []
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
+ for hyp in running_hyps:
+ # scoring
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+ scores, states = self.score_full(hyp, x)
+ for k in self.full_scorers:
+ weighted_scores += self.weights[k] * scores[k]
+ # partial scoring
+ if self.do_pre_beam:
+ pre_beam_scores = (
+ weighted_scores
+ if self.pre_beam_score_key == "full"
+ else scores[self.pre_beam_score_key]
+ )
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
+ for k in self.part_scorers:
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+ # add previous hyp score
+ weighted_scores += hyp.score
+
+ # update hyps
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+ # will be (2 x beam at most)
+ best_hyps.append(
+ Hypothesis(
+ score=weighted_scores[j],
+ yseq=self.append_token(hyp.yseq, j),
+ scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
+ states=self.merge_states(states, part_states, part_j),
+ )
+ )
+
+ # sort and prune 2 x beam -> beam
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+ : min(len(best_hyps), self.beam_size)
+ ]
+ return best_hyps
+
+ def forward(
+ self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ If maxlenratio<0.0, its absolute value is interpreted
+ as a constant max output length.
+ minlenratio (float): Input length ratio to obtain min output length.
+
+ Returns:
+ list[Hypothesis]: N-best decoding results
+
+ """
+ # set length bounds
+ if maxlenratio == 0:
+ maxlen = x.shape[0]
+ elif maxlenratio < 0:
+ maxlen = -1 * int(maxlenratio)
+ else:
+ maxlen = max(1, int(maxlenratio * x.size(0)))
+ minlen = int(minlenratio * x.size(0))
+ logging.info("decoder input length: " + str(x.shape[0]))
+ logging.info("max output length: " + str(maxlen))
+ logging.info("min output length: " + str(minlen))
+
+ # main loop of prefix search
+ running_hyps = self.init_hyp(x)
+ ended_hyps = []
+ for i in range(maxlen):
+ logging.debug("position " + str(i))
+ best = self.search(running_hyps, x)
+ # post process of one iteration
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+ # end detection
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+ logging.info(f"end detected at {i}")
+ break
+ if len(running_hyps) == 0:
+ logging.info("no hypothesis. Finish decoding.")
+ break
+ else:
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+ # check the number of hypotheses reaching to eos
+ if len(nbest_hyps) == 0:
+ logging.warning(
+ "there is no N-best results, perform recognition " "again with smaller minlenratio."
+ )
+ return (
+ []
+ if minlenratio < 0.1
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+ )
+
+ # report the best result
+ best = nbest_hyps[0]
+ for k, v in best.scores.items():
+ logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
+ logging.info(f"total log probability: {best.score:.2f}")
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+ if self.token_list is not None:
+ logging.info(
+ "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
+ )
+ return nbest_hyps
+
+ def post_process(
+ self,
+ i: int,
+ maxlen: int,
+ maxlenratio: float,
+ running_hyps: List[Hypothesis],
+ ended_hyps: List[Hypothesis],
+ ) -> List[Hypothesis]:
+ """Perform post-processing of beam search iterations.
+
+ Args:
+ i (int): The length of hypothesis tokens.
+ maxlen (int): The maximum length of tokens in beam search.
+ maxlenratio (int): The maximum length ratio in beam search.
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+ Returns:
+ List[Hypothesis]: The new running hypotheses.
+
+ """
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+ if self.token_list is not None:
+ logging.debug(
+ "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+ )
+ # add eos in the final loop to avoid that there are no ended hyps
+ if i == maxlen - 1:
+ logging.info("adding <eos> in the last position in the loop")
+ running_hyps = [
+ h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
+ ]
+
+ # add ended hypotheses to a final list, and removed them from current hypotheses
+ # (this will be a problem, number of hyps < beam)
+ remained_hyps = []
+ for hyp in running_hyps:
+ if hyp.yseq[-1] == self.eos:
+ # e.g., Word LM needs to add final <eos> score
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+ s = d.final_score(hyp.states[k])
+ hyp.scores[k] += s
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+ ended_hyps.append(hyp)
+ else:
+ remained_hyps.append(hyp)
+ return remained_hyps
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index dd0ac7a..01e2924 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -308,6 +308,7 @@
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
+ print(checkpoint["train_acc_avg"])
self.train_acc_avg = (
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
)
@@ -464,7 +465,8 @@
batch_num_epoch = len(dataloader_train)
self.log(
epoch,
- batch_idx + kwargs.get("start_step", 0),
+ batch_idx,
+ log_step=batch_idx + kwargs.get("start_step", 0),
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
@@ -633,11 +635,12 @@
tag="train",
data_split_i=0,
data_split_num=1,
+ log_step=None,
**kwargs,
):
if (batch_idx + 1) % self.log_interval == 0:
-
+ batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
"GPU, memory: usage: {:.3f} GB, "
"peak: {:.3f} GB, "
--
Gitblit v1.9.1