From a9e857e45250b16af60d5fe3efcd06e685f6506a Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期六, 03 十二月 2022 16:39:38 +0800
Subject: [PATCH] update funasr 0.1.3
---
funasr/version.txt | 2
funasr/datasets/iterable_dataset_modelscope.py | 349 +++++++
funasr/utils/wav_utils.py | 178 +++
funasr/models/e2e_asr_paraformer.py | 5
funasr/tasks/abs_task.py | 61 +
funasr/bin/modelscope_infer.py | 7
funasr/utils/postprocess_utils.py | 174 +++
funasr/models/frontend/wav_frontend.py | 155 +++
funasr/bin/asr_inference_paraformer_modelscope.py | 686 ++++++++++++++
egs_modelscope/common/modelscope_utils/modelscope_infer.sh | 1
funasr/models/predictor/cif.py | 2
funasr/bin/asr_inference_modelscope.py | 687 ++++++++++++++
funasr/utils/asr_env_checking.py | 85 +
funasr/utils/asr_utils.py | 327 +++++++
14 files changed, 2,712 insertions(+), 7 deletions(-)
diff --git a/egs_modelscope/common/modelscope_utils/modelscope_infer.sh b/egs_modelscope/common/modelscope_utils/modelscope_infer.sh
index 80f0d16..a0c606f 100755
--- a/egs_modelscope/common/modelscope_utils/modelscope_infer.sh
+++ b/egs_modelscope/common/modelscope_utils/modelscope_infer.sh
@@ -65,6 +65,7 @@
${decode_cmd} --max-jobs-run "${inference_nj}" JOB=1:"${inference_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.modelscope_infer \
--model_name ${model_name} \
+ --model_revision ${model_revision} \
--wav_list ${_logdir}/keys.JOB.scp \
--output_file ${_logdir}/text.JOB \
--gpuid_list ${gpuid_list} \
diff --git a/funasr/bin/asr_inference_modelscope.py b/funasr/bin/asr_inference_modelscope.py
new file mode 100755
index 0000000..fd9bd66
--- /dev/null
+++ b/funasr/bin/asr_inference_modelscope.py
@@ -0,0 +1,687 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
+from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
+from funasr.modules.beam_search.beam_search import BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+ 'audio_fs': 16000,
+ 'model_fs': 16000
+}
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, device
+ )
+ if asr_model.frontend is None and frontend_conf is not None:
+ frontend = WavFrontend(**frontend_conf)
+ asr_model.frontend = frontend
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ # TODO(karita): make all scorers batchfied
+ if batch_size == 1:
+ non_batch = [
+ k
+ for k, v in beam_search.full_scorers.items()
+ if not isinstance(v, BatchScorerInterface)
+ ]
+ if len(non_batch) == 0:
+ if streaming:
+ beam_search.__class__ = BatchBeamSearchOnlineSim
+ beam_search.set_streaming_config(asr_train_config)
+ logging.info(
+ "BatchBeamSearchOnlineSim implementation is selected."
+ )
+ else:
+ beam_search.__class__ = BatchBeamSearch
+ logging.info("BatchBeamSearch implementation is selected.")
+ else:
+ logging.warning(
+ f"As non-batch scorers {non_batch} are found, "
+ f"fall back to non-batch implementation."
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+ logging.info(f"Beam_search: {beam_search}")
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray]
+ ) -> List[
+ Tuple[
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ # data: (Nsamples,) -> (1, Nsamples)
+ speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
+ # lengths: (1,)
+ lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ batch = {"speech": speech, "speech_lengths": lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, _ = self.asr_model.encode(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ assert len(enc) == 1, len(enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+
+def inference(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ ctc_weight: float,
+ lm_weight: float,
+ ngram_weight: float,
+ penalty: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: list,
+ audio_lists: Union[List[Any], bytes],
+ key_file: Optional[str],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ word_lm_train_config: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ output_dir: Optional[str],
+ allow_variable_data_keys: bool,
+ streaming: bool,
+ frontend_conf: dict = None,
+ fs: Union[dict, int] = 16000,
+ **kwargs,
+) -> List[Any]:
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1:
+ device = "cuda"
+ else:
+ device = "cpu"
+ features_type: str = data_path_and_name_and_type[1]
+ hop_length: int = 160
+ sr: int = 16000
+ if isinstance(fs, int):
+ sr = fs
+ else:
+ if 'model_fs' in fs and fs['model_fs'] is not None:
+ sr = fs['model_fs']
+ if features_type != 'sound':
+ frontend_conf = None
+ if frontend_conf is not None:
+ if 'hop_length' in frontend_conf:
+ hop_length = frontend_conf['hop_length']
+
+ finish_count = 0
+ file_count = 1
+ if isinstance(audio_lists, bytes):
+ file_count = 1
+ else:
+ file_count = len(audio_lists)
+ if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
+ mvn_file = data_path_and_name_and_type[2]
+ mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
+ frontend_conf['mvn_data'] = mvn_data
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ frontend_conf=frontend_conf,
+ )
+ speech2text = Speech2Text(**speech2text_kwargs)
+ data_path_and_name_and_type_new = [
+ audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
+ ]
+ # 3. Build data-iterator
+ loader = ASRTask.build_streaming_iterator_modelscope(
+ data_path_and_name_and_type_new,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ sample_rate=fs
+ )
+
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ if text is not None:
+ text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+
+ return asr_result_list
+
+
+
+def set_parameters(language: str = None,
+ sample_rate: Union[int, Dict[Any, int]] = None):
+ if language is not None:
+ global global_asr_language
+ global_asr_language = language
+ if sample_rate is not None:
+ global global_sample_rate
+ global_sample_rate = sample_rate
+
+
+def asr_inference(maxlenratio: float,
+ minlenratio: float,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ name_and_type: list,
+ audio_lists: Union[List[Any], bytes],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ nbest: int = 1,
+ num_workers: int = 1,
+ log_level: Union[int, str] = 'INFO',
+ batch_size: int = 1,
+ dtype: str = 'float32',
+ seed: int = 0,
+ key_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ word_lm_file: Optional[str] = None,
+ ngram_file: Optional[str] = None,
+ ngram_weight: float = 0.9,
+ model_tag: Optional[str] = None,
+ token_type: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ transducer_conf: Optional[dict] = None,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ fs: Union[dict, int] = None,
+ lang: Optional[str] = None,
+ outputdir: Optional[str] = None):
+ if lang is not None:
+ global global_asr_language
+ global_asr_language = lang
+ if fs is not None:
+ global global_sample_rate
+ global_sample_rate = fs
+
+ # force use CPU if data type is bytes
+ if isinstance(audio_lists, bytes):
+ num_workers = 0
+ ngpu = 0
+
+ return inference(output_dir=outputdir,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ dtype=dtype,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ seed=seed,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ num_workers=num_workers,
+ log_level=log_level,
+ data_path_and_name_and_type=name_and_type,
+ audio_lists=audio_lists,
+ key_file=key_file,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ word_lm_train_config=word_lm_train_config,
+ word_lm_file=word_lm_file,
+ ngram_file=ngram_file,
+ model_tag=model_tag,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ transducer_conf=transducer_conf,
+ streaming=streaming,
+ frontend_conf=frontend_conf)
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=True,
+ action="append",
+ )
+ group.add_argument("--audio_lists", type=list,
+ default=[{'key':'EdevDEWdIYQ_0021',
+ 'file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.5,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/asr_inference_paraformer_modelscope.py b/funasr/bin/asr_inference_paraformer_modelscope.py
new file mode 100755
index 0000000..d64fe2b
--- /dev/null
+++ b/funasr/bin/asr_inference_paraformer_modelscope.py
@@ -0,0 +1,686 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+import sys
+import time
+from pathlib import Path
+from typing import Any
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import List
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+ 'audio_fs': 16000,
+ 'model_fs': 16000
+}
+
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, device
+ )
+ if asr_model.frontend is None and frontend_conf is not None:
+ frontend = WavFrontend(**frontend_conf)
+ asr_model.frontend = frontend
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+ logging.info(f"Beam_search: {beam_search}")
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray]
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ # data: (Nsamples,) -> (1, Nsamples)
+ speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
+ # lengths: (1,)
+ lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ batch = {"speech": speech, "speech_lengths": lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, enc_len = self.asr_model.encode(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ assert len(enc) == 1, len(enc)
+
+ predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
+ pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
+ pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device)
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ nbest_hyps = self.beam_search(
+ x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ results.append((text, token, token_int, hyp, speech.size(1), lfr_factor))
+
+ # assert check_return_type(results)
+ return results
+
+
+def inference(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ ctc_weight: float,
+ lm_weight: float,
+ ngram_weight: float,
+ penalty: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: list,
+ audio_lists: Union[List[Any], bytes],
+ key_file: Optional[str],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ word_lm_train_config: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ output_dir: Optional[str],
+ allow_variable_data_keys: bool,
+ frontend_conf: dict = None,
+ fs: Union[dict, int] = 16000,
+ **kwargs,
+) -> List[Any]:
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1:
+ device = "cuda"
+ else:
+ device = "cpu"
+ # data_path_and_name_and_type = data_path_and_name_and_type[0]
+ features_type: str = data_path_and_name_and_type[1]
+ hop_length: int = 160
+ sr: int = 16000
+ if isinstance(fs, int):
+ sr = fs
+ else:
+ if 'model_fs' in fs and fs['model_fs'] is not None:
+ sr = fs['model_fs']
+ if features_type != 'sound':
+ frontend_conf = None
+ if frontend_conf is not None:
+ if 'hop_length' in frontend_conf:
+ hop_length = frontend_conf['hop_length']
+
+ finish_count = 0
+ file_count = 1
+ if isinstance(audio_lists, bytes):
+ file_count = 1
+ else:
+ file_count = len(audio_lists)
+ if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
+ mvn_file = data_path_and_name_and_type[2]
+ mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
+ frontend_conf['mvn_data'] = mvn_data
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ frontend_conf=frontend_conf,
+ )
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ data_path_and_name_and_type_new = [
+ audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
+ ]
+
+ # 3. Build data-iterator
+ loader = ASRTask.build_streaming_iterator_modelscope(
+ data_path_and_name_and_type_new,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ sample_rate=fs
+ )
+
+ forward_time_total = 0.0
+ length_total = 0.0
+ asr_result_list = []
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ logging.info("decoding, utt_id: {}".format(keys))
+ # N-best list of (text, token, token_int, hyp_object)
+
+ try:
+ time_beg = time.time()
+ results = speech2text(**batch)
+ time_end = time.time()
+ forward_time = time_end - time_beg
+ lfr_factor = results[0][-1]
+ length = results[0][-2]
+ results = [results[0][:-2]]
+ forward_time_total += forward_time
+ length_total += length
+ logging.info(
+ "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
+ format(length, forward_time, 100 * forward_time / (length * lfr_factor)))
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ if text is not None:
+ text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+
+ logging.info("decoding, predictions: {}".format(text))
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+
+ logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
+ format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
+ if features_type == 'sound':
+ # data format is wav
+ length_total_seconds = length_total / sr
+ length_total_bytes = length_total * 2
+ else:
+ # data format is kaldi_ark
+ length_total_seconds = length_total * hop_length / sr
+ length_total_bytes = length_total * hop_length * 2
+
+ logger.info(
+ header_colors + # noqa: *
+ 'decoding, feature length total: {}bytes, forward_time total: {:.4f}s, rtf avg: {:.4f}'
+ .format(length_total_bytes, forward_time_total, forward_time_total /
+ length_total_seconds) + end_colors)
+
+ return asr_result_list
+
+
+def set_parameters(language: str = None,
+ sample_rate: Union[int, Dict[Any, int]] = None):
+ if language is not None:
+ global global_asr_language
+ global_asr_language = language
+ if sample_rate is not None:
+ global global_sample_rate
+ global_sample_rate = sample_rate
+
+
+def asr_inference(maxlenratio: float,
+ minlenratio: float,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ name_and_type: list,
+ audio_lists: Union[List[Any], bytes],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ nbest: int = 1,
+ num_workers: int = 1,
+ log_level: Union[int, str] = 'INFO',
+ batch_size: int = 1,
+ dtype: str = 'float32',
+ seed: int = 0,
+ key_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ word_lm_file: Optional[str] = None,
+ ngram_file: Optional[str] = None,
+ ngram_weight: float = 0.9,
+ model_tag: Optional[str] = None,
+ token_type: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ transducer_conf: Optional[dict] = None,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ fs: Union[dict, int] = None,
+ lang: Optional[str] = None,
+ outputdir: Optional[str] = None):
+ if lang is not None:
+ global global_asr_language
+ global_asr_language = lang
+ if fs is not None:
+ global global_sample_rate
+ global_sample_rate = fs
+
+ # force use CPU if data type is bytes
+ if isinstance(audio_lists, bytes):
+ num_workers = 0
+ ngpu = 0
+
+ return inference(output_dir=outputdir,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ dtype=dtype,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ seed=seed,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ num_workers=num_workers,
+ log_level=log_level,
+ data_path_and_name_and_type=name_and_type,
+ audio_lists=audio_lists,
+ key_file=key_file,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ word_lm_train_config=word_lm_train_config,
+ word_lm_file=word_lm_file,
+ ngram_file=ngram_file,
+ model_tag=model_tag,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ transducer_conf=transducer_conf,
+ streaming=streaming,
+ frontend_conf=frontend_conf)
+
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=True,
+ action="append",
+ )
+ group.add_argument("--audio_lists", type=list, default=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.5,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group.add_argument(
+ "--asr_model_config",
+ default=None,
+ help="",
+ )
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/modelscope_infer.py b/funasr/bin/modelscope_infer.py
index 440c881..74c2fb7 100755
--- a/funasr/bin/modelscope_infer.py
+++ b/funasr/bin/modelscope_infer.py
@@ -15,6 +15,10 @@
type=str,
default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model name in modelscope")
+ parser.add_argument("--model_revision",
+ type=str,
+ default="v1.0.3",
+ help="model revision in modelscope")
parser.add_argument("--local_model_path",
type=str,
default=None,
@@ -62,7 +66,8 @@
if args.local_model_path is None:
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
- model="damo/{}".format(args.model_name))
+ model="damo/{}".format(args.model_name),
+ model_revision=args.model_revision)
else:
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
diff --git a/funasr/datasets/iterable_dataset_modelscope.py b/funasr/datasets/iterable_dataset_modelscope.py
new file mode 100644
index 0000000..860492c
--- /dev/null
+++ b/funasr/datasets/iterable_dataset_modelscope.py
@@ -0,0 +1,349 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from espnet/espnet.
+"""Iterable dataset module."""
+import copy
+from io import StringIO
+from pathlib import Path
+from typing import Callable, Collection, Dict, Iterator, Tuple, Union
+
+import kaldiio
+import numpy as np
+import soundfile
+import torch
+from funasr.datasets.dataset import ESPnetDataset
+from torch.utils.data.dataset import IterableDataset
+from typeguard import check_argument_types
+
+from funasr.utils import wav_utils
+
+
+def load_kaldi(input):
+ retval = kaldiio.load_mat(input)
+ if isinstance(retval, tuple):
+ assert len(retval) == 2, len(retval)
+ if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+ # sound scp case
+ rate, array = retval
+ elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+ # Extended ark format case
+ array, rate = retval
+ else:
+ raise RuntimeError(
+ f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
+
+ # Multichannel wave fie
+ # array: (NSample, Channel) or (Nsample)
+
+ else:
+ # Normal ark case
+ assert isinstance(retval, np.ndarray), type(retval)
+ array = retval
+ return array
+
+
+DATA_TYPES = {
+ 'sound':
+ lambda x: soundfile.read(x)[0],
+ 'kaldi_ark':
+ load_kaldi,
+ 'npy':
+ np.load,
+ 'text_int':
+ lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
+ 'csv_int':
+ lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
+ 'text_float':
+ lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
+ ),
+ 'csv_float':
+ lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
+ ),
+ 'text':
+ lambda x: x,
+}
+
+
+class IterableESPnetDatasetModelScope(IterableDataset):
+ """Pytorch Dataset class for ESPNet.
+
+ Examples:
+ >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
+ ... ('token_int', 'output', 'text_int')],
+ ... )
+ >>> for uid, data in dataset:
+ ... data
+ {'input': per_utt_array, 'output': per_utt_array}
+ """
+ def __init__(self,
+ path_name_type_list: Collection[Tuple[any, str, str]],
+ preprocess: Callable[[str, Dict[str, np.ndarray]],
+ Dict[str, np.ndarray]] = None,
+ float_dtype: str = 'float32',
+ int_dtype: str = 'long',
+ key_file: str = None,
+ sample_rate: Union[dict, int] = 16000):
+ assert check_argument_types()
+ if len(path_name_type_list) == 0:
+ raise ValueError(
+ '1 or more elements are required for "path_name_type_list"')
+
+ self.preprocess = preprocess
+
+ self.float_dtype = float_dtype
+ self.int_dtype = int_dtype
+ self.key_file = key_file
+ self.sample_rate = sample_rate
+
+ self.debug_info = {}
+ non_iterable_list = []
+ self.path_name_type_list = []
+
+ path_list = path_name_type_list[0]
+ name = path_name_type_list[1]
+ _type = path_name_type_list[2]
+ if name in self.debug_info:
+ raise RuntimeError(f'"{name}" is duplicated for data-key')
+ self.debug_info[name] = path_list, _type
+ # for path, name, _type in path_name_type_list:
+ for path in path_list:
+ self.path_name_type_list.append((path, name, _type))
+
+ if len(non_iterable_list) != 0:
+ # Some types doesn't support iterable mode
+ self.non_iterable_dataset = ESPnetDataset(
+ path_name_type_list=non_iterable_list,
+ preprocess=preprocess,
+ float_dtype=float_dtype,
+ int_dtype=int_dtype,
+ )
+ else:
+ self.non_iterable_dataset = None
+
+ self.apply_utt2category = False
+
+ def has_name(self, name) -> bool:
+ return name in self.debug_info
+
+ def names(self) -> Tuple[str, ...]:
+ return tuple(self.debug_info)
+
+ def __repr__(self):
+ _mes = self.__class__.__name__
+ _mes += '('
+ for name, (path, _type) in self.debug_info.items():
+ _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
+ _mes += f'\n preprocess: {self.preprocess})'
+ return _mes
+
+ def __iter__(
+ self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
+ torch.set_printoptions(profile='default')
+ count = len(self.path_name_type_list)
+ for idx in range(count):
+ # 2. Load the entry from each line and create a dict
+ data = {}
+ # 2.a. Load data streamingly
+
+ # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
+ value = self.path_name_type_list[idx][0]['file']
+ uid = self.path_name_type_list[idx][0]['key']
+ # name: speech
+ name = self.path_name_type_list[idx][1]
+ _type = self.path_name_type_list[idx][2]
+ func = DATA_TYPES[_type]
+ array = func(value)
+
+ # 2.b. audio resample
+ if _type == 'sound':
+ audio_sr: int = 16000
+ model_sr: int = 16000
+ if isinstance(self.sample_rate, int):
+ model_sr = self.sample_rate
+ else:
+ if 'audio_sr' in self.sample_rate:
+ audio_sr = self.sample_rate['audio_sr']
+ if 'model_sr' in self.sample_rate:
+ model_sr = self.sample_rate['model_sr']
+ array = wav_utils.torch_resample(array, audio_sr, model_sr)
+
+ # array: [ 1.25122070e-03 ... ]
+ data[name] = array
+
+ # 3. [Option] Apply preprocessing
+ # e.g. espnet2.train.preprocessor:CommonPreprocessor
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+ # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
+
+ # 4. Force data-precision
+ for name in data:
+ # value is np.ndarray data
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f'All values must be converted to np.ndarray object '
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+
+ # Cast to desired type
+ if value.dtype.kind == 'f':
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == 'i':
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(
+ f'Not supported dtype: {value.dtype}')
+ data[name] = value
+
+ yield uid, data
+
+ if count == 0:
+ raise RuntimeError('No iteration')
+
+
+class IterableESPnetBytesModelScope(IterableDataset):
+ """Pytorch audio bytes class for ESPNet.
+
+ Examples:
+ >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
+ ... ('token_int', 'output', 'text_int')],
+ ... )
+ >>> for uid, data in dataset:
+ ... data
+ {'input': per_utt_array, 'output': per_utt_array}
+ """
+ def __init__(self,
+ path_name_type_list: Collection[Tuple[any, str, str]],
+ preprocess: Callable[[str, Dict[str, np.ndarray]],
+ Dict[str, np.ndarray]] = None,
+ float_dtype: str = 'float32',
+ int_dtype: str = 'long',
+ key_file: str = None,
+ sample_rate: Union[dict, int] = 16000):
+ assert check_argument_types()
+ if len(path_name_type_list) == 0:
+ raise ValueError(
+ '1 or more elements are required for "path_name_type_list"')
+
+ self.preprocess = preprocess
+
+ self.float_dtype = float_dtype
+ self.int_dtype = int_dtype
+ self.key_file = key_file
+ self.sample_rate = sample_rate
+
+ self.debug_info = {}
+ non_iterable_list = []
+ self.path_name_type_list = []
+
+ audio_data = path_name_type_list[0]
+ name = path_name_type_list[1]
+ _type = path_name_type_list[2]
+ if name in self.debug_info:
+ raise RuntimeError(f'"{name}" is duplicated for data-key')
+ self.debug_info[name] = audio_data, _type
+ self.path_name_type_list.append((audio_data, name, _type))
+
+ if len(non_iterable_list) != 0:
+ # Some types doesn't support iterable mode
+ self.non_iterable_dataset = ESPnetDataset(
+ path_name_type_list=non_iterable_list,
+ preprocess=preprocess,
+ float_dtype=float_dtype,
+ int_dtype=int_dtype,
+ )
+ else:
+ self.non_iterable_dataset = None
+
+ self.apply_utt2category = False
+
+ if float_dtype == 'float32':
+ self.np_dtype = np.float32
+
+ def has_name(self, name) -> bool:
+ return name in self.debug_info
+
+ def names(self) -> Tuple[str, ...]:
+ return tuple(self.debug_info)
+
+ def __repr__(self):
+ _mes = self.__class__.__name__
+ _mes += '('
+ for name, (path, _type) in self.debug_info.items():
+ _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
+ _mes += f'\n preprocess: {self.preprocess})'
+ return _mes
+
+ def __iter__(
+ self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
+
+ torch.set_printoptions(profile='default')
+ # 2. Load the entry from each line and create a dict
+ data = {}
+ # 2.a. Load data streamingly
+
+ value = self.path_name_type_list[0][0]
+ uid = 'pcm_data'
+ # name: speech
+ name = self.path_name_type_list[0][1]
+ _type = self.path_name_type_list[0][2]
+ func = DATA_TYPES[_type]
+ # array: [ 1.25122070e-03 ... ]
+ # data[name] = np.frombuffer(value, dtype=self.np_dtype)
+
+ # 2.b. byte(PCM16) to float32
+ middle_data = np.frombuffer(value, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in 'iu':
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype('float32')
+ if dtype.kind != 'f':
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2**(i.bits - 1)
+ offset = i.min + abs_max
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
+ dtype=self.np_dtype)
+
+ # 2.c. audio resample
+ if _type == 'sound':
+ audio_sr: int = 16000
+ model_sr: int = 16000
+ if isinstance(self.sample_rate, int):
+ model_sr = self.sample_rate
+ else:
+ if 'audio_sr' in self.sample_rate:
+ audio_sr = self.sample_rate['audio_sr']
+ if 'model_sr' in self.sample_rate:
+ model_sr = self.sample_rate['model_sr']
+ array = wav_utils.torch_resample(array, audio_sr, model_sr)
+
+ data[name] = array
+
+ # 3. [Option] Apply preprocessing
+ # e.g. espnet2.train.preprocessor:CommonPreprocessor
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+ # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
+
+ # 4. Force data-precision
+ for name in data:
+ # value is np.ndarray data
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f'All values must be converted to np.ndarray object '
+ f'by preprocessing, but "{name}" is still {type(value)}.')
+
+ # Cast to desired type
+ if value.dtype.kind == 'f':
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == 'i':
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(
+ f'Not supported dtype: {value.dtype}')
+ data[name] = value
+
+ yield uid, data
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 5ea28f3..89f7cf0 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -330,9 +330,10 @@
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
- decoder_out, _ = self.decoder(
+ decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
)
+ decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@@ -553,7 +554,6 @@
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
- joint_network: Optional[torch.nn.Module],
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
@@ -590,7 +590,6 @@
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
- joint_network=joint_network,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
new file mode 100644
index 0000000..c0b28ff
--- /dev/null
+++ b/funasr/models/frontend/wav_frontend.py
@@ -0,0 +1,155 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from espnet/espnet.
+
+import copy
+from typing import Optional, Tuple, Union
+
+import humanfriendly
+import numpy as np
+import torch
+import torchaudio.compliance.kaldi as kaldi
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.layers.log_mel import LogMel
+from funasr.layers.stft import Stft
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.modules.frontends.frontend import Frontend
+from typeguard import check_argument_types
+
+
+def apply_cmvn(inputs, mvn): # noqa
+ """
+ Apply CMVN with mvn data
+ """
+
+ device = inputs.device
+ dtype = inputs.dtype
+ frame, dim = inputs.shape
+
+ meams = np.tile(mvn[0:1, :dim], (frame, 1))
+ vars = np.tile(mvn[1:2, :dim], (frame, 1))
+ inputs += torch.from_numpy(meams).type(dtype).to(device)
+ inputs *= torch.from_numpy(vars).type(dtype).to(device)
+
+ return inputs.type(torch.float32)
+
+
+def apply_lfr(inputs, lfr_m, lfr_n):
+ LFR_inputs = []
+ T = inputs.shape[0]
+ T_lfr = int(np.ceil(T / lfr_n))
+ left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
+ inputs = torch.vstack((left_padding, inputs))
+ T = T + (lfr_m - 1) // 2
+ for i in range(T_lfr):
+ if lfr_m <= T - i * lfr_n:
+ LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
+ else: # process last LFR frame
+ num_padding = lfr_m - (T - i * lfr_n)
+ frame = (inputs[i * lfr_n:]).view(-1)
+ for _ in range(num_padding):
+ frame = torch.hstack((frame, inputs[-1]))
+ LFR_inputs.append(frame)
+ LFR_outputs = torch.vstack(LFR_inputs)
+ return LFR_outputs.type(torch.float32)
+
+
+class WavFrontend(AbsFrontend):
+ """Conventional frontend structure for ASR.
+ """
+ def __init__(
+ self,
+ fs: Union[int, str] = 16000,
+ n_fft: int = 512,
+ win_length: int = 400,
+ hop_length: int = 160,
+ window: Optional[str] = 'hamming',
+ center: bool = True,
+ normalized: bool = False,
+ onesided: bool = True,
+ n_mels: int = 80,
+ fmin: int = None,
+ fmax: int = None,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ htk: bool = False,
+ mvn_data=None,
+ frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+ apply_stft: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ if isinstance(fs, str):
+ fs = humanfriendly.parse_size(fs)
+
+ # Deepcopy (In general, dict shouldn't be used as default arg)
+ frontend_conf = copy.deepcopy(frontend_conf)
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.window = window
+ self.fs = fs
+ self.mvn_data = mvn_data
+ self.lfr_m = lfr_m
+ self.lfr_n = lfr_n
+
+ if apply_stft:
+ self.stft = Stft(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ center=center,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ )
+ else:
+ self.stft = None
+ self.apply_stft = apply_stft
+
+ if frontend_conf is not None:
+ self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
+ else:
+ self.frontend = None
+
+ self.logmel = LogMel(
+ fs=fs,
+ n_fft=n_fft,
+ n_mels=n_mels,
+ fmin=fmin,
+ fmax=fmax,
+ htk=htk,
+ )
+ self.n_mels = n_mels
+ self.frontend_type = 'default'
+
+ def output_size(self) -> int:
+ return self.n_mels
+
+ def forward(
+ self, input: torch.Tensor,
+ input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ sample_frequency = self.fs
+ num_mel_bins = self.n_mels
+ frame_length = self.win_length * 1000 / sample_frequency
+ frame_shift = self.hop_length * 1000 / sample_frequency
+
+ waveform = input * (1 << 15)
+
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=1.0,
+ energy_floor=0.0,
+ window_type=self.window,
+ sample_frequency=sample_frequency)
+ if self.lfr_m != 1 or self.lfr_n != 1:
+ mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
+ if self.mvn_data is not None:
+ mat = apply_cmvn(mat, self.mvn_data)
+
+ input_feats = mat[None, :]
+ feats_lens = torch.randn(1)
+ feats_lens.fill_(input_feats.shape[1])
+
+ return input_feats, feats_lens
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index cf60eaf..ea41c6c 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -4,7 +4,7 @@
from funasr.modules.nets_utils import make_pad_mask
class CifPredictor(nn.Module):
- def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
+ def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
super(CifPredictor, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5ea78c3..d716423 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -38,6 +38,7 @@
from funasr.datasets.dataset import DATA_TYPES
from funasr.datasets.dataset import ESPnetDataset
from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
@@ -1026,7 +1027,7 @@
@classmethod
def check_task_requirements(
cls,
- dataset: Union[AbsDataset, IterableESPnetDataset],
+ dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
allow_variable_data_keys: bool,
train: bool,
inference: bool = False,
@@ -1748,6 +1749,64 @@
**kwargs,
)
+ @classmethod
+ def build_streaming_iterator_modelscope(
+ cls,
+ data_path_and_name_and_type,
+ preprocess_fn,
+ collate_fn,
+ key_file: str = None,
+ batch_size: int = 1,
+ dtype: str = np.float32,
+ num_workers: int = 1,
+ allow_variable_data_keys: bool = False,
+ ngpu: int = 0,
+ inference: bool = False,
+ sample_rate: Union[dict, int] = 16000
+ ) -> DataLoader:
+ """Build DataLoader using iterable dataset"""
+ assert check_argument_types()
+ # For backward compatibility for pytorch DataLoader
+ if collate_fn is not None:
+ kwargs = dict(collate_fn=collate_fn)
+ else:
+ kwargs = {}
+
+ audio_data = data_path_and_name_and_type[0]
+ if isinstance(audio_data, bytes):
+ dataset = IterableESPnetBytesModelScope(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ sample_rate=sample_rate
+ )
+ else:
+ dataset = IterableESPnetDatasetModelScope(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ sample_rate=sample_rate
+ )
+
+ if dataset.apply_utt2category:
+ kwargs.update(batch_size=1)
+ else:
+ kwargs.update(batch_size=batch_size)
+
+ cls.check_task_requirements(dataset,
+ allow_variable_data_keys,
+ train=False,
+ inference=inference)
+
+ return DataLoader(
+ dataset=dataset,
+ pin_memory=ngpu > 0,
+ num_workers=num_workers,
+ **kwargs,
+ )
+
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
diff --git a/funasr/utils/asr_env_checking.py b/funasr/utils/asr_env_checking.py
new file mode 100644
index 0000000..c393ee5
--- /dev/null
+++ b/funasr/utils/asr_env_checking.py
@@ -0,0 +1,85 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import shutil
+import ssl
+
+import nltk
+
+# mkdir nltk_data dir if not exist
+try:
+ nltk.data.find('.')
+except LookupError:
+ dir_list = nltk.data.path
+ for dir_item in dir_list:
+ if not os.path.exists(dir_item):
+ os.mkdir(dir_item)
+ if os.path.exists(dir_item):
+ break
+
+# download one package if nltk_data not exist
+try:
+ nltk.data.find('.')
+except: # noqa: *
+ try:
+ _create_unverified_https_context = ssl._create_unverified_context
+ except AttributeError:
+ pass
+ else:
+ ssl._create_default_https_context = _create_unverified_https_context
+
+ nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
+
+# deploy taggers/averaged_perceptron_tagger
+try:
+ nltk.data.find('taggers/averaged_perceptron_tagger')
+except: # noqa: *
+ data_dir = nltk.data.find('.')
+ target_dir = os.path.join(data_dir, 'taggers')
+ if not os.path.exists(target_dir):
+ os.mkdir(target_dir)
+ src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
+ 'averaged_perceptron_tagger.zip')
+ shutil.copyfile(src_file,
+ os.path.join(target_dir, 'averaged_perceptron_tagger.zip'))
+ shutil._unpack_zipfile(
+ os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir)
+
+# deploy corpora/cmudict
+try:
+ nltk.data.find('corpora/cmudict')
+except: # noqa: *
+ data_dir = nltk.data.find('.')
+ target_dir = os.path.join(data_dir, 'corpora')
+ if not os.path.exists(target_dir):
+ os.mkdir(target_dir)
+ src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
+ 'cmudict.zip')
+ shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip'))
+ shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir)
+
+try:
+ nltk.data.find('taggers/averaged_perceptron_tagger')
+except: # noqa: *
+ try:
+ _create_unverified_https_context = ssl._create_unverified_context
+ except AttributeError:
+ pass
+ else:
+ ssl._create_default_https_context = _create_unverified_https_context
+
+ nltk.download('averaged_perceptron_tagger',
+ halt_on_error=False,
+ raise_on_error=True)
+
+try:
+ nltk.data.find('corpora/cmudict')
+except: # noqa: *
+ try:
+ _create_unverified_https_context = ssl._create_unverified_context
+ except AttributeError:
+ pass
+ else:
+ ssl._create_default_https_context = _create_unverified_https_context
+
+ nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
new file mode 100644
index 0000000..4258f05
--- /dev/null
+++ b/funasr/utils/asr_utils.py
@@ -0,0 +1,327 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import struct
+from typing import Any, Dict, List, Union
+
+import librosa
+import numpy as np
+import pkg_resources
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+green_color = '\033[1;32m'
+red_color = '\033[0;31;40m'
+yellow_color = '\033[0;33;40m'
+end_color = '\033[0m'
+
+global_asr_language = 'zh-cn'
+
+
+def get_version():
+ return float(pkg_resources.get_distribution('easyasr').version)
+
+
+def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
+ r_audio_fs = None
+
+ if audio_format == 'wav':
+ r_audio_fs = get_sr_from_wav(audio_in)
+ elif audio_format == 'pcm' and isinstance(audio_in, bytes):
+ r_audio_fs = get_sr_from_bytes(audio_in)
+
+ return r_audio_fs
+
+
+def type_checking(audio_in: Union[str, bytes],
+ audio_fs: int = None,
+ recog_type: str = None,
+ audio_format: str = None):
+ r_recog_type = recog_type
+ r_audio_format = audio_format
+ r_wav_path = audio_in
+
+ if isinstance(audio_in, str):
+ assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
+ elif isinstance(audio_in, bytes):
+ assert len(audio_in) > 0, 'audio in is empty'
+ r_audio_format = 'pcm'
+ r_recog_type = 'wav'
+
+ if r_recog_type is None:
+ # audio_in is wav, recog_type is wav_file
+ if os.path.isfile(audio_in):
+ if audio_in.endswith('.wav') or audio_in.endswith('.WAV'):
+ r_recog_type = 'wav'
+ r_audio_format = 'wav'
+
+ # recog_type is datasets_file
+ elif os.path.isdir(audio_in):
+ dir_name = os.path.basename(audio_in)
+ if 'test' in dir_name:
+ r_recog_type = 'test'
+ elif 'dev' in dir_name:
+ r_recog_type = 'dev'
+ elif 'train' in dir_name:
+ r_recog_type = 'train'
+
+ if r_audio_format is None:
+ if find_file_by_ends(audio_in, '.ark'):
+ r_audio_format = 'kaldi_ark'
+ elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
+ audio_in, '.WAV'):
+ r_audio_format = 'wav'
+ elif find_file_by_ends(audio_in, '.records'):
+ r_audio_format = 'tfrecord'
+
+ if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
+ # datasets with kaldi_ark file
+ r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
+ elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
+ # datasets with tensorflow records file
+ r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
+ elif r_audio_format == 'wav' and r_recog_type != 'wav':
+ # datasets with waveform files
+ r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
+
+ return r_recog_type, r_audio_format, r_wav_path
+
+
+def get_sr_from_bytes(wav: bytes):
+ sr = None
+ data = wav
+ if len(data) > 44:
+ try:
+ header_fields = {}
+ header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
+ header_fields['Format'] = str(data[8:12], 'UTF-8')
+ header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
+ if header_fields['ChunkID'] == 'RIFF' and header_fields[
+ 'Format'] == 'WAVE' and header_fields[
+ 'Subchunk1ID'] == 'fmt ':
+ header_fields['SampleRate'] = struct.unpack('<I',
+ data[24:28])[0]
+ sr = header_fields['SampleRate']
+ except Exception:
+ # no treatment
+ pass
+ else:
+ logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
+
+ return sr
+
+
+def get_sr_from_wav(fname: str):
+ fs = None
+ if os.path.isfile(fname):
+ audio, fs = librosa.load(fname, sr=None)
+ return fs
+ elif os.path.isdir(fname):
+ dir_files = os.listdir(fname)
+ for file in dir_files:
+ file_path = os.path.join(fname, file)
+ if os.path.isfile(file_path):
+ if file_path.endswith('.wav') or file_path.endswith('.WAV'):
+ fs = get_sr_from_wav(file_path)
+ elif os.path.isdir(file_path):
+ fs = get_sr_from_wav(file_path)
+
+ if fs is not None:
+ break
+
+ return fs
+
+
+def find_file_by_ends(dir_path: str, ends: str):
+ dir_files = os.listdir(dir_path)
+ for file in dir_files:
+ file_path = os.path.join(dir_path, file)
+ if os.path.isfile(file_path):
+ if file_path.endswith(ends):
+ return True
+ elif os.path.isdir(file_path):
+ if find_file_by_ends(file_path, ends):
+ return True
+
+ return False
+
+
+def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
+ dir_files = os.listdir(dir_path)
+ for file in dir_files:
+ file_path = os.path.join(dir_path, file)
+ if os.path.isfile(file_path):
+ if file_path.endswith('.wav') or file_path.endswith('.WAV'):
+ wav_list.append(file_path)
+ elif os.path.isdir(file_path):
+ recursion_dir_all_wav(wav_list, file_path)
+
+ return wav_list
+
+
+def set_parameters(language: str = None):
+ if language is not None:
+ global global_asr_language
+ global_asr_language = language
+
+
+def compute_wer(hyp_list: List[Any],
+ ref_list: List[Any],
+ lang: str = None) -> Dict[str, Any]:
+ assert len(hyp_list) > 0, 'hyp list is empty'
+ assert len(ref_list) > 0, 'ref list is empty'
+
+ if lang is not None:
+ global global_asr_language
+ global_asr_language = lang
+
+ rst = {
+ 'Wrd': 0,
+ 'Corr': 0,
+ 'Ins': 0,
+ 'Del': 0,
+ 'Sub': 0,
+ 'Snt': 0,
+ 'Err': 0.0,
+ 'S.Err': 0.0,
+ 'wrong_words': 0,
+ 'wrong_sentences': 0
+ }
+
+ for h_item in hyp_list:
+ for r_item in ref_list:
+ if h_item['key'] == r_item['key']:
+ out_item = compute_wer_by_line(h_item['value'],
+ r_item['value'],
+ global_asr_language)
+ rst['Wrd'] += out_item['nwords']
+ rst['Corr'] += out_item['cor']
+ rst['wrong_words'] += out_item['wrong']
+ rst['Ins'] += out_item['ins']
+ rst['Del'] += out_item['del']
+ rst['Sub'] += out_item['sub']
+ rst['Snt'] += 1
+ if out_item['wrong'] > 0:
+ rst['wrong_sentences'] += 1
+ print_wrong_sentence(key=h_item['key'],
+ hyp=h_item['value'],
+ ref=r_item['value'])
+ else:
+ print_correct_sentence(key=h_item['key'],
+ hyp=h_item['value'],
+ ref=r_item['value'])
+
+ break
+
+ if rst['Wrd'] > 0:
+ rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
+ if rst['Snt'] > 0:
+ rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
+
+ return rst
+
+
+def compute_wer_by_line(hyp: List[str],
+ ref: List[str],
+ lang: str = 'zh-cn') -> Dict[str, Any]:
+ if lang != 'zh-cn':
+ hyp = hyp.split()
+ ref = ref.split()
+
+ hyp = list(map(lambda x: x.lower(), hyp))
+ ref = list(map(lambda x: x.lower(), ref))
+
+ len_hyp = len(hyp)
+ len_ref = len(ref)
+
+ cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+ ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+ for i in range(len_hyp + 1):
+ cost_matrix[i][0] = i
+ for j in range(len_ref + 1):
+ cost_matrix[0][j] = j
+
+ for i in range(1, len_hyp + 1):
+ for j in range(1, len_ref + 1):
+ if hyp[i - 1] == ref[j - 1]:
+ cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+ else:
+ substitution = cost_matrix[i - 1][j - 1] + 1
+ insertion = cost_matrix[i - 1][j] + 1
+ deletion = cost_matrix[i][j - 1] + 1
+
+ compare_val = [substitution, insertion, deletion]
+
+ min_val = min(compare_val)
+ operation_idx = compare_val.index(min_val) + 1
+ cost_matrix[i][j] = min_val
+ ops_matrix[i][j] = operation_idx
+
+ match_idx = []
+ i = len_hyp
+ j = len_ref
+ rst = {
+ 'nwords': len_ref,
+ 'cor': 0,
+ 'wrong': 0,
+ 'ins': 0,
+ 'del': 0,
+ 'sub': 0
+ }
+ while i >= 0 or j >= 0:
+ i_idx = max(0, i)
+ j_idx = max(0, j)
+
+ if ops_matrix[i_idx][j_idx] == 0: # correct
+ if i - 1 >= 0 and j - 1 >= 0:
+ match_idx.append((j - 1, i - 1))
+ rst['cor'] += 1
+
+ i -= 1
+ j -= 1
+
+ elif ops_matrix[i_idx][j_idx] == 2: # insert
+ i -= 1
+ rst['ins'] += 1
+
+ elif ops_matrix[i_idx][j_idx] == 3: # delete
+ j -= 1
+ rst['del'] += 1
+
+ elif ops_matrix[i_idx][j_idx] == 1: # substitute
+ i -= 1
+ j -= 1
+ rst['sub'] += 1
+
+ if i < 0 and j >= 0:
+ rst['del'] += 1
+ elif j < 0 and i >= 0:
+ rst['ins'] += 1
+
+ match_idx.reverse()
+ wrong_cnt = cost_matrix[len_hyp][len_ref]
+ rst['wrong'] = wrong_cnt
+
+ return rst
+
+
+def print_wrong_sentence(key: str, hyp: str, ref: str):
+ space = len(key)
+ print(key + yellow_color + ' ref: ' + ref)
+ print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
+
+
+def print_correct_sentence(key: str, hyp: str, ref: str):
+ space = len(key)
+ print(key + yellow_color + ' ref: ' + ref)
+ print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
+
+
+def print_progress(percent):
+ if percent > 1:
+ percent = 1
+ res = int(50 * percent) * '#'
+ print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
new file mode 100644
index 0000000..72080ae
--- /dev/null
+++ b/funasr/utils/postprocess_utils.py
@@ -0,0 +1,174 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import string
+from typing import Any, List, Union
+
+
+def isChinese(ch: str):
+ if '\u4e00' <= ch <= '\u9fff':
+ return True
+ return False
+
+
+def isAllChinese(word: Union[List[Any], str]):
+ word_lists = []
+ table = str.maketrans('', '', string.punctuation)
+ for i in word:
+ cur = i.translate(table)
+ cur = cur.replace(' ', '')
+ cur = cur.replace('</s>', '')
+ cur = cur.replace('<s>', '')
+ word_lists.append(cur)
+
+ if len(word_lists) == 0:
+ return False
+
+ for ch in word_lists:
+ if isChinese(ch) is False:
+ return False
+ return True
+
+
+def isAllAlpha(word: Union[List[Any], str]):
+ word_lists = []
+ table = str.maketrans('', '', string.punctuation)
+ for i in word:
+ cur = i.translate(table)
+ cur = cur.replace(' ', '')
+ cur = cur.replace('</s>', '')
+ cur = cur.replace('<s>', '')
+ word_lists.append(cur)
+
+ if len(word_lists) == 0:
+ return False
+
+ for ch in word_lists:
+ if ch.isalpha() is False:
+ return False
+ elif ch.isalpha() is True and isChinese(ch) is True:
+ return False
+
+ return True
+
+
+def abbr_dispose(words: List[Any]) -> List[Any]:
+ words_size = len(words)
+ word_lists = []
+ abbr_begin = []
+ abbr_end = []
+ last_num = -1
+ for num in range(words_size):
+ if num <= last_num:
+ continue
+
+ if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
+ if num + 1 < words_size and words[
+ num + 1] == ' ' and num + 2 < words_size and len(
+ words[num +
+ 2]) == 1 and words[num +
+ 2].encode('utf-8').isalpha():
+ # found the begin of abbr
+ abbr_begin.append(num)
+ num += 2
+ abbr_end.append(num)
+ # to find the end of abbr
+ while True:
+ num += 1
+ if num < words_size and words[num] == ' ':
+ num += 1
+ if num < words_size and len(
+ words[num]) == 1 and words[num].encode(
+ 'utf-8').isalpha():
+ abbr_end.pop()
+ abbr_end.append(num)
+ last_num = num
+ else:
+ break
+ else:
+ break
+
+ last_num = -1
+ for num in range(words_size):
+ if num <= last_num:
+ continue
+
+ if num in abbr_begin:
+ word_lists.append(words[num].upper())
+ num += 1
+ while num < words_size:
+ if num in abbr_end:
+ word_lists.append(words[num].upper())
+ last_num = num
+ break
+ else:
+ if words[num].encode('utf-8').isalpha():
+ word_lists.append(words[num].upper())
+ num += 1
+ else:
+ word_lists.append(words[num])
+
+ return word_lists
+
+
+def sentence_postprocess(words: List[Any]):
+ middle_lists = []
+ word_lists = []
+ word_item = ''
+
+ # wash words lists
+ for i in words:
+ word = ''
+ if isinstance(i, str):
+ word = i
+ else:
+ word = i.decode('utf-8')
+
+ if word in ['<s>', '</s>', '<unk>']:
+ continue
+ else:
+ middle_lists.append(word)
+
+ # all chinese characters
+ if isAllChinese(middle_lists):
+ for ch in middle_lists:
+ word_lists.append(ch.replace(' ', ''))
+
+ # all alpha characters
+ elif isAllAlpha(middle_lists):
+ for ch in middle_lists:
+ word = ''
+ if '@@' in ch:
+ word = ch.replace('@@', '')
+ word_item += word
+ else:
+ word_item += ch
+ word_lists.append(word_item)
+ word_lists.append(' ')
+ word_item = ''
+
+ # mix characters
+ else:
+ alpha_blank = False
+ for ch in middle_lists:
+ word = ''
+ if isAllChinese(ch):
+ if alpha_blank is True:
+ word_lists.pop()
+ word_lists.append(ch)
+ alpha_blank = False
+ elif '@@' in ch:
+ word = ch.replace('@@', '')
+ word_item += word
+ alpha_blank = False
+ elif isAllAlpha(ch):
+ word_item += ch
+ word_lists.append(word_item)
+ word_lists.append(' ')
+ word_item = ''
+ alpha_blank = True
+ else:
+ raise ValueError('invalid character: {}'.format(ch))
+
+ word_lists = abbr_dispose(word_lists)
+ sentence = ''.join(word_lists).strip()
+ return sentence
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
new file mode 100644
index 0000000..d8564f2
--- /dev/null
+++ b/funasr/utils/wav_utils.py
@@ -0,0 +1,178 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+import os
+from typing import Any, Dict, Union
+
+import kaldiio
+import librosa
+import numpy as np
+import torch
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+
+
+def ndarray_resample(audio_in: np.ndarray,
+ fs_in: int = 16000,
+ fs_out: int = 16000) -> np.ndarray:
+ audio_out = audio_in
+ if fs_in != fs_out:
+ audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
+ return audio_out
+
+
+def torch_resample(audio_in: torch.Tensor,
+ fs_in: int = 16000,
+ fs_out: int = 16000) -> torch.Tensor:
+ audio_out = audio_in
+ if fs_in != fs_out:
+ audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
+ new_freq=fs_out)(audio_in)
+ return audio_out
+
+
+def extract_CMVN_featrures(mvn_file):
+ """
+ extract CMVN from cmvn.ark
+ """
+
+ if not os.path.exists(mvn_file):
+ return None
+ try:
+ cmvn = kaldiio.load_mat(mvn_file)
+ means = []
+ variance = []
+
+ for i in range(cmvn.shape[1] - 1):
+ means.append(float(cmvn[0][i]))
+
+ count = float(cmvn[0][-1])
+
+ for i in range(cmvn.shape[1] - 1):
+ variance.append(float(cmvn[1][i]))
+
+ for i in range(len(means)):
+ means[i] /= count
+ variance[i] = variance[i] / count - means[i] * means[i]
+ if variance[i] < 1.0e-20:
+ variance[i] = 1.0e-20
+ variance[i] = 1.0 / math.sqrt(variance[i])
+
+ cmvn = np.array([means, variance])
+ return cmvn
+ except Exception:
+ cmvn = extract_CMVN_features_txt(mvn_file)
+ return cmvn
+
+
+def extract_CMVN_features_txt(mvn_file): # noqa
+ with open(mvn_file, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+
+ add_shift_list = []
+ rescale_list = []
+ for i in range(len(lines)):
+ line_item = lines[i].split()
+ if line_item[0] == '<AddShift>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ add_shift_line = line_item[3:(len(line_item) - 1)]
+ add_shift_list = list(add_shift_line)
+ continue
+ elif line_item[0] == '<Rescale>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ rescale_line = line_item[3:(len(line_item) - 1)]
+ rescale_list = list(rescale_line)
+ continue
+ add_shift_list_f = [float(s) for s in add_shift_list]
+ rescale_list_f = [float(s) for s in rescale_list]
+ cmvn = np.array([add_shift_list_f, rescale_list_f])
+ return cmvn
+
+
+def build_LFR_features(inputs, m=7, n=6): # noqa
+ """
+ Actually, this implements stacking frames and skipping frames.
+ if m = 1 and n = 1, just return the origin features.
+ if m = 1 and n > 1, it works like skipping.
+ if m > 1 and n = 1, it works like stacking but only support right frames.
+ if m > 1 and n > 1, it works like LFR.
+
+ Args:
+ inputs_batch: inputs is T x D np.ndarray
+ m: number of frames to stack
+ n: number of frames to skip
+ """
+ # LFR_inputs_batch = []
+ # for inputs in inputs_batch:
+ LFR_inputs = []
+ T = inputs.shape[0]
+ T_lfr = int(np.ceil(T / n))
+ left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
+ inputs = np.vstack((left_padding, inputs))
+ T = T + (m - 1) // 2
+ for i in range(T_lfr):
+ if m <= T - i * n:
+ LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
+ else: # process last LFR frame
+ num_padding = m - (T - i * n)
+ frame = np.hstack(inputs[i * n:])
+ for _ in range(num_padding):
+ frame = np.hstack((frame, inputs[-1]))
+ LFR_inputs.append(frame)
+ return np.vstack(LFR_inputs)
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ is_pcm=False,
+ fs: Union[int, Dict[Any, int]] = 16000):
+ audio_sr: int = 16000
+ model_sr: int = 16000
+ if isinstance(fs, int):
+ model_sr = fs
+ audio_sr = fs
+ else:
+ model_sr = fs['model_fs']
+ audio_sr = fs['audio_fs']
+
+ if is_pcm is True:
+ # byte(PCM16) to float32, and resample
+ value = wav_file
+ middle_data = np.frombuffer(value, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in 'iu':
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype('float32')
+ if dtype.kind != 'f':
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2**(i.bits - 1)
+ offset = i.min + abs_max
+ waveform = np.frombuffer(
+ (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+ waveform = ndarray_resample(waveform, audio_sr, model_sr)
+ waveform = torch.from_numpy(waveform.reshape(1, -1))
+ else:
+ # load pcm from wav, and resample
+ waveform, audio_sr = torchaudio.load(wav_file)
+ waveform = waveform * (1 << 15)
+ waveform = torch_resample(waveform, audio_sr, model_sr)
+
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type='hamming',
+ sample_frequency=model_sr)
+
+ input_feats = mat
+
+ return input_feats
diff --git a/funasr/version.txt b/funasr/version.txt
index 6e8bf73..b1e80bb 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.1.0
+0.1.3
--
Gitblit v1.9.1