From 3762d21300e1f3fa3e0cb1e67545227e6dcec3de Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 13 三月 2023 22:02:54 +0800
Subject: [PATCH] add streaming paraformer code
---
funasr/modules/embedding.py | 11
funasr/models/e2e_asr_paraformer.py | 74 +++
funasr/modules/attention.py | 10
funasr/models/encoder/sanm_encoder.py | 42 ++
funasr/bin/asr_inference_paraformer_streaming.py | 907 +++++++++++++++++++++++++++++++++++++++++++++
funasr/models/predictor/cif.py | 57 ++
funasr/models/decoder/sanm_decoder.py | 59 ++
funasr/bin/asr_inference_launch.py | 3
8 files changed, 1,157 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1fae766..da1241a 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -216,6 +216,9 @@
elif mode == "paraformer":
from funasr.bin.asr_inference_paraformer import inference_modelscope
return inference_modelscope(**kwargs)
+ elif mode == "paraformer_streaming":
+ from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope
+ return inference_modelscope(**kwargs)
elif mode == "paraformer_vad":
from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
return inference_modelscope(**kwargs)
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
new file mode 100644
index 0000000..9b572a0
--- /dev/null
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -0,0 +1,907 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+import sys
+import time
+import copy
+import os
+import codecs
+import tempfile
+import requests
+from pathlib import Path
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ if asr_model.ctc != None:
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = asr_model.token_list
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from str, local file or url
+
+ is_use_lm = lm_weight != 0.0 and lm_file is not None
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
+ beam_search = None
+ self.beam_search = beam_search
+ logging.info(f"Beam_search: {self.beam_search}")
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+ self.encoder_downsampling_factor = 1
+ if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
+ self.encoder_downsampling_factor = 4
+
+ @torch.no_grad()
+ def __call__(
+ self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ begin_time: int = 0, end_time: int = None,
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, enc_len = self.asr_model.encode_chunk(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ # assert len(enc) == 1, len(enc)
+ enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
+
+ predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.floor().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
+ decoder_out = decoder_outs
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = enc[i, :enc_len[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
+
+ # assert check_return_type(results)
+ return results
+
+
+class Speech2TextExport:
+ """Speech2TextExport class
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
+ **kwargs,
+ ):
+
+ # 1. Build ASR model
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ token_list = asr_model.token_list
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ # self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ model = Paraformer_export(asr_model, onnx=False)
+ self.asr_model = model
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+
+ enc_len_batch_total = feats_len.sum()
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ decoder_outs = self.asr_model(**batch)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ am_scores = decoder_out[i, :ys_pad_lens[i], :]
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ yseq.tolist(), device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
+
+ return results
+
+
+def inference(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ penalty=penalty,
+ log_level=log_level,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ raw_inputs=raw_inputs,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ key_file=key_file,
+ word_lm_train_config=word_lm_train_config,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ ngram_weight=ngram_weight,
+ nbest=nbest,
+ num_workers=num_workers,
+
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ export_mode = False
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ export_mode = param_dict.get("export_mode", False)
+ else:
+ hotword_list_or_file = None
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ batch_size = 1
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
+ )
+ if export_mode:
+ speech2text = Speech2TextExport(**speech2text_kwargs)
+ else:
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+ if hotword_list_or_file is not None or 'hotword' in kwargs:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
+
+ forward_time_total = 0.0
+ length_total = 0.0
+ finish_count = 0
+ file_count = 1
+ cache = None
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+ if param_dict is not None and "cache" in param_dict:
+ cache = param_dict["cache"]
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+ logging.info("decoding, utt_id: {}".format(keys))
+ # N-best list of (text, token, token_int, hyp_object)
+
+ time_beg = time.time()
+ results = speech2text(cache=cache, **batch)
+ if len(results) < 1:
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
+ time_end = time.time()
+ forward_time = time_end - time_beg
+ lfr_factor = results[0][-1]
+ length = results[0][-2]
+ forward_time_total += forward_time
+ length_total += length
+ rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
+ 100 * forward_time / (
+ length * lfr_factor))
+ logging.info(rtf_cur)
+
+ for batch_id in range(_bs):
+ result = [results[batch_id][:-2]]
+
+ key = keys[batch_id]
+ for n, result in zip(range(1, nbest + 1), result):
+ text, token, token_int, hyp = result[0], result[1], result[2], result[3]
+ time_stamp = None if len(result) < 5 else result[4]
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["rtf"][key] = rtf_cur
+
+ if text is not None:
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
+ time_stamp_postprocessed = ""
+ if len(postprocessed_result) == 3:
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
+ postprocessed_result[1], \
+ postprocessed_result[2]
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+ item = {'key': key, 'value': text_postprocessed}
+ if time_stamp_postprocessed != "":
+ item['time_stamp'] = time_stamp_postprocessed
+ asr_result_list.append(item)
+ finish_count += 1
+ # asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text_postprocessed
+
+ logging.info("decoding, utt: {}, predictions: {}".format(key, text))
+ rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
+ forward_time_total,
+ 100 * forward_time_total / (
+ length_total * lfr_factor))
+ logging.info(rtf_avg)
+ if writer is not None:
+ ibest_writer["rtf"]["rtf_avf"] = rtf_avg
+ return asr_result_list
+
+ return _forward
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.5,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group.add_argument(
+ "--frontend_conf",
+ default=None,
+ help="",
+ )
+ group.add_argument("--raw_inputs", type=list, default=None)
+ # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ param_dict = {'hotword': args.hotword}
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ kwargs['param_dict'] = param_dict
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
+
+ # from modelscope.pipelines import pipeline
+ # from modelscope.utils.constant import Tasks
+ #
+ # inference_16k_pipline = pipeline(
+ # task=Tasks.auto_speech_recognition,
+ # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+ #
+ # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+ # print(rec_result)
+
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index ab03f0b..0117430 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -947,6 +947,65 @@
)
return logp.squeeze(0), state
+ def forward_chunk(
+ self,
+ memory: torch.Tensor,
+ tgt: torch.Tensor,
+ cache: dict = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ x = tgt
+ if cache["decode_fsmn"] is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ new_cache = [None] * cache_layer_num
+ else:
+ new_cache = cache["decode_fsmn"]
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, None, memory, None, cache=new_cache[i]
+ )
+ new_cache[i] = c_ret
+
+ if self.num_blocks - self.att_layer_num > 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, None, memory, None, cache=new_cache[j]
+ )
+ new_cache[j] = c_ret
+
+ for decoder in self.decoders3:
+
+ x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, None, memory, None, cache=None
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+ cache["decode_fsmn"] = new_cache
+ return x
+
def forward_one_step(
self,
tgt: torch.Tensor,
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 44c9de3..02f60af 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -325,12 +325,76 @@
return encoder_out, encoder_out_lens
+ def encode_chunk(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+ feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ # Post-encoder, e.g. NLU
+ if self.postencoder is not None:
+ encoder_out, encoder_out_lens = self.postencoder(
+ encoder_out, encoder_out_lens
+ )
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def calc_predictor_chunk(self, encoder_out, cache=None):
+
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -341,6 +405,14 @@
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
+
+ def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+ decoder_outs = self.decoder.forward_chunk(
+ encoder_out, sematic_embeds, cache["decoder"]
+ )
+ decoder_out = decoder_outs
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
@@ -1459,4 +1531,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
\ No newline at end of file
+ return var_dict_torch_update
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 0751a10..57890ef 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -347,6 +347,48 @@
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+ def forward_chunk(self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ ctc: CTC = None,
+ ):
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ else:
+ xs_pad = self.embed.forward_chunk(xs_pad, cache)
+
+ encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), None, None
+ return xs_pad, ilens, None
+
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 5615373..74f3e68 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -199,6 +199,63 @@
return acoustic_embeds, token_num, alphas, cif_peak
+ def forward_chunk(self, hidden, cache=None):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+
+ alphas = alphas.squeeze(-1)
+ mask_chunk_predictor = None
+ if cache is not None:
+ mask_chunk_predictor = None
+ mask_chunk_predictor = torch.zeros_like(alphas)
+ mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0
+
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+
+ if cache is not None:
+ if cache["cif_hidden"] is not None:
+ hidden = torch.cat((cache["cif_hidden"], hidden), 1)
+ if cache["cif_alphas"] is not None:
+ alphas = torch.cat((cache["cif_alphas"], alphas), -1)
+
+ token_num = alphas.sum(-1)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ len_time = alphas.size(-1)
+ last_fire_place = len_time - 1
+ last_fire_remainds = 0.0
+ pre_alphas_length = 0
+
+ mask_chunk_peak_predictor = None
+ if cache is not None:
+ mask_chunk_peak_predictor = None
+ mask_chunk_peak_predictor = torch.zeros_like(cif_peak)
+ if cache["cif_alphas"] is not None:
+ pre_alphas_length = cache["cif_alphas"].size(-1)
+ mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0
+ mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0
+
+
+ if mask_chunk_peak_predictor is not None:
+ cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1)
+
+ for i in range(len_time):
+ if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
+ last_fire_place = len_time - 1 - i
+ last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
+ break
+ last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
+ cache["cif_hidden"] = hidden[:, last_fire_place:, :]
+ cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
+ token_num_int = token_num.floor().type(torch.int32).item()
+ return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
+
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 6277005..31d5a87 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -347,15 +347,17 @@
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
+ inputs = inputs * mask
- inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
- return x * mask
+ if mask is not None:
+ x = x * mask
+ return x
def forward_qkv(self, x):
"""Transform query, key and value.
@@ -505,7 +507,7 @@
# print("in fsmn, cache is None, x", x.size())
x = self.pad_fn(x)
- if not self.training and t <= 1:
+ if not self.training:
cache = x
else:
# print("in fsmn, cache is not None, x", x.size())
@@ -513,7 +515,7 @@
# if t < self.kernel_size:
# x = self.pad_fn(x)
x = torch.cat((cache[:, :, 1:], x), dim=2)
- x = x[:, :, -self.kernel_size:]
+ x = x[:, :, -(self.kernel_size+t-1):]
# print("in fsmn, cache is not None, x_cat", x.size())
cache = x
x = self.fsmn_block(x)
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index b61a61a..e4f9bff 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -405,4 +405,13 @@
positions = torch.arange(1, timesteps+1)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
- return x + position_encoding
\ No newline at end of file
+ return x + position_encoding
+
+ def forward_chunk(self, x, cache=None):
+ start_idx = 0
+ batch_size, timesteps, input_dim = x.size()
+ if cache is not None:
+ start_idx = cache["start_idx"]
+ positions = torch.arange(1, timesteps+start_idx+1)[None, :]
+ position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
+ return x + position_encoding[:, start_idx: start_idx + timesteps]
--
Gitblit v1.9.1