Merge pull request #6 from alibaba-damo-academy/dev
update funasr 0.1.3
| | |
| | | ${decode_cmd} --max-jobs-run "${inference_nj}" JOB=1:"${inference_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.modelscope_infer \ |
| | | --model_name ${model_name} \ |
| | | --model_revision ${model_revision} \ |
| | | --wav_list ${_logdir}/keys.JOB.scp \ |
| | | --output_file ${_logdir}/text.JOB \ |
| | | --gpuid_list ${gpuid_list} \ |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import argparse |
| | | import logging |
| | | import sys |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch |
| | | from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim |
| | | from funasr.modules.beam_search.beam_search import BeamSearch |
| | | from funasr.modules.beam_search.beam_search import Hypothesis |
| | | from funasr.modules.scorers.ctc import CTCPrefixScorer |
| | | from funasr.modules.scorers.length_bonus import LengthBonus |
| | | from funasr.modules.scorers.scorer_interface import BatchScorerInterface |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.tasks.asr import ASRTask |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.text.build_tokenizer import build_tokenizer |
| | | from funasr.text.token_id_converter import TokenIDConverter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | |
| | | from modelscope.utils.logger import get_logger |
| | | |
| | | logger = get_logger() |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | | |
| | | global_asr_language: str = 'zh-cn' |
| | | global_sample_rate: Union[int, Dict[Any, int]] = { |
| | | 'audio_fs': 16000, |
| | | 'model_fs': 16000 |
| | | } |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class |
| | | |
| | | Examples: |
| | | >>> import soundfile |
| | | >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") |
| | | >>> audio, rate = soundfile.read("speech.wav") |
| | | >>> speech2text(audio) |
| | | [(text, token, token_int, hypothesis object), ...] |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | | token_type: str = None, |
| | | bpemodel: str = None, |
| | | device: str = "cpu", |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | beam_size: int = 20, |
| | | ctc_weight: float = 0.5, |
| | | lm_weight: float = 1.0, |
| | | ngram_weight: float = 0.9, |
| | | penalty: float = 0.0, |
| | | nbest: int = 1, |
| | | streaming: bool = False, |
| | | frontend_conf: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | asr_model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, device |
| | | ) |
| | | if asr_model.frontend is None and frontend_conf is not None: |
| | | frontend = WavFrontend(**frontend_conf) |
| | | asr_model.frontend = frontend |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | decoder = asr_model.decoder |
| | | |
| | | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) |
| | | token_list = asr_model.token_list |
| | | scorers.update( |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | # 2. Build Language model |
| | | if lm_train_config is not None: |
| | | lm, lm_train_args = LMTask.build_model_from_file( |
| | | lm_train_config, lm_file, device |
| | | ) |
| | | scorers["lm"] = lm.lm |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | # 4. Build BeamSearch object |
| | | # transducer is not supported now |
| | | beam_search_transducer = None |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - ctc_weight, |
| | | ctc=ctc_weight, |
| | | lm=lm_weight, |
| | | ngram=ngram_weight, |
| | | length_bonus=penalty, |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=beam_size, |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=asr_model.sos, |
| | | eos=asr_model.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if ctc_weight == 1.0 else "full", |
| | | ) |
| | | |
| | | # TODO(karita): make all scorers batchfied |
| | | if batch_size == 1: |
| | | non_batch = [ |
| | | k |
| | | for k, v in beam_search.full_scorers.items() |
| | | if not isinstance(v, BatchScorerInterface) |
| | | ] |
| | | if len(non_batch) == 0: |
| | | if streaming: |
| | | beam_search.__class__ = BatchBeamSearchOnlineSim |
| | | beam_search.set_streaming_config(asr_train_config) |
| | | logging.info( |
| | | "BatchBeamSearchOnlineSim implementation is selected." |
| | | ) |
| | | else: |
| | | beam_search.__class__ = BatchBeamSearch |
| | | logging.info("BatchBeamSearch implementation is selected.") |
| | | else: |
| | | logging.warning( |
| | | f"As non-batch scorers {non_batch} are found, " |
| | | f"fall back to non-batch implementation." |
| | | ) |
| | | |
| | | beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() |
| | | for scorer in scorers.values(): |
| | | if isinstance(scorer, torch.nn.Module): |
| | | scorer.to(device=device, dtype=getattr(torch, dtype)).eval() |
| | | logging.info(f"Beam_search: {beam_search}") |
| | | logging.info(f"Decoding device={device}, dtype={dtype}") |
| | | |
| | | # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text |
| | | if token_type is None: |
| | | token_type = asr_train_args.token_type |
| | | if bpemodel is None: |
| | | bpemodel = asr_train_args.bpemodel |
| | | |
| | | if token_type is None: |
| | | tokenizer = None |
| | | elif token_type == "bpe": |
| | | if bpemodel is not None: |
| | | tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) |
| | | else: |
| | | tokenizer = None |
| | | else: |
| | | tokenizer = build_tokenizer(token_type=token_type) |
| | | converter = TokenIDConverter(token_list=token_list) |
| | | logging.info(f"Text tokenizer: {tokenizer}") |
| | | |
| | | self.asr_model = asr_model |
| | | self.asr_train_args = asr_train_args |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | self.beam_search = beam_search |
| | | self.beam_search_transducer = beam_search_transducer |
| | | self.maxlenratio = maxlenratio |
| | | self.minlenratio = minlenratio |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.nbest = nbest |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, speech: Union[torch.Tensor, np.ndarray] |
| | | ) -> List[ |
| | | Tuple[ |
| | | Optional[str], |
| | | List[str], |
| | | List[int], |
| | | Union[Hypothesis], |
| | | ] |
| | | ]: |
| | | """Inference |
| | | |
| | | Args: |
| | | speech: Input speech data |
| | | Returns: |
| | | text, token, token_int, hyp |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | # data: (Nsamples,) -> (1, Nsamples) |
| | | speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | lfr_factor = max(1, (speech.size()[-1] // 80) - 1) |
| | | # lengths: (1,) |
| | | lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) |
| | | batch = {"speech": speech, "speech_lengths": lengths} |
| | | |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | # b. Forward Encoder |
| | | enc, _ = self.asr_model.encode(**batch) |
| | | if isinstance(enc, tuple): |
| | | enc = enc[0] |
| | | assert len(enc) == 1, len(enc) |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | for hyp in nbest_hyps: |
| | | assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != 0, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | | else: |
| | | text = None |
| | | results.append((text, token, token_int, hyp)) |
| | | |
| | | assert check_return_type(results) |
| | | return results |
| | | |
| | | |
| | | def inference( |
| | | maxlenratio: float, |
| | | minlenratio: float, |
| | | batch_size: int, |
| | | dtype: str, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | seed: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | ngram_weight: float, |
| | | penalty: float, |
| | | nbest: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | data_path_and_name_and_type: list, |
| | | audio_lists: Union[List[Any], bytes], |
| | | key_file: Optional[str], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | lm_train_config: Optional[str], |
| | | lm_file: Optional[str], |
| | | word_lm_train_config: Optional[str], |
| | | token_type: Optional[str], |
| | | bpemodel: Optional[str], |
| | | output_dir: Optional[str], |
| | | allow_variable_data_keys: bool, |
| | | streaming: bool, |
| | | frontend_conf: dict = None, |
| | | fs: Union[dict, int] = 16000, |
| | | **kwargs, |
| | | ) -> List[Any]: |
| | | assert check_argument_types() |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if word_lm_train_config is not None: |
| | | raise NotImplementedError("Word LM is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1: |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | features_type: str = data_path_and_name_and_type[1] |
| | | hop_length: int = 160 |
| | | sr: int = 16000 |
| | | if isinstance(fs, int): |
| | | sr = fs |
| | | else: |
| | | if 'model_fs' in fs and fs['model_fs'] is not None: |
| | | sr = fs['model_fs'] |
| | | if features_type != 'sound': |
| | | frontend_conf = None |
| | | if frontend_conf is not None: |
| | | if 'hop_length' in frontend_conf: |
| | | hop_length = frontend_conf['hop_length'] |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | | if isinstance(audio_lists, bytes): |
| | | file_count = 1 |
| | | else: |
| | | file_count = len(audio_lists) |
| | | if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None: |
| | | mvn_file = data_path_and_name_and_type[2] |
| | | mvn_data = wav_utils.extract_CMVN_featrures(mvn_file) |
| | | frontend_conf['mvn_data'] = mvn_data |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build speech2text |
| | | speech2text_kwargs = dict( |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | device=device, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | streaming=streaming, |
| | | frontend_conf=frontend_conf, |
| | | ) |
| | | speech2text = Speech2Text(**speech2text_kwargs) |
| | | data_path_and_name_and_type_new = [ |
| | | audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1] |
| | | ] |
| | | # 3. Build data-iterator |
| | | loader = ASRTask.build_streaming_iterator_modelscope( |
| | | data_path_and_name_and_type_new, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), |
| | | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | sample_rate=fs |
| | | ) |
| | | |
| | | # 7 .Start for-loop |
| | | # FIXME(kamo): The output format should be discussed about |
| | | asr_result_list = [] |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | # N-best list of (text, token, token_int, hyp_object) |
| | | try: |
| | | results = speech2text(**batch) |
| | | except TooShortUttError as e: |
| | | logging.warning(f"Utterance {keys} {e}") |
| | | hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) |
| | | results = [[" ", ["<space>"], [2], hyp]] * nbest |
| | | |
| | | # Only supporting batch_size==1 |
| | | key = keys[0] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | |
| | | return asr_result_list |
| | | |
| | | |
| | | |
| | | def set_parameters(language: str = None, |
| | | sample_rate: Union[int, Dict[Any, int]] = None): |
| | | if language is not None: |
| | | global global_asr_language |
| | | global_asr_language = language |
| | | if sample_rate is not None: |
| | | global global_sample_rate |
| | | global_sample_rate = sample_rate |
| | | |
| | | |
| | | def asr_inference(maxlenratio: float, |
| | | minlenratio: float, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | penalty: float, |
| | | name_and_type: list, |
| | | audio_lists: Union[List[Any], bytes], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | log_level: Union[int, str] = 'INFO', |
| | | batch_size: int = 1, |
| | | dtype: str = 'float32', |
| | | seed: int = 0, |
| | | key_file: Optional[str] = None, |
| | | lm_train_config: Optional[str] = None, |
| | | lm_file: Optional[str] = None, |
| | | word_lm_train_config: Optional[str] = None, |
| | | word_lm_file: Optional[str] = None, |
| | | ngram_file: Optional[str] = None, |
| | | ngram_weight: float = 0.9, |
| | | model_tag: Optional[str] = None, |
| | | token_type: Optional[str] = None, |
| | | bpemodel: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | transducer_conf: Optional[dict] = None, |
| | | streaming: bool = False, |
| | | frontend_conf: dict = None, |
| | | fs: Union[dict, int] = None, |
| | | lang: Optional[str] = None, |
| | | outputdir: Optional[str] = None): |
| | | if lang is not None: |
| | | global global_asr_language |
| | | global_asr_language = lang |
| | | if fs is not None: |
| | | global global_sample_rate |
| | | global_sample_rate = fs |
| | | |
| | | # force use CPU if data type is bytes |
| | | if isinstance(audio_lists, bytes): |
| | | num_workers = 0 |
| | | ngpu = 0 |
| | | |
| | | return inference(output_dir=outputdir, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | batch_size=batch_size, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ngpu=ngpu, |
| | | seed=seed, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | num_workers=num_workers, |
| | | log_level=log_level, |
| | | data_path_and_name_and_type=name_and_type, |
| | | audio_lists=audio_lists, |
| | | key_file=key_file, |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | word_lm_train_config=word_lm_train_config, |
| | | word_lm_file=word_lm_file, |
| | | ngram_file=ngram_file, |
| | | model_tag=model_tag, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | transducer_conf=transducer_conf, |
| | | streaming=streaming, |
| | | frontend_conf=frontend_conf) |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="ASR Decoding", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpuid_list", |
| | | type=str, |
| | | default="", |
| | | help="The visible gpus", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | required=True, |
| | | action="append", |
| | | ) |
| | | group.add_argument("--audio_lists", type=list, |
| | | default=[{'key':'EdevDEWdIYQ_0021', |
| | | 'file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--asr_train_config", |
| | | type=str, |
| | | help="ASR training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--asr_model_file", |
| | | type=str, |
| | | help="ASR model parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--lm_train_config", |
| | | type=str, |
| | | help="LM training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--lm_file", |
| | | type=str, |
| | | help="LM parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--word_lm_train_config", |
| | | type=str, |
| | | help="Word LM training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--word_lm_file", |
| | | type=str, |
| | | help="Word LM parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--ngram_file", |
| | | type=str, |
| | | help="N-gram parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--model_tag", |
| | | type=str, |
| | | help="Pretrained model tag. If specify this option, *_train_config and " |
| | | "*_file will be overwritten", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Beam-search related") |
| | | group.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") |
| | | group.add_argument("--beam_size", type=int, default=20, help="Beam size") |
| | | group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") |
| | | group.add_argument( |
| | | "--maxlenratio", |
| | | type=float, |
| | | default=0.0, |
| | | help="Input length ratio to obtain max output length. " |
| | | "If maxlenratio=0.0 (default), it uses a end-detect " |
| | | "function " |
| | | "to automatically find maximum hypothesis lengths." |
| | | "If maxlenratio<0.0, its absolute value is interpreted" |
| | | "as a constant max output length", |
| | | ) |
| | | group.add_argument( |
| | | "--minlenratio", |
| | | type=float, |
| | | default=0.0, |
| | | help="Input length ratio to obtain min output length", |
| | | ) |
| | | group.add_argument( |
| | | "--ctc_weight", |
| | | type=float, |
| | | default=0.5, |
| | | help="CTC weight in joint decoding", |
| | | ) |
| | | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") |
| | | group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") |
| | | group.add_argument("--streaming", type=str2bool, default=False) |
| | | |
| | | group = parser.add_argument_group("Text converter related") |
| | | group.add_argument( |
| | | "--token_type", |
| | | type=str_or_none, |
| | | default=None, |
| | | choices=["char", "bpe", None], |
| | | help="The token type for ASR model. " |
| | | "If not given, refers from the training args", |
| | | ) |
| | | group.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model path of sentencepiece. " |
| | | "If not given, refers from the training args", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | inference(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import argparse |
| | | import logging |
| | | import sys |
| | | import time |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import List |
| | | from typing import Dict |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch |
| | | from funasr.modules.beam_search.beam_search import Hypothesis |
| | | from funasr.modules.scorers.ctc import CTCPrefixScorer |
| | | from funasr.modules.scorers.length_bonus import LengthBonus |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.text.build_tokenizer import build_tokenizer |
| | | from funasr.text.token_id_converter import TokenIDConverter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | |
| | | from modelscope.utils.logger import get_logger |
| | | |
| | | logger = get_logger() |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | | |
| | | global_asr_language: str = 'zh-cn' |
| | | global_sample_rate: Union[int, Dict[Any, int]] = { |
| | | 'audio_fs': 16000, |
| | | 'model_fs': 16000 |
| | | } |
| | | |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class |
| | | |
| | | Examples: |
| | | >>> import soundfile |
| | | >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") |
| | | >>> audio, rate = soundfile.read("speech.wav") |
| | | >>> speech2text(audio) |
| | | [(text, token, token_int, hypothesis object), ...] |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | | token_type: str = None, |
| | | bpemodel: str = None, |
| | | device: str = "cpu", |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | dtype: str = "float32", |
| | | beam_size: int = 20, |
| | | ctc_weight: float = 0.5, |
| | | lm_weight: float = 1.0, |
| | | ngram_weight: float = 0.9, |
| | | penalty: float = 0.0, |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | asr_model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, device |
| | | ) |
| | | if asr_model.frontend is None and frontend_conf is not None: |
| | | frontend = WavFrontend(**frontend_conf) |
| | | asr_model.frontend = frontend |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) |
| | | token_list = asr_model.token_list |
| | | scorers.update( |
| | | ctc=ctc, |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | # 2. Build Language model |
| | | if lm_train_config is not None: |
| | | lm, lm_train_args = LMTask.build_model_from_file( |
| | | lm_train_config, lm_file, device |
| | | ) |
| | | scorers["lm"] = lm.lm |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | # 4. Build BeamSearch object |
| | | # transducer is not supported now |
| | | beam_search_transducer = None |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - ctc_weight, |
| | | ctc=ctc_weight, |
| | | lm=lm_weight, |
| | | ngram=ngram_weight, |
| | | length_bonus=penalty, |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=beam_size, |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=asr_model.sos, |
| | | eos=asr_model.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if ctc_weight == 1.0 else "full", |
| | | ) |
| | | |
| | | beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() |
| | | for scorer in scorers.values(): |
| | | if isinstance(scorer, torch.nn.Module): |
| | | scorer.to(device=device, dtype=getattr(torch, dtype)).eval() |
| | | logging.info(f"Beam_search: {beam_search}") |
| | | logging.info(f"Decoding device={device}, dtype={dtype}") |
| | | |
| | | # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text |
| | | if token_type is None: |
| | | token_type = asr_train_args.token_type |
| | | if bpemodel is None: |
| | | bpemodel = asr_train_args.bpemodel |
| | | |
| | | if token_type is None: |
| | | tokenizer = None |
| | | elif token_type == "bpe": |
| | | if bpemodel is not None: |
| | | tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) |
| | | else: |
| | | tokenizer = None |
| | | else: |
| | | tokenizer = build_tokenizer(token_type=token_type) |
| | | converter = TokenIDConverter(token_list=token_list) |
| | | logging.info(f"Text tokenizer: {tokenizer}") |
| | | |
| | | self.asr_model = asr_model |
| | | self.asr_train_args = asr_train_args |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | self.beam_search = beam_search |
| | | self.beam_search_transducer = beam_search_transducer |
| | | self.maxlenratio = maxlenratio |
| | | self.minlenratio = minlenratio |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.nbest = nbest |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, speech: Union[torch.Tensor, np.ndarray] |
| | | ): |
| | | """Inference |
| | | |
| | | Args: |
| | | speech: Input speech data |
| | | Returns: |
| | | text, token, token_int, hyp |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | # data: (Nsamples,) -> (1, Nsamples) |
| | | speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | lfr_factor = max(1, (speech.size()[-1] // 80) - 1) |
| | | # lengths: (1,) |
| | | lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) |
| | | batch = {"speech": speech, "speech_lengths": lengths} |
| | | |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | # b. Forward Encoder |
| | | enc, enc_len = self.asr_model.encode(**batch) |
| | | if isinstance(enc, tuple): |
| | | enc = enc[0] |
| | | assert len(enc) == 1, len(enc) |
| | | |
| | | predictor_outs = self.asr_model.calc_predictor(enc, enc_len) |
| | | pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1] |
| | | pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device) |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | nbest_hyps = self.beam_search( |
| | | x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | results = [] |
| | | for hyp in nbest_hyps: |
| | | assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != 0, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | | else: |
| | | text = None |
| | | |
| | | results.append((text, token, token_int, hyp, speech.size(1), lfr_factor)) |
| | | |
| | | # assert check_return_type(results) |
| | | return results |
| | | |
| | | |
| | | def inference( |
| | | maxlenratio: float, |
| | | minlenratio: float, |
| | | batch_size: int, |
| | | dtype: str, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | seed: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | ngram_weight: float, |
| | | penalty: float, |
| | | nbest: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | data_path_and_name_and_type: list, |
| | | audio_lists: Union[List[Any], bytes], |
| | | key_file: Optional[str], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | lm_train_config: Optional[str], |
| | | lm_file: Optional[str], |
| | | word_lm_train_config: Optional[str], |
| | | model_tag: Optional[str], |
| | | token_type: Optional[str], |
| | | bpemodel: Optional[str], |
| | | output_dir: Optional[str], |
| | | allow_variable_data_keys: bool, |
| | | frontend_conf: dict = None, |
| | | fs: Union[dict, int] = 16000, |
| | | **kwargs, |
| | | ) -> List[Any]: |
| | | assert check_argument_types() |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if word_lm_train_config is not None: |
| | | raise NotImplementedError("Word LM is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1: |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | # data_path_and_name_and_type = data_path_and_name_and_type[0] |
| | | features_type: str = data_path_and_name_and_type[1] |
| | | hop_length: int = 160 |
| | | sr: int = 16000 |
| | | if isinstance(fs, int): |
| | | sr = fs |
| | | else: |
| | | if 'model_fs' in fs and fs['model_fs'] is not None: |
| | | sr = fs['model_fs'] |
| | | if features_type != 'sound': |
| | | frontend_conf = None |
| | | if frontend_conf is not None: |
| | | if 'hop_length' in frontend_conf: |
| | | hop_length = frontend_conf['hop_length'] |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | | if isinstance(audio_lists, bytes): |
| | | file_count = 1 |
| | | else: |
| | | file_count = len(audio_lists) |
| | | if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None: |
| | | mvn_file = data_path_and_name_and_type[2] |
| | | mvn_data = wav_utils.extract_CMVN_featrures(mvn_file) |
| | | frontend_conf['mvn_data'] = mvn_data |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build speech2text |
| | | speech2text_kwargs = dict( |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | device=device, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | frontend_conf=frontend_conf, |
| | | ) |
| | | speech2text = Speech2Text(**speech2text_kwargs) |
| | | |
| | | data_path_and_name_and_type_new = [ |
| | | audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1] |
| | | ] |
| | | |
| | | # 3. Build data-iterator |
| | | loader = ASRTask.build_streaming_iterator_modelscope( |
| | | data_path_and_name_and_type_new, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), |
| | | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | sample_rate=fs |
| | | ) |
| | | |
| | | forward_time_total = 0.0 |
| | | length_total = 0.0 |
| | | asr_result_list = [] |
| | | # 7 .Start for-loop |
| | | # FIXME(kamo): The output format should be discussed about |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | logging.info("decoding, utt_id: {}".format(keys)) |
| | | # N-best list of (text, token, token_int, hyp_object) |
| | | |
| | | try: |
| | | time_beg = time.time() |
| | | results = speech2text(**batch) |
| | | time_end = time.time() |
| | | forward_time = time_end - time_beg |
| | | lfr_factor = results[0][-1] |
| | | length = results[0][-2] |
| | | results = [results[0][:-2]] |
| | | forward_time_total += forward_time |
| | | length_total += length |
| | | logging.info( |
| | | "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}". |
| | | format(length, forward_time, 100 * forward_time / (length * lfr_factor))) |
| | | except TooShortUttError as e: |
| | | logging.warning(f"Utterance {keys} {e}") |
| | | hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) |
| | | results = [[" ", ["<space>"], [2], hyp]] * nbest |
| | | |
| | | # Only supporting batch_size==1 |
| | | key = keys[0] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | |
| | | logging.info("decoding, predictions: {}".format(text)) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | |
| | | logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}". |
| | | format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))) |
| | | if features_type == 'sound': |
| | | # data format is wav |
| | | length_total_seconds = length_total / sr |
| | | length_total_bytes = length_total * 2 |
| | | else: |
| | | # data format is kaldi_ark |
| | | length_total_seconds = length_total * hop_length / sr |
| | | length_total_bytes = length_total * hop_length * 2 |
| | | |
| | | logger.info( |
| | | header_colors + # noqa: * |
| | | 'decoding, feature length total: {}bytes, forward_time total: {:.4f}s, rtf avg: {:.4f}' |
| | | .format(length_total_bytes, forward_time_total, forward_time_total / |
| | | length_total_seconds) + end_colors) |
| | | |
| | | return asr_result_list |
| | | |
| | | |
| | | def set_parameters(language: str = None, |
| | | sample_rate: Union[int, Dict[Any, int]] = None): |
| | | if language is not None: |
| | | global global_asr_language |
| | | global_asr_language = language |
| | | if sample_rate is not None: |
| | | global global_sample_rate |
| | | global_sample_rate = sample_rate |
| | | |
| | | |
| | | def asr_inference(maxlenratio: float, |
| | | minlenratio: float, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | penalty: float, |
| | | name_and_type: list, |
| | | audio_lists: Union[List[Any], bytes], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | log_level: Union[int, str] = 'INFO', |
| | | batch_size: int = 1, |
| | | dtype: str = 'float32', |
| | | seed: int = 0, |
| | | key_file: Optional[str] = None, |
| | | lm_train_config: Optional[str] = None, |
| | | lm_file: Optional[str] = None, |
| | | word_lm_train_config: Optional[str] = None, |
| | | word_lm_file: Optional[str] = None, |
| | | ngram_file: Optional[str] = None, |
| | | ngram_weight: float = 0.9, |
| | | model_tag: Optional[str] = None, |
| | | token_type: Optional[str] = None, |
| | | bpemodel: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | transducer_conf: Optional[dict] = None, |
| | | streaming: bool = False, |
| | | frontend_conf: dict = None, |
| | | fs: Union[dict, int] = None, |
| | | lang: Optional[str] = None, |
| | | outputdir: Optional[str] = None): |
| | | if lang is not None: |
| | | global global_asr_language |
| | | global_asr_language = lang |
| | | if fs is not None: |
| | | global global_sample_rate |
| | | global_sample_rate = fs |
| | | |
| | | # force use CPU if data type is bytes |
| | | if isinstance(audio_lists, bytes): |
| | | num_workers = 0 |
| | | ngpu = 0 |
| | | |
| | | return inference(output_dir=outputdir, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | batch_size=batch_size, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ngpu=ngpu, |
| | | seed=seed, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | num_workers=num_workers, |
| | | log_level=log_level, |
| | | data_path_and_name_and_type=name_and_type, |
| | | audio_lists=audio_lists, |
| | | key_file=key_file, |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | word_lm_train_config=word_lm_train_config, |
| | | word_lm_file=word_lm_file, |
| | | ngram_file=ngram_file, |
| | | model_tag=model_tag, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | transducer_conf=transducer_conf, |
| | | streaming=streaming, |
| | | frontend_conf=frontend_conf) |
| | | |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="ASR Decoding", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | required=True, |
| | | action="append", |
| | | ) |
| | | group.add_argument("--audio_lists", type=list, default=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--asr_train_config", |
| | | type=str, |
| | | help="ASR training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--asr_model_file", |
| | | type=str, |
| | | help="ASR model parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--lm_train_config", |
| | | type=str, |
| | | help="LM training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--lm_file", |
| | | type=str, |
| | | help="LM parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--word_lm_train_config", |
| | | type=str, |
| | | help="Word LM training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--word_lm_file", |
| | | type=str, |
| | | help="Word LM parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--ngram_file", |
| | | type=str, |
| | | help="N-gram parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--model_tag", |
| | | type=str, |
| | | help="Pretrained model tag. If specify this option, *_train_config and " |
| | | "*_file will be overwritten", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Beam-search related") |
| | | group.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") |
| | | group.add_argument("--beam_size", type=int, default=20, help="Beam size") |
| | | group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") |
| | | group.add_argument( |
| | | "--maxlenratio", |
| | | type=float, |
| | | default=0.0, |
| | | help="Input length ratio to obtain max output length. " |
| | | "If maxlenratio=0.0 (default), it uses a end-detect " |
| | | "function " |
| | | "to automatically find maximum hypothesis lengths." |
| | | "If maxlenratio<0.0, its absolute value is interpreted" |
| | | "as a constant max output length", |
| | | ) |
| | | group.add_argument( |
| | | "--minlenratio", |
| | | type=float, |
| | | default=0.0, |
| | | help="Input length ratio to obtain min output length", |
| | | ) |
| | | group.add_argument( |
| | | "--ctc_weight", |
| | | type=float, |
| | | default=0.5, |
| | | help="CTC weight in joint decoding", |
| | | ) |
| | | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") |
| | | group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") |
| | | group.add_argument("--streaming", type=str2bool, default=False) |
| | | |
| | | group.add_argument( |
| | | "--asr_model_config", |
| | | default=None, |
| | | help="", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Text converter related") |
| | | group.add_argument( |
| | | "--token_type", |
| | | type=str_or_none, |
| | | default=None, |
| | | choices=["char", "bpe", None], |
| | | help="The token type for ASR model. " |
| | | "If not given, refers from the training args", |
| | | ) |
| | | group.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model path of sentencepiece. " |
| | | "If not given, refers from the training args", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | inference(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | |
| | | type=str, |
| | | default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | help="model name in modelscope") |
| | | parser.add_argument("--model_revision", |
| | | type=str, |
| | | default="v1.0.3", |
| | | help="model revision in modelscope") |
| | | parser.add_argument("--local_model_path", |
| | | type=str, |
| | | default=None, |
| | |
| | | if args.local_model_path is None: |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model="damo/{}".format(args.model_name)) |
| | | model="damo/{}".format(args.model_name), |
| | | model_revision=args.model_revision) |
| | | else: |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | # Part of the implementation is borrowed from espnet/espnet. |
| | | """Iterable dataset module.""" |
| | | import copy |
| | | from io import StringIO |
| | | from pathlib import Path |
| | | from typing import Callable, Collection, Dict, Iterator, Tuple, Union |
| | | |
| | | import kaldiio |
| | | import numpy as np |
| | | import soundfile |
| | | import torch |
| | | from funasr.datasets.dataset import ESPnetDataset |
| | | from torch.utils.data.dataset import IterableDataset |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.utils import wav_utils |
| | | |
| | | |
| | | def load_kaldi(input): |
| | | retval = kaldiio.load_mat(input) |
| | | if isinstance(retval, tuple): |
| | | assert len(retval) == 2, len(retval) |
| | | if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray): |
| | | # sound scp case |
| | | rate, array = retval |
| | | elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray): |
| | | # Extended ark format case |
| | | array, rate = retval |
| | | else: |
| | | raise RuntimeError( |
| | | f'Unexpected type: {type(retval[0])}, {type(retval[1])}') |
| | | |
| | | # Multichannel wave fie |
| | | # array: (NSample, Channel) or (Nsample) |
| | | |
| | | else: |
| | | # Normal ark case |
| | | assert isinstance(retval, np.ndarray), type(retval) |
| | | array = retval |
| | | return array |
| | | |
| | | |
| | | DATA_TYPES = { |
| | | 'sound': |
| | | lambda x: soundfile.read(x)[0], |
| | | 'kaldi_ark': |
| | | load_kaldi, |
| | | 'npy': |
| | | np.load, |
| | | 'text_int': |
| | | lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '), |
| | | 'csv_int': |
| | | lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','), |
| | | 'text_float': |
| | | lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' ' |
| | | ), |
| | | 'csv_float': |
| | | lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=',' |
| | | ), |
| | | 'text': |
| | | lambda x: x, |
| | | } |
| | | |
| | | |
| | | class IterableESPnetDatasetModelScope(IterableDataset): |
| | | """Pytorch Dataset class for ESPNet. |
| | | |
| | | Examples: |
| | | >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'), |
| | | ... ('token_int', 'output', 'text_int')], |
| | | ... ) |
| | | >>> for uid, data in dataset: |
| | | ... data |
| | | {'input': per_utt_array, 'output': per_utt_array} |
| | | """ |
| | | def __init__(self, |
| | | path_name_type_list: Collection[Tuple[any, str, str]], |
| | | preprocess: Callable[[str, Dict[str, np.ndarray]], |
| | | Dict[str, np.ndarray]] = None, |
| | | float_dtype: str = 'float32', |
| | | int_dtype: str = 'long', |
| | | key_file: str = None, |
| | | sample_rate: Union[dict, int] = 16000): |
| | | assert check_argument_types() |
| | | if len(path_name_type_list) == 0: |
| | | raise ValueError( |
| | | '1 or more elements are required for "path_name_type_list"') |
| | | |
| | | self.preprocess = preprocess |
| | | |
| | | self.float_dtype = float_dtype |
| | | self.int_dtype = int_dtype |
| | | self.key_file = key_file |
| | | self.sample_rate = sample_rate |
| | | |
| | | self.debug_info = {} |
| | | non_iterable_list = [] |
| | | self.path_name_type_list = [] |
| | | |
| | | path_list = path_name_type_list[0] |
| | | name = path_name_type_list[1] |
| | | _type = path_name_type_list[2] |
| | | if name in self.debug_info: |
| | | raise RuntimeError(f'"{name}" is duplicated for data-key') |
| | | self.debug_info[name] = path_list, _type |
| | | # for path, name, _type in path_name_type_list: |
| | | for path in path_list: |
| | | self.path_name_type_list.append((path, name, _type)) |
| | | |
| | | if len(non_iterable_list) != 0: |
| | | # Some types doesn't support iterable mode |
| | | self.non_iterable_dataset = ESPnetDataset( |
| | | path_name_type_list=non_iterable_list, |
| | | preprocess=preprocess, |
| | | float_dtype=float_dtype, |
| | | int_dtype=int_dtype, |
| | | ) |
| | | else: |
| | | self.non_iterable_dataset = None |
| | | |
| | | self.apply_utt2category = False |
| | | |
| | | def has_name(self, name) -> bool: |
| | | return name in self.debug_info |
| | | |
| | | def names(self) -> Tuple[str, ...]: |
| | | return tuple(self.debug_info) |
| | | |
| | | def __repr__(self): |
| | | _mes = self.__class__.__name__ |
| | | _mes += '(' |
| | | for name, (path, _type) in self.debug_info.items(): |
| | | _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' |
| | | _mes += f'\n preprocess: {self.preprocess})' |
| | | return _mes |
| | | |
| | | def __iter__( |
| | | self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]: |
| | | torch.set_printoptions(profile='default') |
| | | count = len(self.path_name_type_list) |
| | | for idx in range(count): |
| | | # 2. Load the entry from each line and create a dict |
| | | data = {} |
| | | # 2.a. Load data streamingly |
| | | |
| | | # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav |
| | | value = self.path_name_type_list[idx][0]['file'] |
| | | uid = self.path_name_type_list[idx][0]['key'] |
| | | # name: speech |
| | | name = self.path_name_type_list[idx][1] |
| | | _type = self.path_name_type_list[idx][2] |
| | | func = DATA_TYPES[_type] |
| | | array = func(value) |
| | | |
| | | # 2.b. audio resample |
| | | if _type == 'sound': |
| | | audio_sr: int = 16000 |
| | | model_sr: int = 16000 |
| | | if isinstance(self.sample_rate, int): |
| | | model_sr = self.sample_rate |
| | | else: |
| | | if 'audio_sr' in self.sample_rate: |
| | | audio_sr = self.sample_rate['audio_sr'] |
| | | if 'model_sr' in self.sample_rate: |
| | | model_sr = self.sample_rate['model_sr'] |
| | | array = wav_utils.torch_resample(array, audio_sr, model_sr) |
| | | |
| | | # array: [ 1.25122070e-03 ... ] |
| | | data[name] = array |
| | | |
| | | # 3. [Option] Apply preprocessing |
| | | # e.g. espnet2.train.preprocessor:CommonPreprocessor |
| | | if self.preprocess is not None: |
| | | data = self.preprocess(uid, data) |
| | | # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])} |
| | | |
| | | # 4. Force data-precision |
| | | for name in data: |
| | | # value is np.ndarray data |
| | | value = data[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f'All values must be converted to np.ndarray object ' |
| | | f'by preprocessing, but "{name}" is still {type(value)}.' |
| | | ) |
| | | |
| | | # Cast to desired type |
| | | if value.dtype.kind == 'f': |
| | | value = value.astype(self.float_dtype) |
| | | elif value.dtype.kind == 'i': |
| | | value = value.astype(self.int_dtype) |
| | | else: |
| | | raise NotImplementedError( |
| | | f'Not supported dtype: {value.dtype}') |
| | | data[name] = value |
| | | |
| | | yield uid, data |
| | | |
| | | if count == 0: |
| | | raise RuntimeError('No iteration') |
| | | |
| | | |
| | | class IterableESPnetBytesModelScope(IterableDataset): |
| | | """Pytorch audio bytes class for ESPNet. |
| | | |
| | | Examples: |
| | | >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'), |
| | | ... ('token_int', 'output', 'text_int')], |
| | | ... ) |
| | | >>> for uid, data in dataset: |
| | | ... data |
| | | {'input': per_utt_array, 'output': per_utt_array} |
| | | """ |
| | | def __init__(self, |
| | | path_name_type_list: Collection[Tuple[any, str, str]], |
| | | preprocess: Callable[[str, Dict[str, np.ndarray]], |
| | | Dict[str, np.ndarray]] = None, |
| | | float_dtype: str = 'float32', |
| | | int_dtype: str = 'long', |
| | | key_file: str = None, |
| | | sample_rate: Union[dict, int] = 16000): |
| | | assert check_argument_types() |
| | | if len(path_name_type_list) == 0: |
| | | raise ValueError( |
| | | '1 or more elements are required for "path_name_type_list"') |
| | | |
| | | self.preprocess = preprocess |
| | | |
| | | self.float_dtype = float_dtype |
| | | self.int_dtype = int_dtype |
| | | self.key_file = key_file |
| | | self.sample_rate = sample_rate |
| | | |
| | | self.debug_info = {} |
| | | non_iterable_list = [] |
| | | self.path_name_type_list = [] |
| | | |
| | | audio_data = path_name_type_list[0] |
| | | name = path_name_type_list[1] |
| | | _type = path_name_type_list[2] |
| | | if name in self.debug_info: |
| | | raise RuntimeError(f'"{name}" is duplicated for data-key') |
| | | self.debug_info[name] = audio_data, _type |
| | | self.path_name_type_list.append((audio_data, name, _type)) |
| | | |
| | | if len(non_iterable_list) != 0: |
| | | # Some types doesn't support iterable mode |
| | | self.non_iterable_dataset = ESPnetDataset( |
| | | path_name_type_list=non_iterable_list, |
| | | preprocess=preprocess, |
| | | float_dtype=float_dtype, |
| | | int_dtype=int_dtype, |
| | | ) |
| | | else: |
| | | self.non_iterable_dataset = None |
| | | |
| | | self.apply_utt2category = False |
| | | |
| | | if float_dtype == 'float32': |
| | | self.np_dtype = np.float32 |
| | | |
| | | def has_name(self, name) -> bool: |
| | | return name in self.debug_info |
| | | |
| | | def names(self) -> Tuple[str, ...]: |
| | | return tuple(self.debug_info) |
| | | |
| | | def __repr__(self): |
| | | _mes = self.__class__.__name__ |
| | | _mes += '(' |
| | | for name, (path, _type) in self.debug_info.items(): |
| | | _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' |
| | | _mes += f'\n preprocess: {self.preprocess})' |
| | | return _mes |
| | | |
| | | def __iter__( |
| | | self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]: |
| | | |
| | | torch.set_printoptions(profile='default') |
| | | # 2. Load the entry from each line and create a dict |
| | | data = {} |
| | | # 2.a. Load data streamingly |
| | | |
| | | value = self.path_name_type_list[0][0] |
| | | uid = 'pcm_data' |
| | | # name: speech |
| | | name = self.path_name_type_list[0][1] |
| | | _type = self.path_name_type_list[0][2] |
| | | func = DATA_TYPES[_type] |
| | | # array: [ 1.25122070e-03 ... ] |
| | | # data[name] = np.frombuffer(value, dtype=self.np_dtype) |
| | | |
| | | # 2.b. byte(PCM16) to float32 |
| | | middle_data = np.frombuffer(value, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | | if middle_data.dtype.kind not in 'iu': |
| | | raise TypeError("'middle_data' must be an array of integers") |
| | | dtype = np.dtype('float32') |
| | | if dtype.kind != 'f': |
| | | raise TypeError("'dtype' must be a floating point type") |
| | | |
| | | i = np.iinfo(middle_data.dtype) |
| | | abs_max = 2**(i.bits - 1) |
| | | offset = i.min + abs_max |
| | | array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, |
| | | dtype=self.np_dtype) |
| | | |
| | | # 2.c. audio resample |
| | | if _type == 'sound': |
| | | audio_sr: int = 16000 |
| | | model_sr: int = 16000 |
| | | if isinstance(self.sample_rate, int): |
| | | model_sr = self.sample_rate |
| | | else: |
| | | if 'audio_sr' in self.sample_rate: |
| | | audio_sr = self.sample_rate['audio_sr'] |
| | | if 'model_sr' in self.sample_rate: |
| | | model_sr = self.sample_rate['model_sr'] |
| | | array = wav_utils.torch_resample(array, audio_sr, model_sr) |
| | | |
| | | data[name] = array |
| | | |
| | | # 3. [Option] Apply preprocessing |
| | | # e.g. espnet2.train.preprocessor:CommonPreprocessor |
| | | if self.preprocess is not None: |
| | | data = self.preprocess(uid, data) |
| | | # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])} |
| | | |
| | | # 4. Force data-precision |
| | | for name in data: |
| | | # value is np.ndarray data |
| | | value = data[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f'All values must be converted to np.ndarray object ' |
| | | f'by preprocessing, but "{name}" is still {type(value)}.') |
| | | |
| | | # Cast to desired type |
| | | if value.dtype.kind == 'f': |
| | | value = value.astype(self.float_dtype) |
| | | elif value.dtype.kind == 'i': |
| | | value = value.astype(self.int_dtype) |
| | | else: |
| | | raise NotImplementedError( |
| | | f'Not supported dtype: {value.dtype}') |
| | | data[name] = value |
| | | |
| | | yield uid, data |
| | |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): |
| | | |
| | | decoder_out, _ = self.decoder( |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | joint_network: Optional[torch.nn.Module], |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | |
| | | postencoder=postencoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | joint_network=joint_network, |
| | | ctc_weight=ctc_weight, |
| | | interctc_weight=interctc_weight, |
| | | ignore_id=ignore_id, |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | # Part of the implementation is borrowed from espnet/espnet. |
| | | |
| | | import copy |
| | | from typing import Optional, Tuple, Union |
| | | |
| | | import humanfriendly |
| | | import numpy as np |
| | | import torch |
| | | import torchaudio.compliance.kaldi as kaldi |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.layers.log_mel import LogMel |
| | | from funasr.layers.stft import Stft |
| | | from funasr.utils.get_default_kwargs import get_default_kwargs |
| | | from funasr.modules.frontends.frontend import Frontend |
| | | from typeguard import check_argument_types |
| | | |
| | | |
| | | def apply_cmvn(inputs, mvn): # noqa |
| | | """ |
| | | Apply CMVN with mvn data |
| | | """ |
| | | |
| | | device = inputs.device |
| | | dtype = inputs.dtype |
| | | frame, dim = inputs.shape |
| | | |
| | | meams = np.tile(mvn[0:1, :dim], (frame, 1)) |
| | | vars = np.tile(mvn[1:2, :dim], (frame, 1)) |
| | | inputs += torch.from_numpy(meams).type(dtype).to(device) |
| | | inputs *= torch.from_numpy(vars).type(dtype).to(device) |
| | | |
| | | return inputs.type(torch.float32) |
| | | |
| | | |
| | | def apply_lfr(inputs, lfr_m, lfr_n): |
| | | LFR_inputs = [] |
| | | T = inputs.shape[0] |
| | | T_lfr = int(np.ceil(T / lfr_n)) |
| | | left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1) |
| | | inputs = torch.vstack((left_padding, inputs)) |
| | | T = T + (lfr_m - 1) // 2 |
| | | for i in range(T_lfr): |
| | | if lfr_m <= T - i * lfr_n: |
| | | LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1)) |
| | | else: # process last LFR frame |
| | | num_padding = lfr_m - (T - i * lfr_n) |
| | | frame = (inputs[i * lfr_n:]).view(-1) |
| | | for _ in range(num_padding): |
| | | frame = torch.hstack((frame, inputs[-1])) |
| | | LFR_inputs.append(frame) |
| | | LFR_outputs = torch.vstack(LFR_inputs) |
| | | return LFR_outputs.type(torch.float32) |
| | | |
| | | |
| | | class WavFrontend(AbsFrontend): |
| | | """Conventional frontend structure for ASR. |
| | | """ |
| | | def __init__( |
| | | self, |
| | | fs: Union[int, str] = 16000, |
| | | n_fft: int = 512, |
| | | win_length: int = 400, |
| | | hop_length: int = 160, |
| | | window: Optional[str] = 'hamming', |
| | | center: bool = True, |
| | | normalized: bool = False, |
| | | onesided: bool = True, |
| | | n_mels: int = 80, |
| | | fmin: int = None, |
| | | fmax: int = None, |
| | | lfr_m: int = 1, |
| | | lfr_n: int = 1, |
| | | htk: bool = False, |
| | | mvn_data=None, |
| | | frontend_conf: Optional[dict] = get_default_kwargs(Frontend), |
| | | apply_stft: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | if isinstance(fs, str): |
| | | fs = humanfriendly.parse_size(fs) |
| | | |
| | | # Deepcopy (In general, dict shouldn't be used as default arg) |
| | | frontend_conf = copy.deepcopy(frontend_conf) |
| | | self.hop_length = hop_length |
| | | self.win_length = win_length |
| | | self.window = window |
| | | self.fs = fs |
| | | self.mvn_data = mvn_data |
| | | self.lfr_m = lfr_m |
| | | self.lfr_n = lfr_n |
| | | |
| | | if apply_stft: |
| | | self.stft = Stft( |
| | | n_fft=n_fft, |
| | | win_length=win_length, |
| | | hop_length=hop_length, |
| | | center=center, |
| | | window=window, |
| | | normalized=normalized, |
| | | onesided=onesided, |
| | | ) |
| | | else: |
| | | self.stft = None |
| | | self.apply_stft = apply_stft |
| | | |
| | | if frontend_conf is not None: |
| | | self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) |
| | | else: |
| | | self.frontend = None |
| | | |
| | | self.logmel = LogMel( |
| | | fs=fs, |
| | | n_fft=n_fft, |
| | | n_mels=n_mels, |
| | | fmin=fmin, |
| | | fmax=fmax, |
| | | htk=htk, |
| | | ) |
| | | self.n_mels = n_mels |
| | | self.frontend_type = 'default' |
| | | |
| | | def output_size(self) -> int: |
| | | return self.n_mels |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, |
| | | input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | |
| | | sample_frequency = self.fs |
| | | num_mel_bins = self.n_mels |
| | | frame_length = self.win_length * 1000 / sample_frequency |
| | | frame_shift = self.hop_length * 1000 / sample_frequency |
| | | |
| | | waveform = input * (1 << 15) |
| | | |
| | | mat = kaldi.fbank(waveform, |
| | | num_mel_bins=num_mel_bins, |
| | | frame_length=frame_length, |
| | | frame_shift=frame_shift, |
| | | dither=1.0, |
| | | energy_floor=0.0, |
| | | window_type=self.window, |
| | | sample_frequency=sample_frequency) |
| | | if self.lfr_m != 1 or self.lfr_n != 1: |
| | | mat = apply_lfr(mat, self.lfr_m, self.lfr_n) |
| | | if self.mvn_data is not None: |
| | | mat = apply_cmvn(mat, self.mvn_data) |
| | | |
| | | input_feats = mat[None, :] |
| | | feats_lens = torch.randn(1) |
| | | feats_lens.fill_(input_feats.shape[1]) |
| | | |
| | | return input_feats, feats_lens |
| | |
| | | from funasr.modules.nets_utils import make_pad_mask
|
| | |
|
| | | class CifPredictor(nn.Module):
|
| | | def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
|
| | | def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
| | | super(CifPredictor, self).__init__()
|
| | |
|
| | | self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
| | |
| | | from funasr.datasets.dataset import DATA_TYPES |
| | | from funasr.datasets.dataset import ESPnetDataset |
| | | from funasr.datasets.iterable_dataset import IterableESPnetDataset |
| | | from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.iterators.chunk_iter_factory import ChunkIterFactory |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | |
| | | @classmethod |
| | | def check_task_requirements( |
| | | cls, |
| | | dataset: Union[AbsDataset, IterableESPnetDataset], |
| | | dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope], |
| | | allow_variable_data_keys: bool, |
| | | train: bool, |
| | | inference: bool = False, |
| | |
| | | **kwargs, |
| | | ) |
| | | |
| | | @classmethod |
| | | def build_streaming_iterator_modelscope( |
| | | cls, |
| | | data_path_and_name_and_type, |
| | | preprocess_fn, |
| | | collate_fn, |
| | | key_file: str = None, |
| | | batch_size: int = 1, |
| | | dtype: str = np.float32, |
| | | num_workers: int = 1, |
| | | allow_variable_data_keys: bool = False, |
| | | ngpu: int = 0, |
| | | inference: bool = False, |
| | | sample_rate: Union[dict, int] = 16000 |
| | | ) -> DataLoader: |
| | | """Build DataLoader using iterable dataset""" |
| | | assert check_argument_types() |
| | | # For backward compatibility for pytorch DataLoader |
| | | if collate_fn is not None: |
| | | kwargs = dict(collate_fn=collate_fn) |
| | | else: |
| | | kwargs = {} |
| | | |
| | | audio_data = data_path_and_name_and_type[0] |
| | | if isinstance(audio_data, bytes): |
| | | dataset = IterableESPnetBytesModelScope( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | sample_rate=sample_rate |
| | | ) |
| | | else: |
| | | dataset = IterableESPnetDatasetModelScope( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | sample_rate=sample_rate |
| | | ) |
| | | |
| | | if dataset.apply_utt2category: |
| | | kwargs.update(batch_size=1) |
| | | else: |
| | | kwargs.update(batch_size=batch_size) |
| | | |
| | | cls.check_task_requirements(dataset, |
| | | allow_variable_data_keys, |
| | | train=False, |
| | | inference=inference) |
| | | |
| | | return DataLoader( |
| | | dataset=dataset, |
| | | pin_memory=ngpu > 0, |
| | | num_workers=num_workers, |
| | | **kwargs, |
| | | ) |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | |
| | | import os |
| | | import shutil |
| | | import ssl |
| | | |
| | | import nltk |
| | | |
| | | # mkdir nltk_data dir if not exist |
| | | try: |
| | | nltk.data.find('.') |
| | | except LookupError: |
| | | dir_list = nltk.data.path |
| | | for dir_item in dir_list: |
| | | if not os.path.exists(dir_item): |
| | | os.mkdir(dir_item) |
| | | if os.path.exists(dir_item): |
| | | break |
| | | |
| | | # download one package if nltk_data not exist |
| | | try: |
| | | nltk.data.find('.') |
| | | except: # noqa: * |
| | | try: |
| | | _create_unverified_https_context = ssl._create_unverified_context |
| | | except AttributeError: |
| | | pass |
| | | else: |
| | | ssl._create_default_https_context = _create_unverified_https_context |
| | | |
| | | nltk.download('cmudict', halt_on_error=False, raise_on_error=True) |
| | | |
| | | # deploy taggers/averaged_perceptron_tagger |
| | | try: |
| | | nltk.data.find('taggers/averaged_perceptron_tagger') |
| | | except: # noqa: * |
| | | data_dir = nltk.data.find('.') |
| | | target_dir = os.path.join(data_dir, 'taggers') |
| | | if not os.path.exists(target_dir): |
| | | os.mkdir(target_dir) |
| | | src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages', |
| | | 'averaged_perceptron_tagger.zip') |
| | | shutil.copyfile(src_file, |
| | | os.path.join(target_dir, 'averaged_perceptron_tagger.zip')) |
| | | shutil._unpack_zipfile( |
| | | os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir) |
| | | |
| | | # deploy corpora/cmudict |
| | | try: |
| | | nltk.data.find('corpora/cmudict') |
| | | except: # noqa: * |
| | | data_dir = nltk.data.find('.') |
| | | target_dir = os.path.join(data_dir, 'corpora') |
| | | if not os.path.exists(target_dir): |
| | | os.mkdir(target_dir) |
| | | src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages', |
| | | 'cmudict.zip') |
| | | shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip')) |
| | | shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir) |
| | | |
| | | try: |
| | | nltk.data.find('taggers/averaged_perceptron_tagger') |
| | | except: # noqa: * |
| | | try: |
| | | _create_unverified_https_context = ssl._create_unverified_context |
| | | except AttributeError: |
| | | pass |
| | | else: |
| | | ssl._create_default_https_context = _create_unverified_https_context |
| | | |
| | | nltk.download('averaged_perceptron_tagger', |
| | | halt_on_error=False, |
| | | raise_on_error=True) |
| | | |
| | | try: |
| | | nltk.data.find('corpora/cmudict') |
| | | except: # noqa: * |
| | | try: |
| | | _create_unverified_https_context = ssl._create_unverified_context |
| | | except AttributeError: |
| | | pass |
| | | else: |
| | | ssl._create_default_https_context = _create_unverified_https_context |
| | | |
| | | nltk.download('cmudict', halt_on_error=False, raise_on_error=True) |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | |
| | | import os |
| | | import struct |
| | | from typing import Any, Dict, List, Union |
| | | |
| | | import librosa |
| | | import numpy as np |
| | | import pkg_resources |
| | | from modelscope.utils.logger import get_logger |
| | | |
| | | logger = get_logger() |
| | | |
| | | green_color = '\033[1;32m' |
| | | red_color = '\033[0;31;40m' |
| | | yellow_color = '\033[0;33;40m' |
| | | end_color = '\033[0m' |
| | | |
| | | global_asr_language = 'zh-cn' |
| | | |
| | | |
| | | def get_version(): |
| | | return float(pkg_resources.get_distribution('easyasr').version) |
| | | |
| | | |
| | | def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str): |
| | | r_audio_fs = None |
| | | |
| | | if audio_format == 'wav': |
| | | r_audio_fs = get_sr_from_wav(audio_in) |
| | | elif audio_format == 'pcm' and isinstance(audio_in, bytes): |
| | | r_audio_fs = get_sr_from_bytes(audio_in) |
| | | |
| | | return r_audio_fs |
| | | |
| | | |
| | | def type_checking(audio_in: Union[str, bytes], |
| | | audio_fs: int = None, |
| | | recog_type: str = None, |
| | | audio_format: str = None): |
| | | r_recog_type = recog_type |
| | | r_audio_format = audio_format |
| | | r_wav_path = audio_in |
| | | |
| | | if isinstance(audio_in, str): |
| | | assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist' |
| | | elif isinstance(audio_in, bytes): |
| | | assert len(audio_in) > 0, 'audio in is empty' |
| | | r_audio_format = 'pcm' |
| | | r_recog_type = 'wav' |
| | | |
| | | if r_recog_type is None: |
| | | # audio_in is wav, recog_type is wav_file |
| | | if os.path.isfile(audio_in): |
| | | if audio_in.endswith('.wav') or audio_in.endswith('.WAV'): |
| | | r_recog_type = 'wav' |
| | | r_audio_format = 'wav' |
| | | |
| | | # recog_type is datasets_file |
| | | elif os.path.isdir(audio_in): |
| | | dir_name = os.path.basename(audio_in) |
| | | if 'test' in dir_name: |
| | | r_recog_type = 'test' |
| | | elif 'dev' in dir_name: |
| | | r_recog_type = 'dev' |
| | | elif 'train' in dir_name: |
| | | r_recog_type = 'train' |
| | | |
| | | if r_audio_format is None: |
| | | if find_file_by_ends(audio_in, '.ark'): |
| | | r_audio_format = 'kaldi_ark' |
| | | elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends( |
| | | audio_in, '.WAV'): |
| | | r_audio_format = 'wav' |
| | | elif find_file_by_ends(audio_in, '.records'): |
| | | r_audio_format = 'tfrecord' |
| | | |
| | | if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': |
| | | # datasets with kaldi_ark file |
| | | r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) |
| | | elif r_audio_format == 'tfrecord' and r_recog_type != 'wav': |
| | | # datasets with tensorflow records file |
| | | r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) |
| | | elif r_audio_format == 'wav' and r_recog_type != 'wav': |
| | | # datasets with waveform files |
| | | r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) |
| | | |
| | | return r_recog_type, r_audio_format, r_wav_path |
| | | |
| | | |
| | | def get_sr_from_bytes(wav: bytes): |
| | | sr = None |
| | | data = wav |
| | | if len(data) > 44: |
| | | try: |
| | | header_fields = {} |
| | | header_fields['ChunkID'] = str(data[0:4], 'UTF-8') |
| | | header_fields['Format'] = str(data[8:12], 'UTF-8') |
| | | header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') |
| | | if header_fields['ChunkID'] == 'RIFF' and header_fields[ |
| | | 'Format'] == 'WAVE' and header_fields[ |
| | | 'Subchunk1ID'] == 'fmt ': |
| | | header_fields['SampleRate'] = struct.unpack('<I', |
| | | data[24:28])[0] |
| | | sr = header_fields['SampleRate'] |
| | | except Exception: |
| | | # no treatment |
| | | pass |
| | | else: |
| | | logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.') |
| | | |
| | | return sr |
| | | |
| | | |
| | | def get_sr_from_wav(fname: str): |
| | | fs = None |
| | | if os.path.isfile(fname): |
| | | audio, fs = librosa.load(fname, sr=None) |
| | | return fs |
| | | elif os.path.isdir(fname): |
| | | dir_files = os.listdir(fname) |
| | | for file in dir_files: |
| | | file_path = os.path.join(fname, file) |
| | | if os.path.isfile(file_path): |
| | | if file_path.endswith('.wav') or file_path.endswith('.WAV'): |
| | | fs = get_sr_from_wav(file_path) |
| | | elif os.path.isdir(file_path): |
| | | fs = get_sr_from_wav(file_path) |
| | | |
| | | if fs is not None: |
| | | break |
| | | |
| | | return fs |
| | | |
| | | |
| | | def find_file_by_ends(dir_path: str, ends: str): |
| | | dir_files = os.listdir(dir_path) |
| | | for file in dir_files: |
| | | file_path = os.path.join(dir_path, file) |
| | | if os.path.isfile(file_path): |
| | | if file_path.endswith(ends): |
| | | return True |
| | | elif os.path.isdir(file_path): |
| | | if find_file_by_ends(file_path, ends): |
| | | return True |
| | | |
| | | return False |
| | | |
| | | |
| | | def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]: |
| | | dir_files = os.listdir(dir_path) |
| | | for file in dir_files: |
| | | file_path = os.path.join(dir_path, file) |
| | | if os.path.isfile(file_path): |
| | | if file_path.endswith('.wav') or file_path.endswith('.WAV'): |
| | | wav_list.append(file_path) |
| | | elif os.path.isdir(file_path): |
| | | recursion_dir_all_wav(wav_list, file_path) |
| | | |
| | | return wav_list |
| | | |
| | | |
| | | def set_parameters(language: str = None): |
| | | if language is not None: |
| | | global global_asr_language |
| | | global_asr_language = language |
| | | |
| | | |
| | | def compute_wer(hyp_list: List[Any], |
| | | ref_list: List[Any], |
| | | lang: str = None) -> Dict[str, Any]: |
| | | assert len(hyp_list) > 0, 'hyp list is empty' |
| | | assert len(ref_list) > 0, 'ref list is empty' |
| | | |
| | | if lang is not None: |
| | | global global_asr_language |
| | | global_asr_language = lang |
| | | |
| | | rst = { |
| | | 'Wrd': 0, |
| | | 'Corr': 0, |
| | | 'Ins': 0, |
| | | 'Del': 0, |
| | | 'Sub': 0, |
| | | 'Snt': 0, |
| | | 'Err': 0.0, |
| | | 'S.Err': 0.0, |
| | | 'wrong_words': 0, |
| | | 'wrong_sentences': 0 |
| | | } |
| | | |
| | | for h_item in hyp_list: |
| | | for r_item in ref_list: |
| | | if h_item['key'] == r_item['key']: |
| | | out_item = compute_wer_by_line(h_item['value'], |
| | | r_item['value'], |
| | | global_asr_language) |
| | | rst['Wrd'] += out_item['nwords'] |
| | | rst['Corr'] += out_item['cor'] |
| | | rst['wrong_words'] += out_item['wrong'] |
| | | rst['Ins'] += out_item['ins'] |
| | | rst['Del'] += out_item['del'] |
| | | rst['Sub'] += out_item['sub'] |
| | | rst['Snt'] += 1 |
| | | if out_item['wrong'] > 0: |
| | | rst['wrong_sentences'] += 1 |
| | | print_wrong_sentence(key=h_item['key'], |
| | | hyp=h_item['value'], |
| | | ref=r_item['value']) |
| | | else: |
| | | print_correct_sentence(key=h_item['key'], |
| | | hyp=h_item['value'], |
| | | ref=r_item['value']) |
| | | |
| | | break |
| | | |
| | | if rst['Wrd'] > 0: |
| | | rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) |
| | | if rst['Snt'] > 0: |
| | | rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) |
| | | |
| | | return rst |
| | | |
| | | |
| | | def compute_wer_by_line(hyp: List[str], |
| | | ref: List[str], |
| | | lang: str = 'zh-cn') -> Dict[str, Any]: |
| | | if lang != 'zh-cn': |
| | | hyp = hyp.split() |
| | | ref = ref.split() |
| | | |
| | | hyp = list(map(lambda x: x.lower(), hyp)) |
| | | ref = list(map(lambda x: x.lower(), ref)) |
| | | |
| | | len_hyp = len(hyp) |
| | | len_ref = len(ref) |
| | | |
| | | cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) |
| | | |
| | | ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) |
| | | |
| | | for i in range(len_hyp + 1): |
| | | cost_matrix[i][0] = i |
| | | for j in range(len_ref + 1): |
| | | cost_matrix[0][j] = j |
| | | |
| | | for i in range(1, len_hyp + 1): |
| | | for j in range(1, len_ref + 1): |
| | | if hyp[i - 1] == ref[j - 1]: |
| | | cost_matrix[i][j] = cost_matrix[i - 1][j - 1] |
| | | else: |
| | | substitution = cost_matrix[i - 1][j - 1] + 1 |
| | | insertion = cost_matrix[i - 1][j] + 1 |
| | | deletion = cost_matrix[i][j - 1] + 1 |
| | | |
| | | compare_val = [substitution, insertion, deletion] |
| | | |
| | | min_val = min(compare_val) |
| | | operation_idx = compare_val.index(min_val) + 1 |
| | | cost_matrix[i][j] = min_val |
| | | ops_matrix[i][j] = operation_idx |
| | | |
| | | match_idx = [] |
| | | i = len_hyp |
| | | j = len_ref |
| | | rst = { |
| | | 'nwords': len_ref, |
| | | 'cor': 0, |
| | | 'wrong': 0, |
| | | 'ins': 0, |
| | | 'del': 0, |
| | | 'sub': 0 |
| | | } |
| | | while i >= 0 or j >= 0: |
| | | i_idx = max(0, i) |
| | | j_idx = max(0, j) |
| | | |
| | | if ops_matrix[i_idx][j_idx] == 0: # correct |
| | | if i - 1 >= 0 and j - 1 >= 0: |
| | | match_idx.append((j - 1, i - 1)) |
| | | rst['cor'] += 1 |
| | | |
| | | i -= 1 |
| | | j -= 1 |
| | | |
| | | elif ops_matrix[i_idx][j_idx] == 2: # insert |
| | | i -= 1 |
| | | rst['ins'] += 1 |
| | | |
| | | elif ops_matrix[i_idx][j_idx] == 3: # delete |
| | | j -= 1 |
| | | rst['del'] += 1 |
| | | |
| | | elif ops_matrix[i_idx][j_idx] == 1: # substitute |
| | | i -= 1 |
| | | j -= 1 |
| | | rst['sub'] += 1 |
| | | |
| | | if i < 0 and j >= 0: |
| | | rst['del'] += 1 |
| | | elif j < 0 and i >= 0: |
| | | rst['ins'] += 1 |
| | | |
| | | match_idx.reverse() |
| | | wrong_cnt = cost_matrix[len_hyp][len_ref] |
| | | rst['wrong'] = wrong_cnt |
| | | |
| | | return rst |
| | | |
| | | |
| | | def print_wrong_sentence(key: str, hyp: str, ref: str): |
| | | space = len(key) |
| | | print(key + yellow_color + ' ref: ' + ref) |
| | | print(' ' * space + red_color + ' hyp: ' + hyp + end_color) |
| | | |
| | | |
| | | def print_correct_sentence(key: str, hyp: str, ref: str): |
| | | space = len(key) |
| | | print(key + yellow_color + ' ref: ' + ref) |
| | | print(' ' * space + green_color + ' hyp: ' + hyp + end_color) |
| | | |
| | | |
| | | def print_progress(percent): |
| | | if percent > 1: |
| | | percent = 1 |
| | | res = int(50 * percent) * '#' |
| | | print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='') |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | |
| | | import string |
| | | from typing import Any, List, Union |
| | | |
| | | |
| | | def isChinese(ch: str): |
| | | if '\u4e00' <= ch <= '\u9fff': |
| | | return True |
| | | return False |
| | | |
| | | |
| | | def isAllChinese(word: Union[List[Any], str]): |
| | | word_lists = [] |
| | | table = str.maketrans('', '', string.punctuation) |
| | | for i in word: |
| | | cur = i.translate(table) |
| | | cur = cur.replace(' ', '') |
| | | cur = cur.replace('</s>', '') |
| | | cur = cur.replace('<s>', '') |
| | | word_lists.append(cur) |
| | | |
| | | if len(word_lists) == 0: |
| | | return False |
| | | |
| | | for ch in word_lists: |
| | | if isChinese(ch) is False: |
| | | return False |
| | | return True |
| | | |
| | | |
| | | def isAllAlpha(word: Union[List[Any], str]): |
| | | word_lists = [] |
| | | table = str.maketrans('', '', string.punctuation) |
| | | for i in word: |
| | | cur = i.translate(table) |
| | | cur = cur.replace(' ', '') |
| | | cur = cur.replace('</s>', '') |
| | | cur = cur.replace('<s>', '') |
| | | word_lists.append(cur) |
| | | |
| | | if len(word_lists) == 0: |
| | | return False |
| | | |
| | | for ch in word_lists: |
| | | if ch.isalpha() is False: |
| | | return False |
| | | elif ch.isalpha() is True and isChinese(ch) is True: |
| | | return False |
| | | |
| | | return True |
| | | |
| | | |
| | | def abbr_dispose(words: List[Any]) -> List[Any]: |
| | | words_size = len(words) |
| | | word_lists = [] |
| | | abbr_begin = [] |
| | | abbr_end = [] |
| | | last_num = -1 |
| | | for num in range(words_size): |
| | | if num <= last_num: |
| | | continue |
| | | |
| | | if len(words[num]) == 1 and words[num].encode('utf-8').isalpha(): |
| | | if num + 1 < words_size and words[ |
| | | num + 1] == ' ' and num + 2 < words_size and len( |
| | | words[num + |
| | | 2]) == 1 and words[num + |
| | | 2].encode('utf-8').isalpha(): |
| | | # found the begin of abbr |
| | | abbr_begin.append(num) |
| | | num += 2 |
| | | abbr_end.append(num) |
| | | # to find the end of abbr |
| | | while True: |
| | | num += 1 |
| | | if num < words_size and words[num] == ' ': |
| | | num += 1 |
| | | if num < words_size and len( |
| | | words[num]) == 1 and words[num].encode( |
| | | 'utf-8').isalpha(): |
| | | abbr_end.pop() |
| | | abbr_end.append(num) |
| | | last_num = num |
| | | else: |
| | | break |
| | | else: |
| | | break |
| | | |
| | | last_num = -1 |
| | | for num in range(words_size): |
| | | if num <= last_num: |
| | | continue |
| | | |
| | | if num in abbr_begin: |
| | | word_lists.append(words[num].upper()) |
| | | num += 1 |
| | | while num < words_size: |
| | | if num in abbr_end: |
| | | word_lists.append(words[num].upper()) |
| | | last_num = num |
| | | break |
| | | else: |
| | | if words[num].encode('utf-8').isalpha(): |
| | | word_lists.append(words[num].upper()) |
| | | num += 1 |
| | | else: |
| | | word_lists.append(words[num]) |
| | | |
| | | return word_lists |
| | | |
| | | |
| | | def sentence_postprocess(words: List[Any]): |
| | | middle_lists = [] |
| | | word_lists = [] |
| | | word_item = '' |
| | | |
| | | # wash words lists |
| | | for i in words: |
| | | word = '' |
| | | if isinstance(i, str): |
| | | word = i |
| | | else: |
| | | word = i.decode('utf-8') |
| | | |
| | | if word in ['<s>', '</s>', '<unk>']: |
| | | continue |
| | | else: |
| | | middle_lists.append(word) |
| | | |
| | | # all chinese characters |
| | | if isAllChinese(middle_lists): |
| | | for ch in middle_lists: |
| | | word_lists.append(ch.replace(' ', '')) |
| | | |
| | | # all alpha characters |
| | | elif isAllAlpha(middle_lists): |
| | | for ch in middle_lists: |
| | | word = '' |
| | | if '@@' in ch: |
| | | word = ch.replace('@@', '') |
| | | word_item += word |
| | | else: |
| | | word_item += ch |
| | | word_lists.append(word_item) |
| | | word_lists.append(' ') |
| | | word_item = '' |
| | | |
| | | # mix characters |
| | | else: |
| | | alpha_blank = False |
| | | for ch in middle_lists: |
| | | word = '' |
| | | if isAllChinese(ch): |
| | | if alpha_blank is True: |
| | | word_lists.pop() |
| | | word_lists.append(ch) |
| | | alpha_blank = False |
| | | elif '@@' in ch: |
| | | word = ch.replace('@@', '') |
| | | word_item += word |
| | | alpha_blank = False |
| | | elif isAllAlpha(ch): |
| | | word_item += ch |
| | | word_lists.append(word_item) |
| | | word_lists.append(' ') |
| | | word_item = '' |
| | | alpha_blank = True |
| | | else: |
| | | raise ValueError('invalid character: {}'.format(ch)) |
| | | |
| | | word_lists = abbr_dispose(word_lists) |
| | | sentence = ''.join(word_lists).strip() |
| | | return sentence |
| New file |
| | |
| | | # Copyright (c) Alibaba, Inc. and its affiliates. |
| | | |
| | | import math |
| | | import os |
| | | from typing import Any, Dict, Union |
| | | |
| | | import kaldiio |
| | | import librosa |
| | | import numpy as np |
| | | import torch |
| | | import torchaudio |
| | | import torchaudio.compliance.kaldi as kaldi |
| | | |
| | | |
| | | def ndarray_resample(audio_in: np.ndarray, |
| | | fs_in: int = 16000, |
| | | fs_out: int = 16000) -> np.ndarray: |
| | | audio_out = audio_in |
| | | if fs_in != fs_out: |
| | | audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out) |
| | | return audio_out |
| | | |
| | | |
| | | def torch_resample(audio_in: torch.Tensor, |
| | | fs_in: int = 16000, |
| | | fs_out: int = 16000) -> torch.Tensor: |
| | | audio_out = audio_in |
| | | if fs_in != fs_out: |
| | | audio_out = torchaudio.transforms.Resample(orig_freq=fs_in, |
| | | new_freq=fs_out)(audio_in) |
| | | return audio_out |
| | | |
| | | |
| | | def extract_CMVN_featrures(mvn_file): |
| | | """ |
| | | extract CMVN from cmvn.ark |
| | | """ |
| | | |
| | | if not os.path.exists(mvn_file): |
| | | return None |
| | | try: |
| | | cmvn = kaldiio.load_mat(mvn_file) |
| | | means = [] |
| | | variance = [] |
| | | |
| | | for i in range(cmvn.shape[1] - 1): |
| | | means.append(float(cmvn[0][i])) |
| | | |
| | | count = float(cmvn[0][-1]) |
| | | |
| | | for i in range(cmvn.shape[1] - 1): |
| | | variance.append(float(cmvn[1][i])) |
| | | |
| | | for i in range(len(means)): |
| | | means[i] /= count |
| | | variance[i] = variance[i] / count - means[i] * means[i] |
| | | if variance[i] < 1.0e-20: |
| | | variance[i] = 1.0e-20 |
| | | variance[i] = 1.0 / math.sqrt(variance[i]) |
| | | |
| | | cmvn = np.array([means, variance]) |
| | | return cmvn |
| | | except Exception: |
| | | cmvn = extract_CMVN_features_txt(mvn_file) |
| | | return cmvn |
| | | |
| | | |
| | | def extract_CMVN_features_txt(mvn_file): # noqa |
| | | with open(mvn_file, 'r', encoding='utf-8') as f: |
| | | lines = f.readlines() |
| | | |
| | | add_shift_list = [] |
| | | rescale_list = [] |
| | | for i in range(len(lines)): |
| | | line_item = lines[i].split() |
| | | if line_item[0] == '<AddShift>': |
| | | line_item = lines[i + 1].split() |
| | | if line_item[0] == '<LearnRateCoef>': |
| | | add_shift_line = line_item[3:(len(line_item) - 1)] |
| | | add_shift_list = list(add_shift_line) |
| | | continue |
| | | elif line_item[0] == '<Rescale>': |
| | | line_item = lines[i + 1].split() |
| | | if line_item[0] == '<LearnRateCoef>': |
| | | rescale_line = line_item[3:(len(line_item) - 1)] |
| | | rescale_list = list(rescale_line) |
| | | continue |
| | | add_shift_list_f = [float(s) for s in add_shift_list] |
| | | rescale_list_f = [float(s) for s in rescale_list] |
| | | cmvn = np.array([add_shift_list_f, rescale_list_f]) |
| | | return cmvn |
| | | |
| | | |
| | | def build_LFR_features(inputs, m=7, n=6): # noqa |
| | | """ |
| | | Actually, this implements stacking frames and skipping frames. |
| | | if m = 1 and n = 1, just return the origin features. |
| | | if m = 1 and n > 1, it works like skipping. |
| | | if m > 1 and n = 1, it works like stacking but only support right frames. |
| | | if m > 1 and n > 1, it works like LFR. |
| | | |
| | | Args: |
| | | inputs_batch: inputs is T x D np.ndarray |
| | | m: number of frames to stack |
| | | n: number of frames to skip |
| | | """ |
| | | # LFR_inputs_batch = [] |
| | | # for inputs in inputs_batch: |
| | | LFR_inputs = [] |
| | | T = inputs.shape[0] |
| | | T_lfr = int(np.ceil(T / n)) |
| | | left_padding = np.tile(inputs[0], ((m - 1) // 2, 1)) |
| | | inputs = np.vstack((left_padding, inputs)) |
| | | T = T + (m - 1) // 2 |
| | | for i in range(T_lfr): |
| | | if m <= T - i * n: |
| | | LFR_inputs.append(np.hstack(inputs[i * n:i * n + m])) |
| | | else: # process last LFR frame |
| | | num_padding = m - (T - i * n) |
| | | frame = np.hstack(inputs[i * n:]) |
| | | for _ in range(num_padding): |
| | | frame = np.hstack((frame, inputs[-1])) |
| | | LFR_inputs.append(frame) |
| | | return np.vstack(LFR_inputs) |
| | | |
| | | |
| | | def compute_fbank(wav_file, |
| | | num_mel_bins=80, |
| | | frame_length=25, |
| | | frame_shift=10, |
| | | dither=0.0, |
| | | is_pcm=False, |
| | | fs: Union[int, Dict[Any, int]] = 16000): |
| | | audio_sr: int = 16000 |
| | | model_sr: int = 16000 |
| | | if isinstance(fs, int): |
| | | model_sr = fs |
| | | audio_sr = fs |
| | | else: |
| | | model_sr = fs['model_fs'] |
| | | audio_sr = fs['audio_fs'] |
| | | |
| | | if is_pcm is True: |
| | | # byte(PCM16) to float32, and resample |
| | | value = wav_file |
| | | middle_data = np.frombuffer(value, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | | if middle_data.dtype.kind not in 'iu': |
| | | raise TypeError("'middle_data' must be an array of integers") |
| | | dtype = np.dtype('float32') |
| | | if dtype.kind != 'f': |
| | | raise TypeError("'dtype' must be a floating point type") |
| | | |
| | | i = np.iinfo(middle_data.dtype) |
| | | abs_max = 2**(i.bits - 1) |
| | | offset = i.min + abs_max |
| | | waveform = np.frombuffer( |
| | | (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) |
| | | waveform = ndarray_resample(waveform, audio_sr, model_sr) |
| | | waveform = torch.from_numpy(waveform.reshape(1, -1)) |
| | | else: |
| | | # load pcm from wav, and resample |
| | | waveform, audio_sr = torchaudio.load(wav_file) |
| | | waveform = waveform * (1 << 15) |
| | | waveform = torch_resample(waveform, audio_sr, model_sr) |
| | | |
| | | mat = kaldi.fbank(waveform, |
| | | num_mel_bins=num_mel_bins, |
| | | frame_length=frame_length, |
| | | frame_shift=frame_shift, |
| | | dither=dither, |
| | | energy_floor=0.0, |
| | | window_type='hamming', |
| | | sample_frequency=model_sr) |
| | | |
| | | input_feats = mat |
| | | |
| | | return input_feats |