Dev gzf decoding (#1695)
* resume from step
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* batch
* train_loss_avg train_acc_avg
* train_loss_avg train_acc_avg
* train_loss_avg train_acc_avg
* log step
* wav is not exist
* wav is not exist
* decoding
* decoding
| | |
| | | model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch") |
| | | |
| | | res = model.generate( |
| | | input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav" |
| | | input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav", |
| | | decoding_ctc_weight=0.0, |
| | | ) |
| | | print(res) |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel( |
| | | model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscopeFSMN", |
| | | vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | vad_kwargs={"max_single_segment_time": 30000}, |
| | | ) |
| | | |
| | | |
| | | input_wav = ( |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" |
| | | ) |
| | | |
| | | DecodingOptions = { |
| | | "task": ("ASR", "AED", "SER"), |
| | | "language": "auto", |
| | | "fp16": True, |
| | | "gain_event": True, |
| | | } |
| | | |
| | | res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5) |
| | | print(res) |
| | |
| | | data = json.loads(line.strip()) |
| | | |
| | | wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data") |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | sample_num = len(waveform) |
| | | source_len = int(sample_num / 16000 * 1000 / 10) |
| | | source_len_old = data["source_len"] |
| | | # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100: |
| | | # logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}") |
| | | data["source_len"] = source_len |
| | | data["source"] = wav_path |
| | | jsonl_line = json.dumps(data, ensure_ascii=False) |
| | | lines[i] = jsonl_line |
| | | if os.path.exists(wav_path): |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | sample_num = len(waveform) |
| | | source_len = int(sample_num / 16000 * 1000 / 10) |
| | | source_len_old = data["source_len"] |
| | | # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100: |
| | | # logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}") |
| | | data["source_len"] = source_len |
| | | data["source"] = wav_path |
| | | jsonl_line = json.dumps(data, ensure_ascii=False) |
| | | lines[i] = jsonl_line |
| | | |
| | | |
| | | def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1): |
| | |
| | | |
| | | import torch |
| | | import random |
| | | |
| | | import traceback |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video |
| | | |
| | |
| | | if idx == 0: |
| | | index_cur = index |
| | | else: |
| | | if index <= self.retry: |
| | | index_cur = index + idx |
| | | else: |
| | | index_cur = torch.randint(0, index, ()).item() |
| | | index_cur = torch.randint(0, len(self.index_ds), ()).item() |
| | | |
| | | item = self.index_ds[index_cur] |
| | | |
| | | source = item["source"] |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | try: |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | except Exception as e: |
| | | logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") |
| | | continue |
| | | |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src, fs=self.fs) |
| | | speech, speech_lengths = extract_fbank( |
| | |
| | | ) |
| | | |
| | | if self.batch_type != "example": |
| | | for i in range(3): |
| | | for i in range(10): |
| | | outputs = self._filter_badcase(outputs, i=i) |
| | | |
| | | return outputs |
| | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor, nn |
| | | from funasr.models.transformer.utils.mask import subsequent_mask |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | |
| | | |
| | | return x |
| | | |
| | | def init_state(self, x): |
| | | state = {} |
| | | |
| | | return state |
| | | |
| | | def final_score(self, state) -> float: |
| | | """Score eos (optional). |
| | | |
| | | Args: |
| | | state: Scorer state for prefix tokens |
| | | |
| | | Returns: |
| | | float: final score |
| | | |
| | | """ |
| | | return 0.0 |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0)[-1, :], state |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |
| | |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | cache = kwargs.get("cache", {}) |
| | | layer = kwargs.get("layer", 0) |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | |
| | | fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None |
| | | # if fsmn_cache is not None: |
| | | # x = x[:, -1:] |
| | | att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache) |
| | | # if len(cache)>1: |
| | | # cache[layer]["fsmn_cache"] = fsmn_cache |
| | | # x = x[:, -1:] |
| | | x = x + att_res |
| | | if self.cross_attn: |
| | | x = ( |
| | | x |
| | |
| | | |
| | | ys_in_lens = kwargs.get("ys_in_lens", None) |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt == -1] = 0 |
| | | tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)] |
| | | tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)] |
| | | # tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | |
| | | memory_mask=memory_mask, |
| | | is_pad_mask=False, |
| | | is_pad_memory_mask=True, |
| | | cache=kwargs.get("cache", None), |
| | | layer=layer, |
| | | ) |
| | | |
| | | x = self.ln(x) |
| | | x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x |
| | | |
| | | def init_state(self, x): |
| | | state = {} |
| | | for layer, block in enumerate(self.blocks): |
| | | state[layer] = { |
| | | "fsmn_cache": None, |
| | | "memory_key": None, |
| | | "memory_value": None, |
| | | } |
| | | |
| | | return state |
| | | |
| | | def final_score(self, state) -> float: |
| | | """Score eos (optional). |
| | | |
| | | Args: |
| | | state: Scorer state for prefix tokens |
| | | |
| | | Returns: |
| | | float: final score |
| | | |
| | | """ |
| | | return 0.0 |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0)[-1, :], state |
| | |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | |
| | | from funasr.register import tables |
| | | |
| | |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |
| | | |
| | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_conf = kwargs.get("decoder_conf", {}) |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=dims.n_vocab, |
| | | encoder_output_size=dims.n_audio_state, |
| | | **decoder_conf, |
| | | n_vocab=dims.n_vocab, |
| | | n_ctx=dims.n_text_ctx, |
| | | n_state=dims.n_text_state, |
| | | n_head=dims.n_text_head, |
| | | n_layer=dims.n_text_layer, |
| | | **kwargs.get("decoder_conf"), |
| | | ) |
| | | model.decoder = decoder |
| | | |
| | |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.vocab_size = dims.n_vocab |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |
| New file |
| | |
| | | from itertools import chain |
| | | import logging |
| | | from typing import Any |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import NamedTuple |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import torch |
| | | |
| | | from funasr.metrics.common import end_detect |
| | | from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface |
| | | from funasr.models.transformer.scorers.scorer_interface import ScorerInterface |
| | | |
| | | |
| | | class Hypothesis(NamedTuple): |
| | | """Hypothesis data type.""" |
| | | |
| | | yseq: torch.Tensor |
| | | score: Union[float, torch.Tensor] = 0 |
| | | scores: Dict[str, Union[float, torch.Tensor]] = dict() |
| | | states: Dict[str, Any] = dict() |
| | | |
| | | def asdict(self) -> dict: |
| | | """Convert data to JSON-friendly dict.""" |
| | | return self._replace( |
| | | yseq=self.yseq.tolist(), |
| | | score=float(self.score), |
| | | scores={k: float(v) for k, v in self.scores.items()}, |
| | | )._asdict() |
| | | |
| | | |
| | | class BeamSearch(torch.nn.Module): |
| | | """Beam search implementation.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | scorers: Dict[str, ScorerInterface], |
| | | weights: Dict[str, float], |
| | | beam_size: int, |
| | | vocab_size: int, |
| | | sos=None, |
| | | eos=None, |
| | | token_list: List[str] = None, |
| | | pre_beam_ratio: float = 1.5, |
| | | pre_beam_score_key: str = None, |
| | | ): |
| | | """Initialize beam search. |
| | | |
| | | Args: |
| | | scorers (dict[str, ScorerInterface]): Dict of decoder modules |
| | | e.g., Decoder, CTCPrefixScorer, LM |
| | | The scorer will be ignored if it is `None` |
| | | weights (dict[str, float]): Dict of weights for each scorers |
| | | The scorer will be ignored if its weight is 0 |
| | | beam_size (int): The number of hypotheses kept during search |
| | | vocab_size (int): The number of vocabulary |
| | | sos (int): Start of sequence id |
| | | eos (int): End of sequence id |
| | | token_list (list[str]): List of tokens for debug log |
| | | pre_beam_score_key (str): key of scores to perform pre-beam search |
| | | pre_beam_ratio (float): beam size in the pre-beam search |
| | | will be `int(pre_beam_ratio * beam_size)` |
| | | |
| | | """ |
| | | super().__init__() |
| | | # set scorers |
| | | self.weights = weights |
| | | self.scorers = dict() |
| | | self.full_scorers = dict() |
| | | self.part_scorers = dict() |
| | | # this module dict is required for recursive cast |
| | | # `self.to(device, dtype)` in `recog.py` |
| | | self.nn_dict = torch.nn.ModuleDict() |
| | | for k, v in scorers.items(): |
| | | w = weights.get(k, 0) |
| | | if w == 0 or v is None: |
| | | continue |
| | | # assert isinstance( |
| | | # v, ScorerInterface |
| | | # ), f"{k} ({type(v)}) does not implement ScorerInterface" |
| | | self.scorers[k] = v |
| | | if isinstance(v, PartialScorerInterface): |
| | | self.part_scorers[k] = v |
| | | else: |
| | | self.full_scorers[k] = v |
| | | if isinstance(v, torch.nn.Module): |
| | | self.nn_dict[k] = v |
| | | |
| | | # set configurations |
| | | self.sos = sos |
| | | self.eos = eos |
| | | if isinstance(self.eos, (list, tuple)): |
| | | self.eos = eos[0] |
| | | self.token_list = token_list |
| | | self.pre_beam_size = int(pre_beam_ratio * beam_size) |
| | | self.beam_size = beam_size |
| | | self.n_vocab = vocab_size |
| | | if ( |
| | | pre_beam_score_key is not None |
| | | and pre_beam_score_key != "full" |
| | | and pre_beam_score_key not in self.full_scorers |
| | | ): |
| | | raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") |
| | | self.pre_beam_score_key = pre_beam_score_key |
| | | self.do_pre_beam = ( |
| | | self.pre_beam_score_key is not None |
| | | and self.pre_beam_size < self.n_vocab |
| | | and len(self.part_scorers) > 0 |
| | | ) |
| | | |
| | | def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: |
| | | """Get an initial hypothesis data. |
| | | |
| | | Args: |
| | | x (torch.Tensor): The encoder output feature |
| | | |
| | | Returns: |
| | | Hypothesis: The initial hypothesis. |
| | | |
| | | """ |
| | | init_states = dict() |
| | | init_scores = dict() |
| | | for k, d in self.scorers.items(): |
| | | init_states[k] = d.init_state(x) |
| | | init_scores[k] = 0.0 |
| | | if not isinstance(self.sos, (list, tuple)): |
| | | self.sos = [self.sos] |
| | | return [ |
| | | Hypothesis( |
| | | score=0.0, |
| | | scores=init_scores, |
| | | states=init_states, |
| | | yseq=torch.tensor(self.sos, device=x.device), |
| | | ) |
| | | ] |
| | | |
| | | @staticmethod |
| | | def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: |
| | | """Append new token to prefix tokens. |
| | | |
| | | Args: |
| | | xs (torch.Tensor): The prefix token |
| | | x (int): The new token to append |
| | | |
| | | Returns: |
| | | torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device |
| | | |
| | | """ |
| | | x = torch.tensor([x], dtype=xs.dtype, device=xs.device) |
| | | return torch.cat((xs, x)) |
| | | |
| | | def score_full( |
| | | self, hyp: Hypothesis, x: torch.Tensor |
| | | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| | | """Score new hypothesis by `self.full_scorers`. |
| | | |
| | | Args: |
| | | hyp (Hypothesis): Hypothesis with prefix tokens to score |
| | | x (torch.Tensor): Corresponding input feature |
| | | |
| | | Returns: |
| | | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
| | | score dict of `hyp` that has string keys of `self.full_scorers` |
| | | and tensor score values of shape: `(self.n_vocab,)`, |
| | | and state dict that has string keys |
| | | and state values of `self.full_scorers` |
| | | |
| | | """ |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.full_scorers.items(): |
| | | scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) |
| | | return scores, states |
| | | |
| | | def score_partial( |
| | | self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor |
| | | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| | | """Score new hypothesis by `self.part_scorers`. |
| | | |
| | | Args: |
| | | hyp (Hypothesis): Hypothesis with prefix tokens to score |
| | | ids (torch.Tensor): 1D tensor of new partial tokens to score |
| | | x (torch.Tensor): Corresponding input feature |
| | | |
| | | Returns: |
| | | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
| | | score dict of `hyp` that has string keys of `self.part_scorers` |
| | | and tensor score values of shape: `(len(ids),)`, |
| | | and state dict that has string keys |
| | | and state values of `self.part_scorers` |
| | | |
| | | """ |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.part_scorers.items(): |
| | | scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) |
| | | return scores, states |
| | | |
| | | def beam( |
| | | self, weighted_scores: torch.Tensor, ids: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute topk full token ids and partial token ids. |
| | | |
| | | Args: |
| | | weighted_scores (torch.Tensor): The weighted sum scores for each tokens. |
| | | Its shape is `(self.n_vocab,)`. |
| | | ids (torch.Tensor): The partial token ids to compute topk |
| | | |
| | | Returns: |
| | | Tuple[torch.Tensor, torch.Tensor]: |
| | | The topk full token ids and partial token ids. |
| | | Their shapes are `(self.beam_size,)` |
| | | |
| | | """ |
| | | # no pre beam performed |
| | | if weighted_scores.size(0) == ids.size(0): |
| | | top_ids = weighted_scores.topk(self.beam_size)[1] |
| | | return top_ids, top_ids |
| | | |
| | | # mask pruned in pre-beam not to select in topk |
| | | tmp = weighted_scores[ids] |
| | | weighted_scores[:] = -float("inf") |
| | | weighted_scores[ids] = tmp |
| | | top_ids = weighted_scores.topk(self.beam_size)[1] |
| | | local_ids = weighted_scores[ids].topk(self.beam_size)[1] |
| | | return top_ids, local_ids |
| | | |
| | | @staticmethod |
| | | def merge_scores( |
| | | prev_scores: Dict[str, float], |
| | | next_full_scores: Dict[str, torch.Tensor], |
| | | full_idx: int, |
| | | next_part_scores: Dict[str, torch.Tensor], |
| | | part_idx: int, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | """Merge scores for new hypothesis. |
| | | |
| | | Args: |
| | | prev_scores (Dict[str, float]): |
| | | The previous hypothesis scores by `self.scorers` |
| | | next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` |
| | | full_idx (int): The next token id for `next_full_scores` |
| | | next_part_scores (Dict[str, torch.Tensor]): |
| | | scores of partial tokens by `self.part_scorers` |
| | | part_idx (int): The new token id for `next_part_scores` |
| | | |
| | | Returns: |
| | | Dict[str, torch.Tensor]: The new score dict. |
| | | Its keys are names of `self.full_scorers` and `self.part_scorers`. |
| | | Its values are scalar tensors by the scorers. |
| | | |
| | | """ |
| | | new_scores = dict() |
| | | for k, v in next_full_scores.items(): |
| | | new_scores[k] = prev_scores[k] + v[full_idx] |
| | | for k, v in next_part_scores.items(): |
| | | new_scores[k] = prev_scores[k] + v[part_idx] |
| | | return new_scores |
| | | |
| | | def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: |
| | | """Merge states for new hypothesis. |
| | | |
| | | Args: |
| | | states: states of `self.full_scorers` |
| | | part_states: states of `self.part_scorers` |
| | | part_idx (int): The new token id for `part_scores` |
| | | |
| | | Returns: |
| | | Dict[str, torch.Tensor]: The new score dict. |
| | | Its keys are names of `self.full_scorers` and `self.part_scorers`. |
| | | Its values are states of the scorers. |
| | | |
| | | """ |
| | | new_states = dict() |
| | | for k, v in states.items(): |
| | | new_states[k] = v |
| | | for k, d in self.part_scorers.items(): |
| | | new_states[k] = d.select_state(part_states[k], part_idx) |
| | | return new_states |
| | | |
| | | def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]: |
| | | """Search new tokens for running hypotheses and encoded speech x. |
| | | |
| | | Args: |
| | | running_hyps (List[Hypothesis]): Running hypotheses on beam |
| | | x (torch.Tensor): Encoded speech feature (T, D) |
| | | |
| | | Returns: |
| | | List[Hypotheses]: Best sorted hypotheses |
| | | |
| | | """ |
| | | best_hyps = [] |
| | | part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam |
| | | for hyp in running_hyps: |
| | | # scoring |
| | | weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) |
| | | scores, states = self.score_full(hyp, x) |
| | | for k in self.full_scorers: |
| | | weighted_scores += self.weights[k] * scores[k] |
| | | # partial scoring |
| | | if self.do_pre_beam: |
| | | pre_beam_scores = ( |
| | | weighted_scores |
| | | if self.pre_beam_score_key == "full" |
| | | else scores[self.pre_beam_score_key] |
| | | ) |
| | | part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] |
| | | part_scores, part_states = self.score_partial(hyp, part_ids, x) |
| | | for k in self.part_scorers: |
| | | weighted_scores[part_ids] += self.weights[k] * part_scores[k] |
| | | # add previous hyp score |
| | | weighted_scores += hyp.score |
| | | |
| | | # update hyps |
| | | for j, part_j in zip(*self.beam(weighted_scores, part_ids)): |
| | | # will be (2 x beam at most) |
| | | best_hyps.append( |
| | | Hypothesis( |
| | | score=weighted_scores[j], |
| | | yseq=self.append_token(hyp.yseq, j), |
| | | scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j), |
| | | states=self.merge_states(states, part_states, part_j), |
| | | ) |
| | | ) |
| | | |
| | | # sort and prune 2 x beam -> beam |
| | | best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ |
| | | : min(len(best_hyps), self.beam_size) |
| | | ] |
| | | return best_hyps |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 |
| | | ) -> List[Hypothesis]: |
| | | """Perform beam search. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Encoded speech feature (T, D) |
| | | maxlenratio (float): Input length ratio to obtain max output length. |
| | | If maxlenratio=0.0 (default), it uses a end-detect function |
| | | to automatically find maximum hypothesis lengths |
| | | If maxlenratio<0.0, its absolute value is interpreted |
| | | as a constant max output length. |
| | | minlenratio (float): Input length ratio to obtain min output length. |
| | | |
| | | Returns: |
| | | list[Hypothesis]: N-best decoding results |
| | | |
| | | """ |
| | | # set length bounds |
| | | if maxlenratio == 0: |
| | | maxlen = x.shape[0] |
| | | elif maxlenratio < 0: |
| | | maxlen = -1 * int(maxlenratio) |
| | | else: |
| | | maxlen = max(1, int(maxlenratio * x.size(0))) |
| | | minlen = int(minlenratio * x.size(0)) |
| | | logging.info("decoder input length: " + str(x.shape[0])) |
| | | logging.info("max output length: " + str(maxlen)) |
| | | logging.info("min output length: " + str(minlen)) |
| | | |
| | | # main loop of prefix search |
| | | running_hyps = self.init_hyp(x) |
| | | ended_hyps = [] |
| | | for i in range(maxlen): |
| | | logging.debug("position " + str(i)) |
| | | best = self.search(running_hyps, x) |
| | | # post process of one iteration |
| | | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) |
| | | # end detection |
| | | if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): |
| | | logging.info(f"end detected at {i}") |
| | | break |
| | | if len(running_hyps) == 0: |
| | | logging.info("no hypothesis. Finish decoding.") |
| | | break |
| | | else: |
| | | logging.debug(f"remained hypotheses: {len(running_hyps)}") |
| | | |
| | | nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) |
| | | # check the number of hypotheses reaching to eos |
| | | if len(nbest_hyps) == 0: |
| | | logging.warning( |
| | | "there is no N-best results, perform recognition " "again with smaller minlenratio." |
| | | ) |
| | | return ( |
| | | [] |
| | | if minlenratio < 0.1 |
| | | else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) |
| | | ) |
| | | |
| | | # report the best result |
| | | best = nbest_hyps[0] |
| | | for k, v in best.scores.items(): |
| | | logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}") |
| | | logging.info(f"total log probability: {best.score:.2f}") |
| | | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") |
| | | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.info( |
| | | "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n" |
| | | ) |
| | | return nbest_hyps |
| | | |
| | | def post_process( |
| | | self, |
| | | i: int, |
| | | maxlen: int, |
| | | maxlenratio: float, |
| | | running_hyps: List[Hypothesis], |
| | | ended_hyps: List[Hypothesis], |
| | | ) -> List[Hypothesis]: |
| | | """Perform post-processing of beam search iterations. |
| | | |
| | | Args: |
| | | i (int): The length of hypothesis tokens. |
| | | maxlen (int): The maximum length of tokens in beam search. |
| | | maxlenratio (int): The maximum length ratio in beam search. |
| | | running_hyps (List[Hypothesis]): The running hypotheses in beam search. |
| | | ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. |
| | | |
| | | Returns: |
| | | List[Hypothesis]: The new running hypotheses. |
| | | |
| | | """ |
| | | logging.debug(f"the number of running hypotheses: {len(running_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.debug( |
| | | "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) |
| | | ) |
| | | # add eos in the final loop to avoid that there are no ended hyps |
| | | if i == maxlen - 1: |
| | | logging.info("adding <eos> in the last position in the loop") |
| | | running_hyps = [ |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps |
| | | ] |
| | | |
| | | # add ended hypotheses to a final list, and removed them from current hypotheses |
| | | # (this will be a problem, number of hyps < beam) |
| | | remained_hyps = [] |
| | | for hyp in running_hyps: |
| | | if hyp.yseq[-1] == self.eos: |
| | | # e.g., Word LM needs to add final <eos> score |
| | | for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): |
| | | s = d.final_score(hyp.states[k]) |
| | | hyp.scores[k] += s |
| | | hyp = hyp._replace(score=hyp.score + self.weights[k] * s) |
| | | ended_hyps.append(hyp) |
| | | else: |
| | | remained_hyps.append(hyp) |
| | | return remained_hyps |
| | |
| | | checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0 |
| | | ) |
| | | self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch |
| | | print(checkpoint["train_acc_avg"]) |
| | | self.train_acc_avg = ( |
| | | checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0 |
| | | ) |
| | |
| | | batch_num_epoch = len(dataloader_train) |
| | | self.log( |
| | | epoch, |
| | | batch_idx + kwargs.get("start_step", 0), |
| | | batch_idx, |
| | | log_step=batch_idx + kwargs.get("start_step", 0), |
| | | step_in_epoch=self.step_in_epoch, |
| | | batch_num_epoch=batch_num_epoch, |
| | | lr=lr, |
| | |
| | | tag="train", |
| | | data_split_i=0, |
| | | data_split_num=1, |
| | | log_step=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | if (batch_idx + 1) % self.log_interval == 0: |
| | | |
| | | batch_idx = log_step if log_step is not None else batch_idx |
| | | gpu_info = ( |
| | | "GPU, memory: usage: {:.3f} GB, " |
| | | "peak: {:.3f} GB, " |