From 7012ca2efc130103c4acd24e3678c7ae280f8db4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 20:08:55 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer
---
examples/industrial_data_pretraining/paraformer-large/run.sh | 2
funasr/models/paraformer/model.py | 2
/dev/null | 655 ------------------------------------------------------
funasr/bin/train.py | 7
funasr/utils/trainer.py | 0
funasr/bin/export_model.py | 0
setup.py | 2
funasr/models/model_class_factory.py | 22 -
8 files changed, 5 insertions(+), 685 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer-large/run.sh b/examples/industrial_data_pretraining/paraformer-large/run.sh
index 8571974..ce1953c 100644
--- a/examples/industrial_data_pretraining/paraformer-large/run.sh
+++ b/examples/industrial_data_pretraining/paraformer-large/run.sh
@@ -1,5 +1,5 @@
-cmd="funasr/cli/train_cli.py"
+cmd="funasr/bin/train.py"
python $cmd \
+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
diff --git a/funasr/bin/argument.py b/funasr/bin/argument.py
deleted file mode 100644
index 0ea4ac9..0000000
--- a/funasr/bin/argument.py
+++ /dev/null
@@ -1,262 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import sys
-
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import config_argparse
-import argparse
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, default=None)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=1,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- parser.add_argument(
- "--hotword",
- type=str_or_none,
- default=None,
- help="hotword file path or hotwords seperated by space"
- )
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
- group.add_argument(
- "--mc",
- type=bool,
- default=False,
- help="MultiChannel input",
- )
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--punc_infer_config",
- type=str,
- help="PUNC infer configuration",
- )
- group.add_argument(
- "--punc_model_file",
- type=str,
- help="PUNC model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global CMVN file",
- )
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--sv_model_file",
- type=str,
- help="SV model parameter file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
- group.add_argument(
- "--beam_search_config",
- default={},
- help="The keyword arguments for transducer beam search.",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.0,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
- group.add_argument("--fake_streaming", type=str2bool, default=False)
- group.add_argument("--full_utt", type=str2bool, default=False)
- group.add_argument("--chunk_size", type=int, default=16)
- group.add_argument("--left_context", type=int, default=16)
- group.add_argument("--right_context", type=int, default=0)
- group.add_argument(
- "--display_partial_hypotheses",
- type=bool,
- default=False,
- help="Whether to display partial hypotheses during chunk-by-chunk inference.",
- )
-
- group = parser.add_argument_group("Dynamic quantization related")
- group.add_argument(
- "--quantize_asr_model",
- type=bool,
- default=False,
- help="Apply dynamic quantization to ASR model.",
- )
- group.add_argument(
- "--quantize_modules",
- nargs="*",
- default=None,
- help="""Module names to apply dynamic quantization on.
- The module names are provided as a list, where each name is separated
- by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
- Each specified name should be an attribute of 'torch.nn', e.g.:
- torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
- )
- group.add_argument(
- "--quantize_dtype",
- type=str,
- default="qint8",
- choices=["float16", "qint8"],
- help="Dtype for dynamic quantization.",
- )
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
- group.add_argument("--token_num_relax", type=int, default=1, help="")
- group.add_argument("--decoding_ind", type=int, default=0, help="")
- group.add_argument("--decoding_mode", type=str, default="model1", help="")
- group.add_argument(
- "--ctc_weight2",
- type=float,
- default=0.0,
- help="CTC weight in joint decoding",
- )
- return parser
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
deleted file mode 100644
index a1cede1..0000000
--- a/funasr/bin/asr_infer.py
+++ /dev/null
@@ -1,2004 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-
-import codecs
-import copy
-import logging
-import os
-import re
-import tempfile
-from pathlib import Path
-from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import requests
-import torch
-from packaging.version import parse as V
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.modules.beam_search.beam_search import BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
-from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
-from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.build_utils.build_asr_model import frontend_choices
-from funasr.tokenizer.build_tokenizer import build_tokenizer
-from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
- ):
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend == 'wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
- else:
- frontend_class = frontend_choices.get_class(asr_train_args.frontend)
- frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, None, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
- from funasr.modules.beam_search.beam_search import BeamSearch
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, _ = self.asr_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
- assert len(enc) == 1, len(enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- return results
-
-
-class Speech2TextParaformer:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- clas_scale: float = 1.0,
- decoding_ind: int = 0,
- **kwargs,
- ):
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, None, device, task_name="lm"
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
- from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.cmvn_file = cmvn_file
-
- # 6. [Optional] Build hotword list from str, local file or url
- self.hotword_list = None
- self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
- self.clas_scale = clas_scale
-
- is_use_lm = lm_weight != 0.0 and lm_file is not None
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
- beam_search = None
- self.beam_search = beam_search
- logging.info(f"Beam_search: {self.beam_search}")
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- self.decoding_ind = decoding_ind
- if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- if decoding_ind is None:
- decoding_ind = 0 if self.decoding_ind is None else self.decoding_ind
- enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
- if isinstance(enc, tuple):
- enc = enc[0]
- # assert len(enc) == 1, len(enc)
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
- predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- if not isinstance(self.asr_model, ContextualParaformer) and \
- not isinstance(self.asr_model, NeatContextualParaformer):
- if self.hotword_list:
- logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
- pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
- enc_len,
- pre_acoustic_embeds,
- pre_token_length,
- hw_list=self.hotword_list,
- clas_scale=self.clas_scale)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- if isinstance(self.asr_model, BiCifParaformer):
- _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
- pre_token_length) # test no bias cif2
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = enc[i, :enc_len[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
- if pre_token_length[i] == 0:
- yseq = torch.tensor(
- [self.asr_model.sos] + [self.asr_model.eos], device=pre_acoustic_embeds.device
- )
- score = torch.tensor(0.0, device=pre_acoustic_embeds.device)
- else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- timestamp = []
- if isinstance(self.asr_model, BiCifParaformer):
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
- us_peaks[i][:enc_len[i] * 3],
- copy.copy(token),
- vad_offset=begin_time)
- results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
-
- return results
-
- def generate_hotwords_list(self, hotword_list_or_file):
- def load_seg_dict(seg_dict_file):
- seg_dict = {}
- assert isinstance(seg_dict_file, str)
- with open(seg_dict_file, "r", encoding="utf8") as f:
- lines = f.readlines()
- for line in lines:
- s = line.strip().split()
- key = s[0]
- value = s[1:]
- seg_dict[key] = " ".join(value)
- return seg_dict
-
- def seg_tokenize(txt, seg_dict):
- pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
- out_txt = ""
- for word in txt:
- word = word.lower()
- if word in seg_dict:
- out_txt += seg_dict[word] + " "
- else:
- if pattern.match(word):
- for char in word:
- if char in seg_dict:
- out_txt += seg_dict[char] + " "
- else:
- out_txt += "<unk>" + " "
- else:
- out_txt += "<unk>" + " "
- return out_txt.strip().split()
-
- seg_dict = None
- if self.cmvn_file is not None:
- model_dir = os.path.dirname(self.cmvn_file)
- seg_dict_file = os.path.join(model_dir, 'seg_dict')
- if os.path.exists(seg_dict_file):
- seg_dict = load_seg_dict(seg_dict_file)
- else:
- seg_dict = None
- # for None
- if hotword_list_or_file is None:
- hotword_list = None
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords from local txt...")
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hw_list = hw.split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids(hw_list))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- elif hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hw_list = hw.split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids(hw_list))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for text str input
- elif not hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords as str...")
- hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- hw_list = hw.strip().split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_list.append(self.converter.tokens2ids(hw_list))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- else:
- hotword_list = None
- return hotword_list
-
-
-class Speech2TextParaformerOnline:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, None, device, task_name="lm"
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
- from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
- # 6. [Optional] Build hotword list from str, local file or url
-
- is_use_lm = lm_weight != 0.0 and lm_file is not None
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
- beam_search = None
- self.beam_search = beam_search
- logging.info(f"Beam_search: {self.beam_search}")
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- results = []
- cache_en = cache["encoder"]
- if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
- if cache_en["start_idx"] == 0:
- return []
- cache_en["tail_chunk"] = True
- feats = cache_en["feats"]
- feats_len = torch.tensor([feats.shape[1]])
- self.asr_model.frontend = None
- self.frontend.cache_reset()
- results = self.infer(feats, feats_len, cache)
- return results
- else:
- if self.frontend is not None:
- if cache_en["start_idx"] == 0:
- self.frontend.cache_reset()
- feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
-
- if feats.shape[1] != 0:
- results = self.infer(feats, feats_len, cache)
-
- return results
-
- @torch.no_grad()
- def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None):
- batch = {"speech": feats, "speech_lengths": feats_len}
- batch = to_device(batch, device=self.device)
- # b. Forward Encoder
- enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache)
- if isinstance(enc, tuple):
- enc = enc[0]
- # assert len(enc) == 1, len(enc)
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
- predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
- pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
- if torch.max(pre_token_length) < 1:
- return []
- decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
- decoder_out = decoder_outs
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = enc[i, :enc_len[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
- postprocessed_result = ""
- for item in token:
- if item.endswith('@@'):
- postprocessed_result += item[:-2]
- elif re.match('^[a-zA-Z]+$', item):
- postprocessed_result += item + " "
- else:
- postprocessed_result += item
-
- results.append(postprocessed_result)
-
- return results
-
-
-class Speech2TextUniASR:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- frontend_conf: dict = None,
- **kwargs,
- ):
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
- if decoding_mode == "model1":
- decoder = asr_model.decoder
- else:
- decoder = asr_model.decoder2
-
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, device, "lm"
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
- from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
- # logging.info(f"Beam_search: {beam_search}")
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.token_num_relax = token_num_relax
- self.decoding_ind = decoding_ind
- self.decoding_mode = decoding_mode
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- feats_raw = feats.clone().to(self.device)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
- # b. Forward Encoder
- _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
- if isinstance(enc, tuple):
- enc = enc[0]
- assert len(enc) == 1, len(enc)
- if self.decoding_mode == "model1":
- predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
- else:
- enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
- predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
-
- scama_mask = predictor_outs[4]
- pre_token_length = predictor_outs[1]
- pre_acoustic_embeds = predictor_outs[0]
- maxlen = pre_token_length.sum().item() + self.token_num_relax
- minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
- minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
- token = list(filter(lambda x: x != "<gbg>", token))
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- return results
-
-
-class Speech2TextMFCCA:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- **kwargs,
- ):
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, None, device, task_name="lm"
- )
- lm.to(device)
- scorers["lm"] = lm.lm
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
- # beam_search.__class__ = BatchBeamSearch
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if (speech.dim() == 3):
- speech = torch.squeeze(speech, 2)
- # speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- speech = speech.to(getattr(torch, self.dtype))
- # lenghts: (1,)
- lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- batch = {"speech": speech, "speech_lengths": lengths}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, _ = self.asr_model.encode(**batch)
-
- assert len(enc) == 1, len(enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- return results
-
-
-class Speech2TextTransducer:
- """Speech2Text class for Transducer models.
- Args:
- asr_train_config: ASR model training config path.
- asr_model_file: ASR model path.
- beam_search_config: Beam search config path.
- lm_train_config: Language Model training config path.
- lm_file: Language Model config path.
- token_type: Type of token units.
- bpemodel: BPE model path.
- device: Device to use for inference.
- beam_size: Size of beam during search.
- dtype: Data type.
- lm_weight: Language model weight.
- quantize_asr_model: Whether to apply dynamic quantization to ASR model.
- quantize_modules: List of module names to apply dynamic quantization on.
- quantize_dtype: Dynamic quantization data type.
- nbest: Number of final hypothesis.
- streaming: Whether to perform chunk-by-chunk inference.
- chunk_size: Number of frames in chunk AFTER subsampling.
- left_context: Number of frames in left context AFTER subsampling.
- right_context: Number of frames in right context AFTER subsampling.
- display_partial_hypotheses: Whether to display partial hypotheses.
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- beam_search_config: Dict[str, Any] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- beam_size: int = 5,
- dtype: str = "float32",
- lm_weight: float = 1.0,
- quantize_asr_model: bool = False,
- quantize_modules: List[str] = None,
- quantize_dtype: str = "qint8",
- nbest: int = 1,
- streaming: bool = False,
- fake_streaming: bool = False,
- full_utt: bool = False,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- display_partial_hypotheses: bool = False,
- ) -> None:
- """Construct a Speech2Text object."""
- super().__init__()
-
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
-
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- if quantize_asr_model:
- if quantize_modules is not None:
- if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
- raise ValueError(
- "Only 'Linear' and 'LSTM' modules are currently supported"
- " by PyTorch and in --quantize_modules"
- )
-
- q_config = set([getattr(torch.nn, q) for q in quantize_modules])
- else:
- q_config = {torch.nn.Linear}
-
- if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
- raise ValueError(
- "float16 dtype for dynamic quantization is not supported with torch"
- " version < 1.5.0. Switching to qint8 dtype instead."
- )
- q_dtype = getattr(torch, quantize_dtype)
-
- asr_model = torch.quantization.quantize_dynamic(
- asr_model, q_config, dtype=q_dtype
- ).eval()
- else:
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, None, device, task_name="lm"
- )
- lm_scorer = lm.lm
- else:
- lm_scorer = None
-
- # 4. Build BeamSearch object
- if beam_search_config is None:
- beam_search_config = {}
-
- beam_search = BeamSearchTransducer(
- asr_model.decoder,
- asr_model.joint_network,
- beam_size,
- lm=lm_scorer,
- lm_weight=lm_weight,
- nbest=nbest,
- **beam_search_config,
- )
-
- token_list = asr_model.token_list
-
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
-
- self.converter = converter
- self.tokenizer = tokenizer
-
- self.beam_search = beam_search
- self.streaming = streaming
- self.fake_streaming = fake_streaming
- self.full_utt = full_utt
- self.chunk_size = max(chunk_size, 0)
- self.left_context = left_context
- self.right_context = max(right_context, 0)
-
- if not streaming or chunk_size == 0:
- self.streaming = False
- self.asr_model.encoder.dynamic_chunk_training = False
-
- if not fake_streaming or chunk_size == 0:
- self.fake_streaming = False
- self.asr_model.encoder.dynamic_chunk_training = False
-
- self.frontend = frontend
- self.window_size = self.chunk_size + self.right_context
-
- if self.streaming:
- self._ctx = self.asr_model.encoder.get_encoder_input_size(
- self.window_size
- )
- self._right_ctx = right_context
-
- self.last_chunk_length = (
- self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
- )
- self.reset_inference_cache()
-
- def reset_inference_cache(self) -> None:
- """Reset Speech2Text parameters."""
- self.frontend_cache = None
-
- self.asr_model.encoder.reset_streaming_cache(
- self.left_context, device=self.device
- )
- self.beam_search.reset_inference_cache()
-
- self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
- @torch.no_grad()
- def streaming_decode(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- is_final: bool = True,
- ) -> List[HypothesisTransducer]:
- """Speech2Text streaming call.
- Args:
- speech: Chunk of speech data. (S)
- is_final: Whether speech corresponds to the final chunk of data.
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if is_final:
- if self.streaming and speech.size(0) < self.last_chunk_length:
- pad = torch.zeros(
- self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
- )
- speech = torch.cat([speech, pad],
- dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
-
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- if self.asr_model.normalize is not None:
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
- enc_out = self.asr_model.encoder.chunk_forward(
- feats,
- feats_lengths,
- self.num_processed_frames,
- chunk_size=self.chunk_size,
- left_context=self.left_context,
- right_context=self.right_context,
- )
- nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
-
- self.num_processed_frames += self.chunk_size
-
- if is_final:
- self.reset_inference_cache()
-
- return nbest_hyps
-
- @torch.no_grad()
- def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
- """Speech2Text call.
- Args:
- speech: Speech data. (S)
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
-
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- speech = torch.unsqueeze(speech, axis=0)
- speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- if self.asr_model.normalize is not None:
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
- enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
- self.right_context)
- nbest_hyps = self.beam_search(enc_out[0])
-
- return nbest_hyps
-
- @torch.no_grad()
- def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
- """Speech2Text call.
- Args:
- speech: Speech data. (S)
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
- assert check_argument_types()
-
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- speech = torch.unsqueeze(speech, axis=0)
- speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- if self.asr_model.normalize is not None:
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
- enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths)
- nbest_hyps = self.beam_search(enc_out[0])
-
- return nbest_hyps
-
- @torch.no_grad()
- def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
- """Speech2Text call.
- Args:
- speech: Speech data. (S)
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
-
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- speech = torch.unsqueeze(speech, axis=0)
- speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
-
- enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
- nbest_hyps = self.beam_search(enc_out[0])
-
- return nbest_hyps
-
- def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
- """Build partial or final results from the hypotheses.
- Args:
- nbest_hyps: N-best hypothesis.
- Returns:
- results: Results containing different representation for the hypothesis.
- """
- results = []
-
- for hyp in nbest_hyps:
- token_int = list(filter(lambda x: x != 0, hyp.yseq))
-
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
-
- return results
-
-
-class Speech2TextSAASR:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
- ):
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- from funasr.tasks.sa_asr import frontend_choices
- if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
- frontend_class = frontend_choices.get_class(asr_train_args.frontend)
- frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
- else:
- frontend_class = frontend_choices.get_class(asr_train_args.frontend)
- frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = build_model_from_file(
- lm_train_config, lm_file, None, device, task_name="lm"
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
- from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
- profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
- ) -> List[
- Tuple[
- Optional[str],
- Optional[str],
- List[str],
- List[int],
- Union[HypothesisSAASR],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, text_id, token, token_int, hyp
-
- """
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if isinstance(profile, np.ndarray):
- profile = torch.tensor(profile)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- asr_enc, _, spk_enc = self.asr_model.encode(**batch)
- if isinstance(asr_enc, tuple):
- asr_enc = asr_enc[0]
- if isinstance(spk_enc, tuple):
- spk_enc = spk_enc[0]
- assert len(asr_enc) == 1, len(asr_enc)
- assert len(spk_enc) == 1, len(spk_enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1: last_pos]
- else:
- token_int = hyp.yseq[1: last_pos].tolist()
-
- spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
-
- token_ori = self.converter.ids2tokens(token_int)
- text_ori = self.tokenizer.tokens2text(token_ori)
-
- text_ori_spklist = text_ori.split('$')
- cur_index = 0
- spk_choose = []
- for i in range(len(text_ori_spklist)):
- text_ori_split = text_ori_spklist[i]
- n = len(text_ori_split)
- spk_weights_local = spk_weigths[cur_index: cur_index + n]
- cur_index = cur_index + n + 1
- spk_weights_local = spk_weights_local.mean(dim=0)
- spk_choose_local = spk_weights_local.argmax(-1)
- spk_choose.append(spk_choose_local.item() + 1)
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- text_spklist = text.split('$')
- assert len(spk_choose) == len(text_spklist)
-
- spk_list = []
- for i in range(len(text_spklist)):
- text_split = text_spklist[i]
- n = len(text_split)
- spk_list.append(str(spk_choose[i]) * n)
-
- text_id = '$'.join(spk_list)
-
- assert len(text) == len(text_id)
-
- results.append((text, text_id, token, token_int, hyp))
-
- return results
-
-
-class Speech2TextWhisper:
- """Speech2Text class
-
- Examples:
- >>> import librosa
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- language: str = None,
- task: str = "transcribe",
- **kwargs,
- ):
-
- from funasr.tasks.whisper import ASRTask
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- token_list = []
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.tokenizer = tokenizer
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
- self.language = language
- self.task = task
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
-
- from funasr.utils.whisper_utils.transcribe import transcribe
- from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
- from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
-
- speech = speech[0]
- speech = pad_or_trim(speech)
- mel = log_mel_spectrogram(speech).to(self.device)
-
- if self.asr_model.is_multilingual:
- options = DecodingOptions(fp16=False, language=self.language, task=self.task)
- asr_res = decode(self.asr_model, mel, options)
- text = asr_res.text
- language = self.language if self.language else asr_res.language
- else:
- asr_res = transcribe(self.asr_model, speech, fp16=False)
- text = asr_res["text"]
- language = asr_res["language"]
- results = [(text, language)]
- return results
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
deleted file mode 100644
index 6151d28..0000000
--- a/funasr/bin/asr_inference_launch.py
+++ /dev/null
@@ -1,2248 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-from optparse import Option
-import os
-import sys
-import time
-from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-import torchaudio
-# import librosa
-import librosa
-import yaml
-
-from funasr.bin.asr_infer import Speech2Text
-from funasr.bin.asr_infer import Speech2TextMFCCA
-from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
-from funasr.bin.asr_infer import Speech2TextSAASR
-from funasr.bin.asr_infer import Speech2TextTransducer
-from funasr.bin.asr_infer import Speech2TextUniASR
-from funasr.bin.asr_infer import Speech2TextWhisper
-from funasr.bin.punc_infer import Text2Punc
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.vad_infer import Speech2VadSegment
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.subsampling import TooShortUttError
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import asr_utils, postprocess_utils
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.utils.speaker_utils import (check_audio_list,
- sv_preprocess,
- sv_chunk,
- extract_feature,
- postprocess,
- distribute_spk)
-import funasr.modules.cnn as sv_module
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.utils.cluster_backend import ClusterBackend
-from funasr.utils.modelscope_utils import get_cache_dir
-from tqdm import tqdm
-
-def inference_asr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- mc=mc,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
-
- logging.info("uttid: {}".format(key))
- logging.info("text predictions: {}\n".format(text))
- return asr_result_list
-
- return _forward
-
-
-def inference_paraformer(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- param_dict: dict = None,
- decoding_ind: int = 0,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- export_mode = False
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- export_mode = param_dict.get("export_mode", False)
- clas_scale = param_dict.get('clas_scale', 1.0)
- else:
- hotword_list_or_file = None
- clas_scale = 1.0
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
- clas_scale=clas_scale,
- decoding_ind=decoding_ind,
- )
-
- speech2text = Speech2TextParaformer(**speech2text_kwargs)
-
- if timestamp_model_file is not None:
- speechtext2timestamp = Speech2Timestamp(
- timestamp_cmvn_file=cmvn_file,
- timestamp_model_file=timestamp_model_file,
- timestamp_infer_config=timestamp_infer_config,
- )
- else:
- speechtext2timestamp = None
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- decoding_ind = None
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- if 'hotword' in kwargs and kwargs['hotword'] is not None:
- hotword_list_or_file = kwargs['hotword']
- if hotword_list_or_file is not None or 'hotword' in kwargs:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
- if param_dict is not None and "decoding_ind" in param_dict:
- decoding_ind = param_dict["decoding_ind"]
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- if param_dict is not None:
- use_timestamp = param_dict.get('use_timestamp', True)
- else:
- use_timestamp = True
-
- forward_time_total = 0.0
- length_total = 0.0
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
- logging.info("decoding, utt_id: {}".format(keys))
- # N-best list of (text, token, token_int, hyp_object)
-
- time_beg = time.time()
- batch["decoding_ind"] = decoding_ind
- results = speech2text(**batch)
- if len(results) < 1:
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp, 10, 6, []]] * nbest
- time_end = time.time()
- forward_time = time_end - time_beg
- lfr_factor = results[0][-1]
- length = results[0][-2]
- forward_time_total += forward_time
- length_total += length
- rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
- 100 * forward_time / (
- length * lfr_factor))
- logging.info(rtf_cur)
-
- for batch_id in range(_bs):
- result = [results[batch_id][:-2]]
-
- key = keys[batch_id]
- for n, result in zip(range(1, nbest + 1), result):
- text, token, token_int, hyp = result[0], result[1], result[2], result[3]
- timestamp = result[4] if len(result[4]) > 0 else None
- # conduct timestamp prediction here
- # timestamp inference requires token length
- # thus following inference cannot be conducted in batch
- if timestamp is None and speechtext2timestamp:
- ts_batch = {}
- ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
- ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
- ts_batch['text_lengths'] = torch.tensor([len(token)])
- us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
- ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
- force_time_shift=-3.0)
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["rtf"][key] = rtf_cur
-
- if text is not None:
- if use_timestamp and timestamp is not None and len(timestamp):
- postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
- else:
- postprocessed_result = postprocess_utils.sentence_postprocess(token)
- timestamp_postprocessed = ""
- if len(postprocessed_result) == 3:
- text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
- postprocessed_result[1], \
- postprocessed_result[2]
- else:
- text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
- item = {'key': key, 'value': text_postprocessed}
- if timestamp_postprocessed != "":
- item['timestamp'] = timestamp_postprocessed
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text))
- rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
- forward_time_total,
- 100 * forward_time_total / (
- length_total * lfr_factor))
- logging.info(rtf_avg)
- if writer is not None:
- ibest_writer["rtf"]["rtf_avf"] = rtf_avg
- torch.cuda.empty_cache()
- return asr_result_list
-
- return _forward
-
-
-def inference_paraformer_vad_punc(
- maxlenratio: float=0.0,
- minlenratio: float=0.0,
- batch_size: int=1,
- beam_size: int=1,
- ngpu: int=1,
- ctc_weight: float=0.0,
- lm_weight: float=0.0,
- penalty: float=0.0,
- log_level: Union[int, str]=logging.ERROR,
- # data_path_and_name_and_type,
- asr_train_config: Optional[str]=None,
- asr_model_file: Optional[str]=None,
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 0,
- vad_infer_config: Optional[str] = None,
- vad_model_file: Optional[str] = None,
- vad_cmvn_file: Optional[str] = None,
- time_stamp_writer: bool = True,
- punc_infer_config: Optional[str] = None,
- punc_model_file: Optional[str] = None,
- outputs_dict: Optional[bool] = True,
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- language = kwargs.get("model_lang", None)
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- else:
- hotword_list_or_file = None
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
- # 3. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
- )
- speech2text = Speech2TextParaformer(**speech2text_kwargs)
- text2punc = None
- if punc_model_file is not None:
- text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- ibest_writer = writer[f"1best_recog"]
- ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
-
- if 'hotword' in kwargs:
- hotword_list_or_file = kwargs['hotword']
-
- speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
- batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
- batch_size_token = kwargs.get("batch_size_token", 6000)
- print("batch_size_token: ", batch_size_token)
-
- if speech2text.hotword_list is None:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=1,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- if param_dict is not None:
- use_timestamp = param_dict.get('use_timestamp', True)
- else:
- use_timestamp = True
-
- finish_count = 0
- file_count = 1
- lfr_factor = 6
- # 7 .Start for-loop
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- writer = None
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- beg_vad = time.time()
- vad_results = speech2vadsegment(**batch)
- end_vad = time.time()
- print("time cost vad: ", end_vad - beg_vad)
- _, vadsegments = vad_results[0], vad_results[1][0]
-
- speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
- n = len(vadsegments)
- data_with_index = [(vadsegments[i], i) for i in range(n)]
- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
- results_sorted = []
-
- if not len(sorted_data):
- key = keys[0]
- # no active segments after VAD
- if writer is not None:
- # Write empty results
- ibest_writer["token"][key] = ""
- ibest_writer["token_int"][key] = ""
- ibest_writer["vad"][key] = ""
- ibest_writer["text"][key] = ""
- ibest_writer["text_with_punc"][key] = ""
- if use_timestamp:
- ibest_writer["time_stamp"][key] = ""
-
- logging.info("decoding, utt: {}, empty speech".format(key))
- continue
-
- batch_size_token_ms = batch_size_token*60
- if speech2text.device == "cpu":
- batch_size_token_ms = 0
- if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
- batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
-
- batch_size_token_ms_cum = 0
- beg_idx = 0
- beg_asr_total = time.time()
- for j, _ in enumerate(tqdm(range(0, n))):
- batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
- if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
- continue
- batch_size_token_ms_cum = 0
- end_idx = j + 1
- speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
- beg_idx = end_idx
- batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
- batch = to_device(batch, device=device)
-
- beg_asr = time.time()
- results = speech2text(**batch)
- end_asr = time.time()
- if speech2text.device != "cpu":
- print("batch: ", speech_j.shape[0])
- print("time cost asr: ", end_asr - beg_asr)
-
- if len(results) < 1:
- results = [["", [], [], [], [], [], []]]
- results_sorted.extend(results)
- end_asr_total = time.time()
- print("total time cost asr: ", end_asr_total-beg_asr_total)
- restored_data = [0] * n
- for j in range(n):
- index = sorted_data[j][1]
- restored_data[index] = results_sorted[j]
- result = ["", [], [], [], [], [], []]
- for j in range(n):
- result[0] += restored_data[j][0]
- result[1] += restored_data[j][1]
- result[2] += restored_data[j][2]
- if len(restored_data[j][4]) > 0:
- for t in restored_data[j][4]:
- t[0] += vadsegments[j][0]
- t[1] += vadsegments[j][0]
- result[4] += restored_data[j][4]
- # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
- key = keys[0]
- # result = result_segments[0]
- text, token, token_int = result[0], result[1], result[2]
- time_stamp = result[4] if len(result[4]) > 0 else None
-
- if language == "en-bpe":
- postprocessed_result = postprocess_utils.sentence_postprocess_sentencepiece(token)
- else:
- if use_timestamp and time_stamp is not None and len(time_stamp):
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
- else:
- postprocessed_result = postprocess_utils.sentence_postprocess(token)
- text_postprocessed = ""
- time_stamp_postprocessed = ""
- text_postprocessed_punc = postprocessed_result
- if len(postprocessed_result) == 3:
- text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
- postprocessed_result[1], \
- postprocessed_result[2]
- else:
- text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
- text_postprocessed_punc = text_postprocessed
- punc_id_list = []
- if len(word_lists) > 0 and text2punc is not None:
- beg_punc = time.time()
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
- end_punc = time.time()
- print("time cost punc: ", end_punc - beg_punc)
-
- item = {'key': key, 'value': text_postprocessed_punc}
- if text_postprocessed != "":
- item['text_postprocessed'] = text_postprocessed
- if time_stamp_postprocessed != "":
- item['time_stamp'] = time_stamp_postprocessed
-
- item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["vad"][key] = "{}".format(vadsegments)
- ibest_writer["text"][key] = " ".join(word_lists)
- ibest_writer["text_with_punc"][key] = text_postprocessed_punc
- if time_stamp_postprocessed is not None:
- ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
- torch.cuda.empty_cache()
- return asr_result_list
-
- return _forward
-
-
-def inference_paraformer_vad_speaker(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- vad_infer_config: Optional[str] = None,
- vad_model_file: Optional[str] = None,
- vad_cmvn_file: Optional[str] = None,
- time_stamp_writer: bool = True,
- punc_infer_config: Optional[str] = None,
- punc_model_file: Optional[str] = None,
- sv_model_file: Optional[str] = None,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- outputs_dict: Optional[bool] = True,
- param_dict: dict = None,
-
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml")
- if not os.path.exists(sv_model_config_path):
- sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}}
- else:
- with open(sv_model_config_path, 'r') as f:
- sv_model_config = yaml.load(f, Loader=yaml.FullLoader)
- if sv_model_config['models_config'] is None:
- sv_model_config['models_config'] = {}
- sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file'])
-
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- else:
- hotword_list_or_file = None
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
- # 3. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
- )
- speech2text = Speech2TextParaformer(**speech2text_kwargs)
- text2punc = None
- if punc_model_file is not None:
- text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- ibest_writer = writer[f"1best_recog"]
- ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
-
- if 'hotword' in kwargs:
- hotword_list_or_file = kwargs['hotword']
-
- speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
- batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
- batch_size_token = kwargs.get("batch_size_token", 6000)
- print("batch_size_token: ", batch_size_token)
-
- if speech2text.hotword_list is None:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=1,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- if param_dict is not None:
- use_timestamp = param_dict.get('use_timestamp', True)
- else:
- use_timestamp = True
-
- finish_count = 0
- file_count = 1
- lfr_factor = 6
- # 7 .Start for-loop
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- writer = None
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- beg_vad = time.time()
- vad_results = speech2vadsegment(**batch)
- end_vad = time.time()
- print("time cost vad: ", end_vad - beg_vad)
- _, vadsegments = vad_results[0], vad_results[1][0]
- ##################################
- ##### speaker_verification #####
- ##################################
- # load sv model
- if ngpu > 0:
- sv_model_dict = torch.load(sv_model_file)
- sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
- sv_model.cuda()
- else:
- sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
- sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
- sv_model.load_state_dict(sv_model_dict)
- print(f'load sv model params: {sv_model_file}')
- sv_model.eval()
- cb_model = ClusterBackend()
- vad_segments = []
- audio = batch['speech'].numpy().reshape(-1)
- for vadsegment in vadsegments:
- st = int(vadsegment[0]) / 1000
- ed = int(vadsegment[1]) / 1000
- vad_segments.append(
- [st, ed, audio[int(st * 16000):int(ed * 16000)]])
- audio_dur = check_audio_list(vad_segments)
- if audio_dur > 5:
- # sv pipeline
- segments = sv_chunk(vad_segments)
- embeddings = []
- for s in segments:
- #_, embs = self.sv_pipeline([s[2]], output_emb=True)
- # embeddings.append(embs)
- wavs = sv_preprocess([s[2]])
- # embs = self.forward(wavs)
- embs = []
- for x in wavs:
- x = extract_feature([x])
- if ngpu > 0:
- x = x.cuda()
- embs.append(sv_model(x))
- embs = torch.cat(embs)
- embeddings.append(embs.cpu().detach().numpy())
- embeddings = np.concatenate(embeddings)
- labels = cb_model(embeddings)
- sv_output = postprocess(segments, vad_segments, labels, embeddings)
- else:
- # fake speaker res for too shot utterance
- sv_output = [[0.0, vadsegments[-1][-1]/1000.0, 0]]
- logging.warning("Too short utterence found: {}, return default speaker results.".format(keys))
-
- speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
- n = len(vadsegments)
- data_with_index = [(vadsegments[i], i) for i in range(n)]
- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
- results_sorted = []
-
- if not len(sorted_data):
- key = keys[0]
- # no active segments after VAD
- if writer is not None:
- # Write empty results
- ibest_writer["token"][key] = ""
- ibest_writer["token_int"][key] = ""
- ibest_writer["vad"][key] = ""
- ibest_writer["text"][key] = ""
- ibest_writer["text_with_punc"][key] = ""
- if use_timestamp:
- ibest_writer["time_stamp"][key] = ""
-
- logging.info("decoding, utt: {}, empty speech".format(key))
- continue
-
- batch_size_token_ms = batch_size_token*60
- if speech2text.device == "cpu":
- batch_size_token_ms = 0
- if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
- batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
-
- batch_size_token_ms_cum = 0
- beg_idx = 0
- beg_asr_total = time.time()
- for j, _ in enumerate(tqdm(range(0, n))):
- batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
- if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
- continue
- batch_size_token_ms_cum = 0
- end_idx = j + 1
- speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
- beg_idx = end_idx
- batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
- batch = to_device(batch, device=device)
- # print("batch: ", speech_j.shape[0])
- beg_asr = time.time()
- results = speech2text(**batch)
- end_asr = time.time()
- # print("time cost asr: ", end_asr - beg_asr)
-
- if len(results) < 1:
- results = [["", [], [], [], [], [], []]]
- results_sorted.extend(results)
- end_asr_total = time.time()
- print("total time cost asr: ", end_asr_total-beg_asr_total)
- restored_data = [0] * n
- for j in range(n):
- index = sorted_data[j][1]
- restored_data[index] = results_sorted[j]
- result = ["", [], [], [], [], [], []]
- for j in range(n):
- result[0] += restored_data[j][0]
- result[1] += restored_data[j][1]
- result[2] += restored_data[j][2]
- if len(restored_data[j][4]) > 0:
- for t in restored_data[j][4]:
- t[0] += vadsegments[j][0]
- t[1] += vadsegments[j][0]
- result[4] += restored_data[j][4]
- # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
- key = keys[0]
- # result = result_segments[0]
- text, token, token_int = result[0], result[1], result[2]
- time_stamp = result[4] if len(result[4]) > 0 else None
-
- if use_timestamp and time_stamp is not None and len(time_stamp):
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
- else:
- postprocessed_result = postprocess_utils.sentence_postprocess(token)
- text_postprocessed = ""
- time_stamp_postprocessed = ""
- text_postprocessed_punc = postprocessed_result
- if len(postprocessed_result) == 3:
- text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
- postprocessed_result[1], \
- postprocessed_result[2]
- else:
- text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
- text_postprocessed_punc = text_postprocessed
- punc_id_list = []
- if len(word_lists) > 0 and text2punc is not None:
- beg_punc = time.time()
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
- end_punc = time.time()
- print("time cost punc: ", end_punc - beg_punc)
-
- item = {'key': key, 'value': text_postprocessed_punc}
- if text_postprocessed != "":
- item['text_postprocessed'] = text_postprocessed
- if time_stamp_postprocessed != "":
- item['time_stamp'] = time_stamp_postprocessed
-
- item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["vad"][key] = "{}".format(vadsegments)
- ibest_writer["text"][key] = " ".join(word_lists)
- ibest_writer["text_with_punc"][key] = text_postprocessed_punc
- if time_stamp_postprocessed is not None:
- ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
- torch.cuda.empty_cache()
- distribute_spk(asr_result_list[0]['sentences'], sv_output)
- return asr_result_list
-
- return _forward
-
-
-def inference_paraformer_online(
- maxlenratio: float=0.0,
- minlenratio: float=0.0,
- batch_size: int=1,
- beam_size: int=1,
- ngpu: int=1,
- ctc_weight: float=0.0,
- lm_weight: float=0.0,
- penalty: float=0.0,
- log_level: Union[int, str]=logging.ERROR,
- # data_path_and_name_and_type,
- asr_train_config: Optional[str]=None,
- asr_model_file: Optional[str]=None,
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- export_mode = False
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- )
-
- speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
-
- def _load_bytes(input):
- middle_data = np.frombuffer(input, dtype=np.int16)
- middle_data = np.asarray(middle_data)
- if middle_data.dtype.kind not in 'iu':
- raise TypeError("'middle_data' must be an array of integers")
- dtype = np.dtype('float32')
- if dtype.kind != 'f':
- raise TypeError("'dtype' must be a floating point type")
-
- i = np.iinfo(middle_data.dtype)
- abs_max = 2 ** (i.bits - 1)
- offset = i.min + abs_max
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
- return array
-
- def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
- if not Path(yaml_path).exists():
- raise FileExistsError(f'The {yaml_path} does not exist.')
-
- with open(str(yaml_path), 'rb') as f:
- data = yaml.load(f, Loader=yaml.Loader)
- return data
-
- def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], encoder_chunk_look_back=0,
- decoder_chunk_look_back=0, batch_size=1):
- if len(cache) > 0:
- return cache
- config = _read_yaml(asr_train_config)
- enc_output_size = config["encoder_conf"]["output_size"]
- feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
- "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
- cache["encoder"] = cache_en
-
- cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
- cache["decoder"] = cache_de
-
- return cache
-
- def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], encoder_chunk_look_back=0,
- decoder_chunk_look_back=0, batch_size=1):
- if len(cache) > 0:
- config = _read_yaml(asr_train_config)
- enc_output_size = config["encoder_conf"]["output_size"]
- feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
- "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
- cache["encoder"] = cache_en
-
- cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
- cache["decoder"] = cache_de
-
- return cache
-
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
- raw_inputs = _load_bytes(data_path_and_name_and_type[0])
- raw_inputs = torch.tensor(raw_inputs)
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- try:
- raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
- except:
- # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
- raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
- if raw_inputs.ndim == 2:
- raw_inputs = raw_inputs[:, 0]
- raw_inputs = torch.tensor(raw_inputs)
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, np.ndarray):
- raw_inputs = torch.tensor(raw_inputs)
- is_final = False
- cache = {}
- chunk_size = [5, 10, 5]
- encoder_chunk_look_back = 0
- decoder_chunk_look_back = 0
- if param_dict is not None and "cache" in param_dict:
- cache = param_dict["cache"]
- if param_dict is not None and "is_final" in param_dict:
- is_final = param_dict["is_final"]
- if param_dict is not None and "chunk_size" in param_dict:
- chunk_size = param_dict["chunk_size"]
- if param_dict is not None and "encoder_chunk_look_back" in param_dict:
- encoder_chunk_look_back = param_dict["encoder_chunk_look_back"]
- if encoder_chunk_look_back > 0:
- chunk_size[0] = 0
- if param_dict is not None and "decoder_chunk_look_back" in param_dict:
- decoder_chunk_look_back = param_dict["decoder_chunk_look_back"]
-
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
- asr_result_list = []
- cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back,
- decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
- item = {}
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- sample_offset = 0
- speech_length = raw_inputs.shape[1]
- stride_size = chunk_size[1] * 960
- cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back,
- decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
- final_result = ""
- for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
- if sample_offset + stride_size >= speech_length - 1:
- stride_size = speech_length - sample_offset
- cache["encoder"]["is_final"] = True
- else:
- cache["encoder"]["is_final"] = False
- input_lens = torch.tensor([stride_size])
- asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
- if len(asr_result) != 0:
- final_result += " ".join(asr_result) + " "
- item = {'key': "utt", 'value': final_result.strip()}
- else:
- input_lens = torch.tensor([raw_inputs.shape[1]])
- cache["encoder"]["is_final"] = is_final
- asr_result = speech2text(cache, raw_inputs, input_lens)
- item = {'key': "utt", 'value': " ".join(asr_result)}
-
- asr_result_list.append(item)
- if is_final:
- cache = _cache_reset(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back,
- decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
- return asr_result_list
-
- return _forward
-
-
-def inference_uniasr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- ngram_file: Optional[str] = None,
- cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- if param_dict is not None and "decoding_model" in param_dict:
- if param_dict["decoding_model"] == "fast":
- decoding_ind = 0
- decoding_mode = "model1"
- elif param_dict["decoding_model"] == "normal":
- decoding_ind = 0
- decoding_mode = "model2"
- elif param_dict["decoding_model"] == "offline":
- decoding_ind = 1
- decoding_mode = "model2"
- else:
- raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- ngram_file=ngram_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- token_num_relax=token_num_relax,
- decoding_ind=decoding_ind,
- decoding_mode=decoding_mode,
- )
- speech2text = Speech2TextUniASR(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- logging.info(f"Utterance: {key}")
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
- return asr_result_list
-
- return _forward
-
-
-def inference_mfcca(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2TextMFCCA(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- fs=fs,
- mc=True,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["<space>"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
- return asr_result_list
-
- return _forward
-
-
-def inference_transducer(
- output_dir: str,
- batch_size: int,
- dtype: str,
- beam_size: int,
- ngpu: int,
- seed: int,
- lm_weight: float,
- nbest: int,
- num_workers: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- beam_search_config: Optional[dict] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- token_type: Optional[str] = None,
- bpemodel: Optional[str] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- quantize_asr_model: Optional[bool] = False,
- quantize_modules: Optional[List[str]] = None,
- quantize_dtype: Optional[str] = "float16",
- streaming: Optional[bool] = False,
- fake_streaming: Optional[bool] = False,
- full_utt: Optional[bool] = False,
- chunk_size: Optional[int] = 16,
- left_context: Optional[int] = 16,
- right_context: Optional[int] = 0,
- display_partial_hypotheses: bool = False,
- **kwargs,
-) -> None:
- """Transducer model inference.
- Args:
- output_dir: Output directory path.
- batch_size: Batch decoding size.
- dtype: Data type.
- beam_size: Beam size.
- ngpu: Number of GPUs.
- seed: Random number generator seed.
- lm_weight: Weight of language model.
- nbest: Number of final hypothesis.
- num_workers: Number of workers.
- log_level: Level of verbose for logs.
- data_path_and_name_and_type:
- asr_train_config: ASR model training config path.
- asr_model_file: ASR model path.
- beam_search_config: Beam search config path.
- lm_train_config: Language Model training config path.
- lm_file: Language Model path.
- model_tag: Model tag.
- token_type: Type of token units.
- bpemodel: BPE model path.
- key_file: File key.
- allow_variable_data_keys: Whether to allow variable data keys.
- quantize_asr_model: Whether to apply dynamic quantization to ASR model.
- quantize_modules: List of module names to apply dynamic quantization on.
- quantize_dtype: Dynamic quantization data type.
- streaming: Whether to perform chunk-by-chunk inference.
- chunk_size: Number of frames in chunk AFTER subsampling.
- left_context: Number of frames in left context AFTER subsampling.
- right_context: Number of frames in right context AFTER subsampling.
- display_partial_hypotheses: Whether to display partial hypotheses.
- """
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- beam_search_config=beam_search_config,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- dtype=dtype,
- beam_size=beam_size,
- lm_weight=lm_weight,
- nbest=nbest,
- quantize_asr_model=quantize_asr_model,
- quantize_modules=quantize_modules,
- quantize_dtype=quantize_dtype,
- streaming=streaming,
- fake_streaming=fake_streaming,
- full_utt=full_utt,
- chunk_size=chunk_size,
- left_context=left_context,
- right_context=right_context,
- )
- speech2text = Speech2TextTransducer(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
- asr_result_list = []
-
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- else:
- writer = None
-
- # 4 .Start for-loop
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
-
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
- assert len(batch.keys()) == 1
-
- try:
- if speech2text.streaming:
- speech = batch["speech"]
-
- _steps = len(speech) // speech2text._ctx
- _end = 0
- for i in range(_steps):
- _end = (i + 1) * speech2text._ctx
-
- speech2text.streaming_decode(
- speech[i * speech2text._ctx: _end + speech2text._right_ctx], is_final=False
- )
-
- final_hyps = speech2text.streaming_decode(
- speech[_end: len(speech)], is_final=True
- )
- elif speech2text.fake_streaming:
- final_hyps = speech2text.fake_streaming_decode(**batch)
- elif speech2text.full_utt:
- final_hyps = speech2text.full_utt_decode(**batch)
- else:
- final_hyps = speech2text(**batch)
-
- results = speech2text.hypotheses_to_results(final_hyps)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
- results = [[" ", ["<space>"], [2], hyp]] * nbest
-
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- item = {'key': key, 'value': text}
- asr_result_list.append(item)
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- ibest_writer["text"][key] = text
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text))
- return asr_result_list
- return _forward
-
-
-def inference_sa_asr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
-):
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2TextSAASR(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- mc=mc,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["text_id"][key] = text_id
-
- if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
-
- logging.info("uttid: {}".format(key))
- logging.info("text predictions: {}".format(text))
- logging.info("text_id predictions: {}\n".format(text_id))
- return asr_result_list
-
- return _forward
-
-def inference_whisper(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
-):
-
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if param_dict:
- language = param_dict.get("language", None)
- task = param_dict.get("task", "transcribe")
- else:
- language = None
- task = "transcribe"
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- language=language,
- task=task,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2TextWhisper(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speech2text.asr_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- mc=mc,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
-
- for n, (text, language) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["language"][key] = language
-
- if text is not None:
- item = {'key': key, 'value': text}
- asr_result_list.append(item)
- finish_count += 1
- if writer is not None:
- ibest_writer["text"][key] = text
-
- logging.info("uttid: {}".format(key))
- logging.info("text predictions: {}\n".format(text))
- return asr_result_list
-
- return _forward
-
-def inference_launch(**kwargs):
- if 'mode' in kwargs:
- mode = kwargs['mode']
- else:
- logging.info("Unknown decoding mode.")
- return None
- if mode == "asr":
- return inference_asr(**kwargs)
- elif mode == "uniasr":
- return inference_uniasr(**kwargs)
- elif mode == "paraformer":
- return inference_paraformer(**kwargs)
- elif mode == "paraformer_fake_streaming":
- return inference_paraformer(**kwargs)
- elif mode == "paraformer_streaming":
- return inference_paraformer_online(**kwargs)
- elif mode.startswith("paraformer_vad_speaker"):
- return inference_paraformer_vad_speaker(**kwargs)
- elif mode.startswith("paraformer_vad"):
- return inference_paraformer_vad_punc(**kwargs)
- elif mode == "mfcca":
- return inference_mfcca(**kwargs)
- elif mode == "rnnt":
- return inference_transducer(**kwargs)
- elif mode == "bat":
- return inference_transducer(**kwargs)
- elif mode == "sa_asr":
- return inference_sa_asr(**kwargs)
- elif mode == "whisper":
- return inference_whisper(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- from funasr.bin.argument import get_parser
- parser = get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="asr",
- help="The decoding mode",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
deleted file mode 100644
index c03bdf3..0000000
--- a/funasr/bin/build_trainer.py
+++ /dev/null
@@ -1,725 +0,0 @@
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from io import BytesIO
-
-import torch
-import yaml
-
-from funasr.build_utils.build_args import build_args
-from funasr.build_utils.build_dataloader import build_dataloader
-from funasr.build_utils.build_distributed import build_distributed
-from funasr.build_utils.build_model import build_model
-from funasr.build_utils.build_optimizer import build_optimizer
-from funasr.build_utils.build_scheduler import build_scheduler
-from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
-from funasr.modules.lora.utils import mark_only_lora_as_trainable
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
-from funasr.torch_utils.load_pretrained_model import load_pretrained_model
-from funasr.torch_utils.model_summary import model_summary
-from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.prepare_data import prepare_data
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
-
-
-def update_dct(fin_configs, root):
- if root == {}:
- return {}
- for root_key, root_value in root.items():
- if not isinstance(root[root_key], dict):
- fin_configs[root_key] = root[root_key]
- else:
- if root_key in fin_configs.keys():
- result = update_dct(fin_configs[root_key], root[root_key])
- fin_configs[root_key] = result
- else:
- fin_configs[root_key] = root[root_key]
- return fin_configs
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- description="FunASR Common Training Parser",
- )
-
- # common configuration
- parser.add_argument("--output_dir", help="model save path")
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
-
- # ddp related
- parser.add_argument(
- "--dist_backend",
- default="nccl",
- type=str,
- help="distributed backend",
- )
- parser.add_argument(
- "--dist_init_method",
- type=str,
- default="env://",
- help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
- '"WORLD_SIZE", and "RANK" are referred.',
- )
- parser.add_argument(
- "--dist_world_size",
- type=int,
- default=1,
- help="number of nodes for distributed training",
- )
- parser.add_argument(
- "--dist_rank",
- type=int,
- default=None,
- help="node rank for distributed training",
- )
- parser.add_argument(
- "--local_rank",
- type=int,
- default=None,
- help="local rank for distributed training",
- )
- parser.add_argument(
- "--dist_master_addr",
- default=None,
- type=str_or_none,
- help="The master address for distributed training. "
- "This value is used when dist_init_method == 'env://'",
- )
- parser.add_argument(
- "--dist_master_port",
- default=None,
- type=int_or_none,
- help="The master port for distributed training"
- "This value is used when dist_init_method == 'env://'",
- )
- parser.add_argument(
- "--dist_launcher",
- default=None,
- type=str_or_none,
- choices=["slurm", "mpi", None],
- help="The launcher type for distributed training",
- )
- parser.add_argument(
- "--multiprocessing_distributed",
- default=True,
- type=str2bool,
- help="Use multi-processing distributed training to launch "
- "N processes per node, which has N GPUs. This is the "
- "fastest way to use PyTorch for either single node or "
- "multi node data parallel training",
- )
- parser.add_argument(
- "--unused_parameters",
- type=str2bool,
- default=False,
- help="Whether to use the find_unused_parameters in "
- "torch.nn.parallel.DistributedDataParallel ",
- )
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
-
- # cudnn related
- parser.add_argument(
- "--cudnn_enabled",
- type=str2bool,
- default=torch.backends.cudnn.enabled,
- help="Enable CUDNN",
- )
- parser.add_argument(
- "--cudnn_benchmark",
- type=str2bool,
- default=torch.backends.cudnn.benchmark,
- help="Enable cudnn-benchmark mode",
- )
- parser.add_argument(
- "--cudnn_deterministic",
- type=str2bool,
- default=True,
- help="Enable cudnn-deterministic mode",
- )
-
- # trainer related
- parser.add_argument(
- "--max_epoch",
- type=int,
- default=40,
- help="The maximum number epoch to train",
- )
- parser.add_argument(
- "--max_update",
- type=int,
- default=sys.maxsize,
- help="The maximum number update step to train",
- )
- parser.add_argument(
- "--batch_interval",
- type=int,
- default=10000,
- help="The batch interval for saving model.",
- )
- parser.add_argument(
- "--patience",
- type=int_or_none,
- default=None,
- help="Number of epochs to wait without improvement "
- "before stopping the training",
- )
- parser.add_argument(
- "--val_scheduler_criterion",
- type=str,
- nargs=2,
- default=("valid", "loss"),
- help="The criterion used for the value given to the lr scheduler. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'and the criterion name. The mode specifying "min" or "max" can '
- "be changed by --scheduler_conf",
- )
- parser.add_argument(
- "--early_stopping_criterion",
- type=str,
- nargs=3,
- default=("valid", "loss", "min"),
- help="The criterion used for judging of early stopping. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
- )
- parser.add_argument(
- "--best_model_criterion",
- nargs="+",
- default=[
- ("train", "loss", "min"),
- ("valid", "loss", "min"),
- ("train", "acc", "max"),
- ("valid", "acc", "max"),
- ],
- help="The criterion used for judging of the best model. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
- )
- parser.add_argument(
- "--keep_nbest_models",
- type=int,
- nargs="+",
- default=[10],
- help="Remove previous snapshots excluding the n-best scored epochs",
- )
- parser.add_argument(
- "--nbest_averaging_interval",
- type=int,
- default=0,
- help="The epoch interval to apply model averaging and save nbest models",
- )
- parser.add_argument(
- "--grad_clip",
- type=float,
- default=5.0,
- help="Gradient norm threshold to clip",
- )
- parser.add_argument(
- "--grad_clip_type",
- type=float,
- default=2.0,
- help="The type of the used p-norm for gradient clip. Can be inf",
- )
- parser.add_argument(
- "--grad_noise",
- type=str2bool,
- default=False,
- help="The flag to switch to use noise injection to "
- "gradients during training",
- )
- parser.add_argument(
- "--accum_grad",
- type=int,
- default=1,
- help="The number of gradient accumulation",
- )
- parser.add_argument(
- "--resume",
- type=str2bool,
- default=False,
- help="Enable resuming if checkpoint is existing",
- )
- parser.add_argument(
- "--train_dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type for training.",
- )
- parser.add_argument(
- "--use_amp",
- type=str2bool,
- default=False,
- help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
- )
- parser.add_argument(
- "--log_interval",
- default=None,
- help="Show the logs every the number iterations in each epochs at the "
- "training phase. If None is given, it is decided according the number "
- "of training samples automatically .",
- )
- parser.add_argument(
- "--use_tensorboard",
- type=str2bool,
- default=True,
- help="Enable tensorboard logging",
- )
-
- # pretrained model related
- parser.add_argument(
- "--init_param",
- type=str,
- action="append",
- default=[],
- help="Specify the file path used for initialization of parameters. "
- "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
- "where file_path is the model file path, "
- "src_key specifies the key of model states to be used in the model file, "
- "dst_key specifies the attribute of the model to be initialized, "
- "and exclude_keys excludes keys of model states for the initialization."
- "e.g.\n"
- " # Load all parameters"
- " --init_param some/where/model.pb\n"
- " # Load only decoder parameters"
- " --init_param some/where/model.pb:decoder:decoder\n"
- " # Load only decoder parameters excluding decoder.embed"
- " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
- " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
- )
- parser.add_argument(
- "--ignore_init_mismatch",
- type=str2bool,
- default=False,
- help="Ignore size mismatch when loading pre-trained model",
- )
- parser.add_argument(
- "--freeze_param",
- type=str,
- default=[],
- action="append",
- help="Freeze parameters",
- )
-
- # dataset related
- parser.add_argument(
- "--dataset_type",
- type=str,
- default="small",
- help="whether to use dataloader for large dataset",
- )
- parser.add_argument(
- "--dataset_conf",
- action=NestedDictAction,
- default=dict(),
- help=f"The keyword arguments for dataset",
- )
- parser.add_argument(
- "--data_dir",
- type=str,
- default=None,
- help="root path of data",
- )
- parser.add_argument(
- "--train_set",
- type=str,
- default="train",
- help="train dataset",
- )
- parser.add_argument(
- "--valid_set",
- type=str,
- default="validation",
- help="dev dataset",
- )
- parser.add_argument(
- "--data_file_names",
- type=str,
- default="wav.scp,text",
- help="input data files",
- )
- parser.add_argument(
- "--speed_perturb",
- type=float,
- nargs="+",
- default=None,
- help="speed perturb",
- )
- parser.add_argument(
- "--use_preprocessor",
- type=str2bool,
- default=True,
- help="Apply preprocessing to data or not",
- )
-
- # optimization related
- parser.add_argument(
- "--optim",
- type=lambda x: x.lower(),
- default="adam",
- help="The optimizer type",
- )
- parser.add_argument(
- "--optim_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for optimizer",
- )
- parser.add_argument(
- "--scheduler",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The lr scheduler type",
- )
- parser.add_argument(
- "--scheduler_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for lr scheduler",
- )
-
- # most task related
- parser.add_argument(
- "--init",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The initialization method",
- choices=[
- "chainer",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- None,
- ],
- )
- parser.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="A text mapping int-id to token",
- )
- parser.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word"],
- help="",
- )
- parser.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model file fo sentencepiece",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Apply text cleaning",
- )
- parser.add_argument(
- "--g2p",
- type=str_or_none,
- choices=g2p_choices,
- default=None,
- help="Specify g2p method if --token_type=phn",
- )
-
- # pai related
- parser.add_argument(
- "--use_pai",
- type=str2bool,
- default=False,
- help="flag to indicate whether training on PAI",
- )
- parser.add_argument(
- "--simple_ddp",
- type=str2bool,
- default=False,
- )
- parser.add_argument(
- "--num_worker_count",
- type=int,
- default=1,
- help="The number of machines on PAI.",
- )
- parser.add_argument(
- "--access_key_id",
- type=str,
- default=None,
- help="The username for oss.",
- )
- parser.add_argument(
- "--access_key_secret",
- type=str,
- default=None,
- help="The password for oss.",
- )
- parser.add_argument(
- "--endpoint",
- type=str,
- default=None,
- help="The endpoint for oss.",
- )
- parser.add_argument(
- "--bucket_name",
- type=str,
- default=None,
- help="The bucket name for oss.",
- )
- parser.add_argument(
- "--oss_bucket",
- default=None,
- help="oss bucket.",
- )
- parser.add_argument(
- "--enable_lora",
- type=str2bool,
- default=False,
- help="Apply lora for finetuning.",
- )
- parser.add_argument(
- "--lora_bias",
- type=str,
- default="none",
- help="lora bias.",
- )
-
- return parser
-
-
-def build_trainer(modelscope_dict,
- data_dir,
- output_dir,
- train_set="train",
- dev_set="validation",
- distributed=False,
- dataset_type="small",
- batch_bins=None,
- max_epoch=None,
- optim=None,
- lr=None,
- scheduler=None,
- scheduler_conf=None,
- specaug=None,
- specaug_conf=None,
- mate_params=None,
- **kwargs):
- parser = get_parser()
- args, extra_task_params = parser.parse_known_args()
- args = build_args(args, parser, extra_task_params)
-
- if args.local_rank is not None:
- distributed = True
- else:
- distributed = False
- args.local_rank = args.local_rank if args.local_rank is not None else 0
- local_rank = args.local_rank
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
-
- config = modelscope_dict['am_model_config']
- finetune_config = modelscope_dict['finetune_config']
- init_param = modelscope_dict['init_model']
- cmvn_file = modelscope_dict['cmvn_file']
- seg_dict_file = modelscope_dict['seg_dict']
- if 'bpemodel' in modelscope_dict:
- bpemodel = modelscope_dict['bpemodel']
- else:
- bpemodel = None
-
- # overwrite parameters
- with open(config) as f:
- configs = yaml.safe_load(f)
- with open(finetune_config) as f:
- finetune_configs = yaml.safe_load(f)
- # set data_types
- if dataset_type == "large":
- # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
- if 'data_types' not in finetune_configs['dataset_conf']:
- finetune_configs["dataset_conf"]["data_types"] = "sound,text"
- finetune_configs = update_dct(configs, finetune_configs)
- for key, value in finetune_configs.items():
- if hasattr(args, key):
- setattr(args, key, value)
- if mate_params is not None:
- for key, value in mate_params.items():
- if hasattr(args, key):
- setattr(args, key, value)
- if mate_params is not None and "lora_params" in mate_params:
- lora_params = mate_params['lora_params']
- configs['encoder_conf'].update(lora_params)
- configs['decoder_conf'].update(lora_params)
- args.dataset_type = dataset_type
- args.init_param = [init_param]
- if mate_params is not None and "init_param" in mate_params:
- if len(mate_params["init_param"]) != 0:
- args.init_param = mate_params["init_param"]
- args.cmvn_file = cmvn_file
- if os.path.exists(seg_dict_file):
- args.seg_dict_file = seg_dict_file
- else:
- args.seg_dict_file = None
- if bpemodel is not None and os.path.exists(bpemodel):
- args.bpemodel = bpemodel
- else:
- args.bpemodel = None
- args.data_dir = data_dir
- args.train_set = train_set
- args.dev_set = dev_set
- args.output_dir = output_dir
- args.gpu_id = args.local_rank
- args.config = finetune_config
- args.use_pai = False
- args.batch_type = "length"
- args.oss_bucket = None
- args.input_size = None
- if distributed:
- args.distributed = True
- args.simple_ddp = True
- else:
- args.distributed = False
- args.ngpu = 1
- if optim is not None:
- args.optim = optim
- if lr is not None:
- args.optim_conf["lr"] = lr
- if scheduler is not None:
- args.scheduler = scheduler
- if scheduler_conf is not None:
- args.scheduler_conf = scheduler_conf
- if specaug is not None:
- args.specaug = specaug
- if specaug_conf is not None:
- args.specaug_conf = specaug_conf
- if max_epoch is not None:
- args.max_epoch = max_epoch
- if batch_bins is not None:
- if args.dataset_type == "small":
- args.batch_bins = batch_bins
- args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
- elif args.dataset_type == "large":
- args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- if args.normalize in ["null", "none", "None"]:
- args.normalize = None
- if args.patience in ["null", "none", "None"]:
- args.patience = None
- args.local_rank = local_rank
-
- # set random seed
- set_all_random_seed(args.seed)
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
- torch.backends.cudnn.deterministic = args.cudnn_deterministic
-
- # ddp init
- distributed_option = build_distributed(args)
-
- # for logging
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- else:
- logging.basicConfig(
- level="ERROR",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- # prepare files for dataloader
- prepare_data(args, distributed_option)
-
- model = build_model(args)
- model = model.to(
- dtype=getattr(torch, args.train_dtype),
- device="cuda" if args.ngpu > 0 else "cpu",
- )
- if args.enable_lora:
- mark_only_lora_as_trainable(model, args.lora_bias)
- for t in args.freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
-
- optimizers = build_optimizer(args, model=model)
- schedulers = build_scheduler(args, optimizers)
-
- logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
- distributed_option.dist_rank,
- distributed_option.local_rank))
- logging.info(pytorch_cudnn_version())
- logging.info("Args: {}".format(args))
- logging.info(model_summary(model))
- logging.info("Optimizer: {}".format(optimizers))
- logging.info("Scheduler: {}".format(schedulers))
-
- # dump args to config.yaml
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- os.makedirs(args.output_dir, exist_ok=True)
- with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
- logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
- if args.use_pai:
- buffer = BytesIO()
- torch.save({"config": vars(args)}, buffer)
- args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
- else:
- yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
-
- for p in args.init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=args.ignore_init_mismatch,
- map_location=f"cuda:{torch.cuda.current_device()}"
- if args.ngpu > 0
- else "cpu",
- oss_bucket=args.oss_bucket,
- )
-
- # dataloader for training/validation
- train_dataloader, valid_dataloader = build_dataloader(args)
-
- # Trainer, including model, optimizers, etc.
- trainer = build_trainer_modelscope(
- args=args,
- model=model,
- optimizers=optimizers,
- schedulers=schedulers,
- train_dataloader=train_dataloader,
- valid_dataloader=valid_dataloader,
- distributed_option=distributed_option
- )
-
- return trainer
diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py
deleted file mode 100755
index b9dbdff..0000000
--- a/funasr/bin/data2vec_train.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.data2vec import Data2VecTask
-
-
-def parse_args():
- parser = Data2VecTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for data2vec Training
- Data2VecTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
deleted file mode 100755
index bb40f5e..0000000
--- a/funasr/bin/diar_infer.py
+++ /dev/null
@@ -1,272 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import logging
-import os
-from collections import OrderedDict
-from pathlib import Path
-from typing import Any
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-from scipy.ndimage import median_filter
-from torch.nn import functional as F
-
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
-from funasr.tasks.diar import DiarTask
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils.misc import statistic_model_parameters
-
-
-class Speech2DiarizationEEND:
- """Speech2Diarlization class
-
- Examples:
- >>> import librosa
- >>> import numpy as np
- >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
- >>> profile = np.load("profiles.npy")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2diar(audio, profile)
- {"spk1": [(int, int), ...], ...}
-
- """
-
- def __init__(
- self,
- diar_train_config: Union[Path, str] = None,
- diar_model_file: Union[Path, str] = None,
- device: str = "cpu",
- dtype: str = "float32",
- ):
-
- # 1. Build Diarization model
- diar_model, diar_train_args = build_model_from_file(
- config_file=diar_train_config,
- model_file=diar_model_file,
- device=device,
- task_name="diar",
- mode="eend-ola",
- )
- frontend = None
- if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
- frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
-
- # set up seed for eda
- np.random.seed(diar_train_args.seed)
- torch.manual_seed(diar_train_args.seed)
- torch.cuda.manual_seed(diar_train_args.seed)
- os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
- logging.info("diar_model: {}".format(diar_model))
- logging.info("diar_train_args: {}".format(diar_train_args))
- diar_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.diar_model = diar_model
- self.diar_train_args = diar_train_args
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- diarization results
-
- """
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.diar_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- batch = {"speech": feats, "speech_lengths": feats_len}
- batch = to_device(batch, device=self.device)
- results = self.diar_model.estimate_sequential(**batch)
-
- return results
-
-
-class Speech2DiarizationSOND:
- """Speech2Xvector class
-
- Examples:
- >>> import librosa
- >>> import numpy as np
- >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
- >>> profile = np.load("profiles.npy")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2diar(audio, profile)
- {"spk1": [(int, int), ...], ...}
-
- """
-
- def __init__(
- self,
- diar_train_config: Union[Path, str] = None,
- diar_model_file: Union[Path, str] = None,
- device: Union[str, torch.device] = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- streaming: bool = False,
- smooth_size: int = 83,
- dur_threshold: float = 10,
- ):
-
- # TODO: 1. Build Diarization model
- diar_model, diar_train_args = build_model_from_file(
- config_file=diar_train_config,
- model_file=diar_model_file,
- device=device,
- task_name="diar",
- mode="sond",
- )
- logging.info("diar_model: {}".format(diar_model))
- logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
- logging.info("diar_train_args: {}".format(diar_train_args))
- diar_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.diar_model = diar_model
- self.diar_train_args = diar_train_args
- self.token_list = diar_train_args.token_list
- self.smooth_size = smooth_size
- self.dur_threshold = dur_threshold
- self.device = device
- self.dtype = dtype
-
- def smooth_multi_labels(self, multi_label):
- multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
- return multi_label
-
- @staticmethod
- def calc_spk_turns(label_arr, spk_list):
- turn_list = []
- length = label_arr.shape[0]
- n_spk = label_arr.shape[1]
- for k in range(n_spk):
- if spk_list[k] == "None":
- continue
- in_utt = False
- start = 0
- for i in range(length):
- if label_arr[i, k] == 1 and in_utt is False:
- start = i
- in_utt = True
- if label_arr[i, k] == 0 and in_utt is True:
- turn_list.append([spk_list[k], start, i - start])
- in_utt = False
- if in_utt:
- turn_list.append([spk_list[k], start, length - start])
- return turn_list
-
- @staticmethod
- def seq2arr(seq, vec_dim=8):
- def int2vec(x, vec_dim=8, dtype=np.int32):
- b = ('{:0' + str(vec_dim) + 'b}').format(x)
- # little-endian order: lower bit first
- return (np.array(list(b)[::-1]) == '1').astype(dtype)
-
- # process oov
- seq = np.array([int(x) for x in seq])
- new_seq = []
- for i, x in enumerate(seq):
- if x < 2 ** vec_dim:
- new_seq.append(x)
- else:
- idx_list = np.where(seq < 2 ** vec_dim)[0]
- if len(idx_list) > 0:
- idx = np.abs(idx_list - i).argmin()
- new_seq.append(seq[idx_list[idx]])
- else:
- new_seq.append(0)
- return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
-
- def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
- logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
- # upsampling outputs to match inputs
- ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
- logits_idx = F.upsample(
- logits_idx.unsqueeze(1).float(),
- size=(ut,),
- mode="nearest",
- ).squeeze(1).long()
- logits_idx = logits_idx[0].tolist()
- pse_labels = [self.token_list[x] for x in logits_idx]
- if output_format == "pse_labels":
- return pse_labels, None
-
- multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
- multi_labels = self.smooth_multi_labels(multi_labels)
- if output_format == "binary_labels":
- return multi_labels, None
-
- spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
- spk_turns = self.calc_spk_turns(multi_labels, spk_list)
- results = OrderedDict()
- for spk, st, dur in spk_turns:
- if spk not in results:
- results[spk] = []
- if dur > self.dur_threshold:
- results[spk].append((st, st + dur))
-
- # sort segments in start time ascending
- for spk in results:
- results[spk] = sorted(results[spk], key=lambda x: x[0])
-
- return results, pse_labels
-
- @torch.no_grad()
- def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- profile: Union[torch.Tensor, np.ndarray],
- output_format: str = "speaker_turn"
- ):
- """Inference
-
- Args:
- speech: Input speech data
- profile: Speaker profiles
- Returns:
- diarization results for each speaker
-
- """
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if isinstance(profile, np.ndarray):
- profile = torch.tensor(profile)
-
- # data: (Nsamples,) -> (1, Nsamples)
- speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
- # lengths: (1,)
- speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
- batch = {"speech": speech, "speech_lengths": speech_lengths,
- "profile": profile, "profile_lengths": profile_lengths}
- # a. To device
- batch = to_device(batch, device=self.device)
-
- logits = self.diar_model.prediction_forward(**batch)
- results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
-
- return results, pse_labels
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
deleted file mode 100755
index f5a11b1..0000000
--- a/funasr/bin/diar_inference_launch.py
+++ /dev/null
@@ -1,506 +0,0 @@
-# !/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-
-import argparse
-import logging
-import os
-import sys
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-# import librosa
-import librosa
-import torch
-from scipy.signal import medfilt
-
-from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
-from funasr.datasets.iterable_dataset import load_bytes
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_sond(
- diar_train_config: str,
- diar_model_file: str,
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 0,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- smooth_size: int = 83,
- dur_threshold: int = 10,
- out_format: str = "vad",
- param_dict: Optional[dict] = None,
- mode: str = "sond",
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("param_dict: {}".format(param_dict))
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2a. Build speech2xvec [Optional]
- if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
- "extract_profile"]:
- assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
- assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
- sv_train_config = param_dict["sv_train_config"]
- sv_model_file = param_dict["sv_model_file"]
- if "model_dir" in param_dict:
- sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
- sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
- from funasr.bin.sv_infer import Speech2Xvector
- speech2xvector_kwargs = dict(
- sv_train_config=sv_train_config,
- sv_model_file=sv_model_file,
- device=device,
- dtype=dtype,
- streaming=streaming,
- embedding_node="resnet1_dense"
- )
- logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
- speech2xvector.sv_model.eval()
-
- # 2b. Build speech2diar
- speech2diar_kwargs = dict(
- diar_train_config=diar_train_config,
- diar_model_file=diar_model_file,
- device=device,
- dtype=dtype,
- streaming=streaming,
- smooth_size=smooth_size,
- dur_threshold=dur_threshold,
- )
- logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
- speech2diar.diar_model.eval()
-
- def output_results_str(results: dict, uttid: str):
- rst = []
- mid = uttid.rsplit("-", 1)[0]
- for key in results:
- results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
- if out_format == "vad":
- for spk, segs in results.items():
- rst.append("{} {}".format(spk, segs))
- else:
- template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
- for spk, segs in results.items():
- rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
-
- return "\n".join(rst)
-
- def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
- ):
- logging.info("param_dict: {}".format(param_dict))
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, (list, tuple)):
- if not isinstance(raw_inputs[0], List):
- raw_inputs = [raw_inputs]
-
- assert all([len(example) >= 2 for example in raw_inputs]), \
- "The length of test case in raw_inputs must larger than 1 (>=2)."
-
- def prepare_dataset():
- for idx, example in enumerate(raw_inputs):
- # read waveform file
- example = [load_bytes(x) if isinstance(x, bytes) else x
- for x in example]
- # example = [librosa.load(x)[0] if isinstance(x, str) else x
- # for x in example]
- example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
- for x in example]
- # convert torch tensor to numpy array
- example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
- for x in example]
- speech = example[0]
- logging.info("Extracting profiles for {} waveforms".format(len(example) - 1))
- profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
- profile = torch.cat(profile, dim=0)
- yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
-
- loader = prepare_dataset()
- else:
- raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
- else:
- # 3. Build data-iterator
- loader = build_streaming_iterator(
- task_name="diar",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- use_collate_fn=False,
- )
-
- # 7. Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- os.makedirs(output_path, exist_ok=True)
- output_writer = open("{}/result.txt".format(output_path), "w")
- pse_label_writer = open("{}/labels.txt".format(output_path), "w")
- logging.info("Start to diarize...")
- result_list = []
- for idx, (keys, batch) in enumerate(loader):
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- results, pse_labels = speech2diar(**batch)
- # Only supporting batch_size==1
- key, value = keys[0], output_results_str(results, keys[0])
- item = {"key": key, "value": value}
- result_list.append(item)
- if output_path is not None:
- output_writer.write(value)
- output_writer.flush()
- pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
- pse_label_writer.flush()
-
- if idx % 100 == 0:
- logging.info("Processing {:5d}: {}".format(idx, key))
-
- if output_path is not None:
- output_writer.close()
- pse_label_writer.close()
-
- return result_list
-
- return _forward
-
-
-def inference_eend(
- diar_train_config: str,
- diar_model_file: str,
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- param_dict: Optional[dict] = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("param_dict: {}".format(param_dict))
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Build speech2diar
- speech2diar_kwargs = dict(
- diar_train_config=diar_train_config,
- diar_model_file=diar_model_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
- speech2diar.diar_model.eval()
-
- def output_results_str(results: dict, uttid: str):
- rst = []
- mid = uttid.rsplit("-", 1)[0]
- for key in results:
- results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
- template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
- for spk, segs in results.items():
- rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
-
- return "\n".join(rst)
-
- def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
- ):
- # 2. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
- loader = build_streaming_iterator(
- task_name="diar",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- # 3. Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- os.makedirs(output_path, exist_ok=True)
- output_writer = open("{}/result.txt".format(output_path), "w")
- result_list = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- results = speech2diar(**batch)
-
- # post process
- a = results[0][0].cpu().numpy()
- a = medfilt(a, (11, 1))
- rst = []
- for spkid, frames in enumerate(a.T):
- frames = np.pad(frames, (1, 1), 'constant')
- changes, = np.where(np.diff(frames, axis=0) != 0)
- fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
- for s, e in zip(changes[::2], changes[1::2]):
- st = s / 10.
- dur = (e - s) / 10.
- rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
-
- # Only supporting batch_size==1
- value = "\n".join(rst)
- item = {"key": keys[0], "value": value}
- result_list.append(item)
- if output_path is not None:
- output_writer.write(value)
- output_writer.flush()
-
- if output_path is not None:
- output_writer.close()
-
- return result_list
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "sond":
- return inference_sond(mode=mode, **kwargs)
- elif mode == "sond_demo":
- param_dict = {
- "extract_profile": True,
- "sv_train_config": "sv.yaml",
- "sv_model_file": "sv.pb",
- }
- if "param_dict" in kwargs and kwargs["param_dict"] is not None:
- for key in param_dict:
- if key not in kwargs["param_dict"]:
- kwargs["param_dict"][key] = param_dict[key]
- else:
- kwargs["param_dict"] = param_dict
- return inference_sond(mode=mode, **kwargs)
- elif mode == "eend-ola":
- return inference_eend(mode=mode, **kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Speaker Verification",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--diar_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--diar_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global CMVN file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("The inference configuration related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument(
- "--smooth_size",
- type=int,
- default=121,
- help="The smoothing size for post-processing"
- )
- group.add_argument(
- "--dur_threshold",
- type=int,
- default=10,
- help="The threshold of minimum duration"
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="sond",
- help="The decoding mode",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/diar_train.py b/funasr/bin/diar_train.py
deleted file mode 100755
index 16a4bd0..0000000
--- a/funasr/bin/diar_train.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.diar import DiarTask
-
-
-# for ASR Training
-def parse_args():
- parser = DiarTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- DiarTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/export/export_model.py b/funasr/bin/export_model.py
similarity index 100%
rename from funasr/export/export_model.py
rename to funasr/bin/export_model.py
diff --git a/funasr/bin/inference_cli.py b/funasr/bin/inference_cli.py
deleted file mode 100644
index f4c66f1..0000000
--- a/funasr/bin/inference_cli.py
+++ /dev/null
@@ -1,139 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-import logging
-import torch
-import numpy as np
-from funasr.utils.download_and_prepare_model import prepare_model
-
-from funasr.utils.types import str2bool
-
-def infer(task_name: str = "asr",
- model: str = None,
- # mode: str = None,
- vad_model: str = None,
- disable_vad: bool = False,
- punc_model: str = None,
- disable_punc: bool = False,
- model_hub: str = "ms",
- cache_dir: str = None,
- **kwargs,
- ):
-
- # set logging messages
- logging.basicConfig(
- level=logging.ERROR,
- )
-
- model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
- if task_name == "asr":
- from funasr.bin.asr_inference_launch import inference_launch
-
- inference_pipeline = inference_launch(**kwargs)
- elif task_name == "":
- pipeline = 1
- elif task_name == "":
- pipeline = 2
- elif task_name == "":
- pipeline = 2
-
- def _infer_fn(input, **kwargs):
- data_type = kwargs.get('data_type', 'sound')
- data_path_and_name_and_type = [input, 'speech', data_type]
- raw_inputs = None
- if isinstance(input, torch.Tensor):
- input = input.numpy()
- if isinstance(input, np.ndarray):
- data_path_and_name_and_type = None
- raw_inputs = input
-
- return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
-
- return _infer_fn
-
-
-def main(cmd=None):
- # print(get_commandline_args(), file=sys.stderr)
- from funasr.bin.argument import get_parser
-
- parser = get_parser()
- parser.add_argument('input', help='input file to transcribe')
- parser.add_argument(
- "--task_name",
- type=str,
- default="asr",
- help="The decoding mode",
- )
- parser.add_argument(
- "-m",
- "--model",
- type=str,
- default="paraformer-zh",
- help="The asr mode name",
- )
- parser.add_argument(
- "-v",
- "--vad_model",
- type=str,
- default="fsmn-vad",
- help="vad model name",
- )
- parser.add_argument(
- "-dv",
- "--disable_vad",
- type=str2bool,
- default=False,
- help="",
- )
- parser.add_argument(
- "-p",
- "--punc_model",
- type=str,
- default="ct-punc",
- help="",
- )
- parser.add_argument(
- "-dp",
- "--disable_punc",
- type=str2bool,
- default=False,
- help="",
- )
- parser.add_argument(
- "--batch_size_token",
- type=int,
- default=5000,
- help="",
- )
- parser.add_argument(
- "--batch_size_token_threshold_s",
- type=int,
- default=35,
- help="",
- )
- parser.add_argument(
- "--max_single_segment_time",
- type=int,
- default=5000,
- help="",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
-
- # set logging messages
- logging.basicConfig(
- level=logging.ERROR,
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # kwargs["ncpu"] = 2 #os.cpu_count()
- kwargs.pop("data_path_and_name_and_type")
- print("args: {}".format(kwargs))
- p = infer(**kwargs)
-
- res = p(**kwargs)
- print(res)
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
deleted file mode 100644
index f12f50a..0000000
--- a/funasr/bin/lm_inference_launch.py
+++ /dev/null
@@ -1,392 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_lm(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float] = 10,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build Model
- model, train_args = build_model_from_file(
- train_config, model_file, None, device, "lm")
- wrapped_model = ForwardAdaptor(model, "nll")
- wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- logging.info(f"Model:\n{model}")
-
- preprocessor = LMPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file
- )
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: dict = None,
- ):
- results = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "lm demo"
- if line == "":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- batch = {}
- batch['text'] = line
- if preprocessor != None:
- batch = preprocessor(key, batch)
-
- # Force data-precision
- for name in batch:
- value = batch[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f"All values must be converted to np.ndarray object "
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
- # Cast to desired type
- if value.dtype.kind == "f":
- value = value.astype("float32")
- elif value.dtype.kind == "i":
- value = value.astype("long")
- else:
- raise NotImplementedError(f"Not supported dtype: {value.dtype}")
- batch[name] = value
-
- batch["text_lengths"] = torch.from_numpy(
- np.array([len(batch["text"])], dtype='int32'))
- batch["text"] = np.expand_dims(batch["text"], axis=0)
-
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
- ## compute ppl
- ppl_out_batch = ""
- ids2tokens = preprocessor.token_id_converter.ids2tokens
- for sent_ids, sent_nll in zip(batch['text'], nll):
- pre_word = "<s>"
- cur_word = None
- sent_lst = ids2tokens(sent_ids) + ['</s>']
- ppl_out = " ".join(sent_lst) + "\n"
- for word, word_nll in zip(sent_lst, sent_nll):
- cur_word = word
- word_nll = -word_nll.cpu()
- if log_base is None:
- word_prob = np.exp(word_nll)
- else:
- word_prob = log_base ** (word_nll / np.log(log_base))
- ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
- cur=cur_word,
- pre=pre_word,
- prob=round(word_prob.item(), 8),
- word_nll=round(word_nll.item(), 8)
- )
- pre_word = cur_word
-
- sent_nll_mean = sent_nll.mean().cpu().numpy()
- sent_nll_sum = sent_nll.sum().cpu().numpy()
- if log_base is None:
- sent_ppl = np.exp(sent_nll_mean)
- else:
- sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
- ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
- sent_nll=round(-sent_nll_sum.item(), 4),
- sent_ppl=round(sent_ppl.item(), 4)
- )
- ppl_out_batch += ppl_out
- item = {'key': key, 'value': ppl_out}
- if writer is not None:
- writer["ppl"][key + ":\n"] = ppl_out
- results.append(item)
-
- return results
-
- # 3. Build data-iterator
- loader = build_streaming_iterator(
- task_name="lm",
- preprocess_args=train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- preprocess_fn=preprocessor,
- num_workers=num_workers,
- )
-
- # 4. Start for-loop
- total_nll = 0.0
- total_ntokens = 0
- ppl_out_all = ""
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- ppl_out_batch = ""
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- # NOTE(kamo): data_parallel also should work with ngpu=1,
- # but for debuggability it's better to keep this block.
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
- ## print ppl
- ids2tokens = preprocessor.token_id_converter.ids2tokens
- for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
- pre_word = "<s>"
- cur_word = None
- sent_lst = ids2tokens(sent_ids) + ['</s>']
- ppl_out = " ".join(sent_lst) + "\n"
- for word, word_nll in zip(sent_lst, sent_nll):
- cur_word = word
- word_nll = -word_nll.cpu()
- if log_base is None:
- word_prob = np.exp(word_nll)
- else:
- word_prob = log_base ** (word_nll / np.log(log_base))
- ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
- cur=cur_word,
- pre=pre_word,
- prob=round(word_prob.item(), 8),
- word_nll=round(word_nll.item(), 8)
- )
- pre_word = cur_word
-
- sent_nll_mean = sent_nll.mean().cpu().numpy()
- sent_nll_sum = sent_nll.sum().cpu().numpy()
- if log_base is None:
- sent_ppl = np.exp(sent_nll_mean)
- else:
- sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
- ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
- sent_nll=round(-sent_nll_sum.item(), 4),
- sent_ppl=round(sent_ppl.item(), 4)
- )
- ppl_out_batch += ppl_out
- utt2nll = round(-sent_nll_sum.item(), 5)
- item = {'key': key, 'value': ppl_out}
- if writer is not None:
- writer["ppl"][key + ":\n"] = ppl_out
- writer["utt2nll"][key] = str(utt2nll)
- results.append(item)
-
- ppl_out_all += ppl_out_batch
-
- assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
- # nll: (B, L) -> (B,)
- nll = nll.detach().cpu().numpy().sum(1)
- # lengths: (B,)
- lengths = lengths.detach().cpu().numpy()
- total_nll += nll.sum()
- total_ntokens += lengths.sum()
-
- if log_base is None:
- ppl = np.exp(total_nll / total_ntokens)
- else:
- ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
- avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
- total_nll=round(-total_nll.item(), 4),
- total_ppl=round(ppl.item(), 4)
- )
- item = {'key': 'AVG PPL', 'value': avg_ppl}
- ppl_out_all += avg_ppl
- if writer is not None:
- writer["ppl"]["AVG PPL : "] = avg_ppl
- results.append(item)
-
- return results
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "transformer":
- return inference_lm(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Calc perplexity",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument("--gpuid_list", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument("--njob", type=int, default=1, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument(
- "--log_base",
- type=float_or_none,
- default=10,
- help="The base of logarithm for Perplexity. "
- "If None, napier's constant is used.",
- required=False
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- required=False
- )
- group.add_argument(
- "--raw_inputs",
- type=str,
- required=False
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group.add_argument("--split_with_space", type=str2bool, default=False)
- group.add_argument("--seg_dict_file", type=str_or_none)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
- group.add_argument("--mode", type=str, default="lm")
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- kwargs.pop("gpuid_list", None)
- kwargs.pop("njob", None)
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
deleted file mode 100755
index 22b5f9c..0000000
--- a/funasr/bin/lm_train.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.lm import LMTask
-
-
-# for LM Training
-def parse_args():
- parser = LMTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for LM Training
- LMTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small" and args.ngpu != 0:
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
deleted file mode 100644
index 9efeb5b..0000000
--- a/funasr/bin/punc_infer.py
+++ /dev/null
@@ -1,282 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-import os
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.datasets.preprocessor import split_to_mini_sentence
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-
-
-class Text2Punc:
-
- def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
- ):
- # Build Model
- model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
- self.device = device
- # Wrape model to make model.nll() data-parallel
- self.wrapped_model = ForwardAdaptor(model, "inference")
- self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- # logging.info(f"Model:\n{model}")
- self.punc_list = train_args.punc_list
- self.period = 0
- for i in range(len(self.punc_list)):
- if self.punc_list[i] == ",":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "?":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "銆�":
- self.period = i
- self.seg_dict_file = None
- self.seg_jieba = False
- if "seg_jieba" in train_args:
- self.seg_jieba = train_args.seg_jieba
- self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
- self.preprocessor = CodeMixTokenizerCommonPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- seg_jieba=self.seg_jieba,
- seg_dict_file=self.seg_dict_file
- )
-
- @torch.no_grad()
- def __call__(self, text: Union[list, str], split_size=20):
- data = {"text": text}
- result = self.preprocessor(data=data, uid="12938712838719")
- split_text = self.preprocessor.pop_split_text_data(result)
- mini_sentences = split_to_mini_sentence(split_text, split_size)
- mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
- assert len(mini_sentences) == len(mini_sentences_id)
- cache_sent = []
- cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
- new_mini_sentence = ""
- new_mini_sentence_punc = []
- cache_pop_trigger_limit = 200
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence_id = mini_sentences_id[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
- data = {
- "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
- "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- }
- data = to_device(data, self.device)
- y, _ = self.wrapped_model(**data)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences) - 1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations) - 2, 1, -1):
- if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
-
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = self.period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
-
- # if len(punctuations) == 0:
- # continue
-
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += [int(x) for x in punctuations_np]
- words_with_punc = []
- for i in range(len(mini_sentence)):
- if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
- mini_sentence[i] = mini_sentence[i].capitalize()
- if i == 0:
- if len(mini_sentence[i][0].encode()) == 1:
- mini_sentence[i] = " " + mini_sentence[i]
- if i > 0:
- if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
- mini_sentence[i] = " " + mini_sentence[i]
- words_with_punc.append(mini_sentence[i])
- if self.punc_list[punctuations[i]] != "_":
- punc_res = self.punc_list[punctuations[i]]
- if len(mini_sentence[i][0].encode()) == 1:
- if punc_res == "锛�":
- punc_res = ","
- elif punc_res == "銆�":
- punc_res = "."
- elif punc_res == "锛�":
- punc_res = "?"
- words_with_punc.append(punc_res)
- new_mini_sentence += "".join(words_with_punc)
- # Add Period for the end of the sentence
- new_mini_sentence_out = new_mini_sentence
- new_mini_sentence_punc_out = new_mini_sentence_punc
- if mini_sentence_i == len(mini_sentences) - 1:
- if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
- new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- elif new_mini_sentence[-1] == ",":
- new_mini_sentence_out = new_mini_sentence[:-1] + "."
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
- new_mini_sentence_out = new_mini_sentence + "銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
- new_mini_sentence_out = new_mini_sentence + "."
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- return new_mini_sentence_out, new_mini_sentence_punc_out
-
-
-class Text2PuncVADRealtime:
-
- def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
- ):
- # Build Model
- model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
- self.device = device
- # Wrape model to make model.nll() data-parallel
- self.wrapped_model = ForwardAdaptor(model, "inference")
- self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- # logging.info(f"Model:\n{model}")
- self.punc_list = train_args.punc_list
- self.period = 0
- for i in range(len(self.punc_list)):
- if self.punc_list[i] == ",":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "?":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "銆�":
- self.period = i
- self.preprocessor = CodeMixTokenizerCommonPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- )
-
- @torch.no_grad()
- def __call__(self, text: Union[list, str], cache: list, split_size=20):
- if cache is not None and len(cache) > 0:
- precache = "".join(cache)
- else:
- precache = ""
- cache = []
- data = {"text": precache + " " + text}
- result = self.preprocessor(data=data, uid="12938712838719")
- split_text = self.preprocessor.pop_split_text_data(result)
- mini_sentences = split_to_mini_sentence(split_text, split_size)
- mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
- assert len(mini_sentences) == len(mini_sentences_id)
- cache_sent = []
- cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
- sentence_punc_list = []
- sentence_words_list = []
- cache_pop_trigger_limit = 200
- skip_num = 0
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence_id = mini_sentences_id[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
- data = {
- "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
- "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
- }
- data = to_device(data, self.device)
- y, _ = self.wrapped_model(**data)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences) - 1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations) - 2, 1, -1):
- if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
-
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = self.period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
-
- punctuations_np = punctuations.cpu().numpy()
- sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
- sentence_words_list += mini_sentence
-
- assert len(sentence_punc_list) == len(sentence_words_list)
- words_with_punc = []
- sentence_punc_list_out = []
- for i in range(0, len(sentence_words_list)):
- if i > 0:
- if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
- sentence_words_list[i] = " " + sentence_words_list[i]
- if skip_num < len(cache):
- skip_num += 1
- else:
- words_with_punc.append(sentence_words_list[i])
- if skip_num >= len(cache):
- sentence_punc_list_out.append(sentence_punc_list[i])
- if sentence_punc_list[i] != "_":
- words_with_punc.append(sentence_punc_list[i])
- sentence_out = "".join(words_with_punc)
-
- sentenceEnd = -1
- for i in range(len(sentence_punc_list) - 2, 1, -1):
- if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
- sentenceEnd = i
- break
- cache_out = sentence_words_list[sentenceEnd + 1:]
- if sentence_out[-1] in self.punc_list:
- sentence_out = sentence_out[:-1]
- sentence_punc_list_out[-1] = "_"
- return sentence_out, sentence_punc_list_out, cache_out
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
deleted file mode 100755
index 5d917f5..0000000
--- a/funasr/bin/punc_inference_launch.py
+++ /dev/null
@@ -1,252 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Union
-
-import torch
-
-from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_punc(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
- text2punc = Text2Punc(train_config, model_file, device)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
- ):
- results = []
- split_size = 20
-
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "demo"
- if line == "":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- result, _ = text2punc(line)
- item = {'key': key, 'value': result}
- results.append(item)
- return results
-
- for inference_text, _, _ in data_path_and_name_and_type:
- with open(inference_text, "r", encoding="utf-8") as fin:
- for line in fin:
- line = line.strip()
- segs = line.split("\t")
- if len(segs) != 2:
- continue
- key = segs[0]
- if len(segs[1]) == 0:
- continue
- result, _ = text2punc(segs[1])
- item = {'key': key, 'value': result}
- results.append(item)
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path != None:
- output_file_name = "infer.out"
- Path(output_path).mkdir(parents=True, exist_ok=True)
- output_file_path = (Path(output_path) / output_file_name).absolute()
- with open(output_file_path, "w", encoding="utf-8") as fout:
- for item_i in results:
- key_out = item_i["key"]
- value_out = item_i["value"]
- fout.write(f"{key_out}\t{value_out}\n")
- return results
-
- return _forward
-
-
-def inference_punc_vad_realtime(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- # cache: list,
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
- text2punc = Text2PuncVADRealtime(train_config, model_file, device)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
- ):
- results = []
- split_size = 10
- cache_in = param_dict["cache"]
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "demo"
- if line == "":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- result, _, cache = text2punc(line, cache_in)
- param_dict["cache"] = cache
- item = {'key': key, 'value': result}
- results.append(item)
- return results
-
- return results
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "punc":
- return inference_punc(**kwargs)
- if mode == "punc_VadRealtime":
- return inference_punc_vad_realtime(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Punctuation inference",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument("--gpuid_list", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument("--njob", type=int, default=1, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
- group.add_argument("--raw_inputs", type=str, required=False)
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--cache", type=list, required=False)
- group.add_argument("--param_dict", type=dict, required=False)
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
- group.add_argument("--mode", type=str, default="punc")
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- kwargs.pop("gpuid_list", None)
- kwargs.pop("njob", None)
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py
deleted file mode 100644
index c3cbee9..0000000
--- a/funasr/bin/punc_train.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-from funasr.tasks.punctuation import PunctuationTask
-
-
-def parse_args():
- parser = PunctuationTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- parser.add_argument(
- "--punc_list",
- type=str,
- default=None,
- help="Punctuation list",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- """
- punc training.
- """
- PunctuationTask.main(args=args, cmd=cmd)
-
-
-if __name__ == "__main__":
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
-
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu * args.num_worker_count
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu * args.num_worker_count
-
- main(args=args)
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
deleted file mode 100755
index 67106cf..0000000
--- a/funasr/bin/sa_asr_train.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.sa_asr import ASRTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- if args.ngpu > 0:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None and args.ngpu > 0:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None and args.ngpu > 0:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py
deleted file mode 100644
index a3eca11..0000000
--- a/funasr/bin/ss_infer.py
+++ /dev/null
@@ -1,127 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-
-import logging
-from pathlib import Path
-from typing import List
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.torch_utils.device_funcs import to_device
-
-
-class SpeechSeparator:
- """SpeechSeparator class
-
- Examples:
- >>> import librosa
- >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
- >>> audio, rate = librosa.load("speech.wav")
- >>> separated_wavs = speech_separator(audio)
-
- """
-
- def __init__(
- self,
- ss_infer_config: Union[Path, str] = None,
- ss_model_file: Union[Path, str] = None,
- device: str = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- **kwargs,
- ):
-
- # 1. Build ss model
- ss_model, ss_infer_args = build_model_from_file(
- ss_infer_config, ss_model_file, None, device, task_name="ss"
- )
-
- logging.info("ss_model: {}".format(ss_model))
- logging.info("ss_infer_args: {}".format(ss_infer_args))
-
- ss_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.ss_model = ss_model
- self.ss_infer_args = ss_infer_args
- self.device = device
- self.dtype = dtype
- self.batch_size = batch_size
-
- def decode(self, model, args, inputs, nsamples):
- decode_do_segment = False
- with torch.no_grad():
- out = []
- window = args.sample_rate * args.decode_window # decoding window length
- stride = int(window*0.75) # decoding stride if segmentation is used
- b, t = inputs.shape
- if t > window * args.one_time_decode_length:
- decode_do_segment = True # set segment decoding to true for very long sequence
-
- if t < window:
- inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
- elif t < window + stride:
- padding = window + stride - t
- inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
- else:
- if (t - window) % stride != 0:
- padding = t - (t-window)//stride * stride
- inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
- inputs = torch.from_numpy(np.float32(inputs))
- inputs = to_device(inputs, device=self.device)
- b, t = inputs.shape
- if decode_do_segment:
- outputs = np.zeros((args.num_spks, t))
- give_up_length = (window - stride)//2
- current_idx = 0
- while current_idx + window <= t:
- tmp_input = inputs[:, current_idx:current_idx+window]
- tmp_out_list = model(tmp_input,)
- for spk in range(args.num_spks):
- tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
- if current_idx == 0:
- outputs[spk, current_idx:current_idx+window-give_up_length] = \
- tmp_out_list[spk][:-give_up_length]
- else:
- outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
- tmp_out_list[spk][give_up_length:-give_up_length]
- current_idx += stride
- for spk in range(args.num_spks):
- out.append(outputs[spk, :])
- else:
- out_list = model(inputs)
- for spk in range(args.num_spks):
- out.append(out_list[spk][0, :].cpu().numpy())
-
- max_abs = 0
- for spk in range(args.num_spks):
- if max_abs < max(abs(out[spk])):
- max_abs = max(abs(out[spk]))
- for spk in range(args.num_spks):
- out[spk] = out[spk][:nsamples]
- out[spk] = out[spk]/max_abs
-
- return out
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- ) -> List[torch.Tensor]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- speech list: list of speech data
-
- """
-
- out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
-
- return out
-
diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py
deleted file mode 100644
index 0c02419..0000000
--- a/funasr/bin/ss_inference_launch.py
+++ /dev/null
@@ -1,258 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-
-import argparse
-import logging
-import os
-import sys
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-import librosa
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2triple_str
-from funasr.bin.ss_infer import SpeechSeparator
-
-
-def inference_ss(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- ss_infer_config: Optional[str],
- ss_model_file: Optional[str],
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- num_spks: int = 2,
- sample_rate: int = 8000,
- param_dict: dict = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech separator
- speech_separator_kwargs = dict(
- ss_infer_config=ss_infer_config,
- ss_model_file=ss_model_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
- speech_separator = SpeechSeparator(**speech_separator_kwargs)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="ss",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- num_workers=num_workers,
- )
-
- # 4 .Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if not os.path.exists(output_path):
- cmd = 'mkdir -p ' + output_path
- os.system(cmd)
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- # do speech separation
- logging.info('decoding: {}'.format(keys[0]))
- ss_results = speech_separator(**batch)
-
- for spk in range(num_spks):
- # sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
- try:
- librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
- except:
- print("To write wav by librosa, you should install librosa<=0.8.0")
- raise
- torch.cuda.empty_cache()
- return ss_results
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "mossformer":
- return inference_ss(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Speech Separator Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=1,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="2",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--ss_infer_config",
- type=str,
- help="SS infer configuration",
- )
- group.add_argument(
- "--ss_model_file",
- type=str,
- help="SS model parameter file",
- )
- group.add_argument(
- "--ss_train_config",
- type=str,
- help="SS training configuration",
- )
-
- group = parser.add_argument_group("The inference configuration related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
-
- parser.add_argument(
- '--num-spks', dest='num_spks', type=int, default=2)
-
- parser.add_argument(
- '--one-time-decode-length', dest='one_time_decode_length', type=int,
- default=60, help='the max length (second) for one-time decoding')
-
- parser.add_argument(
- '--decode-window', dest='decode_window', type=int,
- default=1, help='segmental decoding window length (second)')
-
- parser.add_argument(
- '--sample-rate', dest='sample_rate', type=int, default='8000')
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="mossformer",
- help="The decoding mode",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
-
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
deleted file mode 100755
index 19cfc2e..0000000
--- a/funasr/bin/sv_infer.py
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import logging
-from pathlib import Path
-from typing import Any
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils.misc import statistic_model_parameters
-
-
-class Speech2Xvector:
- """Speech2Xvector class
-
- Examples:
- >>> import librosa
- >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2xvector(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- sv_train_config: Union[Path, str] = None,
- sv_model_file: Union[Path, str] = None,
- device: str = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- ):
-
- # TODO: 1. Build SV model
- sv_model, sv_train_args = build_model_from_file(
- config_file=sv_train_config,
- model_file=sv_model_file,
- cmvn_file=None,
- device=device,
- task_name="sv",
- mode="sv",
- )
- logging.info("sv_model: {}".format(sv_model))
- logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
- logging.info("sv_train_args: {}".format(sv_train_args))
- sv_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.sv_model = sv_model
- self.sv_train_args = sv_train_args
- self.device = device
- self.dtype = dtype
- self.embedding_node = embedding_node
-
- @torch.no_grad()
- def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- # data: (Nsamples,) -> (1, Nsamples)
- speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- # lengths: (1,)
- lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- batch = {"speech": speech, "speech_lengths": lengths}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, ilens = self.sv_model.encode(**batch)
-
- # c. Forward Pooling
- pooling = self.sv_model.pooling_layer(enc)
-
- # d. Forward Decoder
- outputs, embeddings = self.sv_model.decoder(pooling)
-
- if self.embedding_node not in embeddings:
- raise ValueError("Required embedding node {} not in {}".format(
- self.embedding_node, embeddings.keys()))
-
- return embeddings[self.embedding_node]
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray],
- ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
- ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
- """Inference
-
- Args:
- speech: Input speech data
- ref_speech: Reference speech to compare
- Returns:
- embedding, ref_embedding, similarity_score
-
- """
- self.sv_model.eval()
- embedding = self.calculate_embedding(speech)
- ref_emb, score = None, None
- if ref_speech is not None:
- ref_emb = self.calculate_embedding(ref_speech)
- score = torch.cosine_similarity(embedding, ref_emb)
-
- results = (embedding, ref_emb, score)
- return results
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
deleted file mode 100755
index 2f9e276..0000000
--- a/funasr/bin/sv_inference_launch.py
+++ /dev/null
@@ -1,309 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from kaldiio import WriteHelper
-
-from funasr.bin.sv_infer import Speech2Xvector
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_sv(
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- sv_train_config: Optional[str] = "sv.yaml",
- sv_model_file: Optional[str] = "sv.pb",
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- param_dict: Optional[dict] = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("param_dict: {}".format(param_dict))
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2xvector
- speech2xvector_kwargs = dict(
- sv_train_config=sv_train_config,
- sv_model_file=sv_model_file,
- device=device,
- dtype=dtype,
- streaming=streaming,
- embedding_node=embedding_node
- )
- logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
- speech2xvector.sv_model.eval()
-
- def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
- ):
- logging.info("param_dict: {}".format(param_dict))
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
- # 3. Build data-iterator
- loader = build_streaming_iterator(
- task_name="sv",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- use_collate_fn=False,
- )
-
- # 7 .Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- embd_writer, ref_embd_writer, score_writer = None, None, None
- if output_path is not None:
- os.makedirs(output_path, exist_ok=True)
- embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
- sv_result_list = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- embedding, ref_embedding, score = speech2xvector(**batch)
- # Only supporting batch_size==1
- key = keys[0]
- normalized_score = 0.0
- if score is not None:
- score = score.item()
- normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
- item = {"key": key, "value": normalized_score}
- else:
- item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
- sv_result_list.append(item)
- if output_path is not None:
- embd_writer(key, embedding[0].cpu().numpy())
- if ref_embedding is not None:
- if ref_embd_writer is None:
- ref_embd_writer = WriteHelper(
- "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
- )
- score_writer = open(os.path.join(output_path, "score.txt"), "w")
- ref_embd_writer(key, ref_embedding[0].cpu().numpy())
- score_writer.write("{} {:.6f}\n".format(key, normalized_score))
-
- if output_path is not None:
- embd_writer.close()
- if ref_embd_writer is not None:
- ref_embd_writer.close()
- score_writer.close()
-
- return sv_result_list
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "sv":
- return inference_sv(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Speaker Verification",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--sv_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--sv_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global CMVN file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("The inference configuration related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument(
- "--sv_threshold",
- type=float,
- default=0.9465,
- help="The threshold for verification"
- )
- parser.add_argument(
- "--embedding_node",
- type=str,
- default="resnet1_dense",
- help="The network node to extract embedding"
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="sv",
- help="The decoding mode",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
deleted file mode 100644
index cfe534f..0000000
--- a/funasr/bin/tp_infer.py
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import logging
-from pathlib import Path
-from typing import Union
-
-import numpy as np
-import torch
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-
-
-class Speech2Timestamp:
- def __init__(
- self,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- timestamp_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- dtype: str = "float32",
- **kwargs,
- ):
- # 1. Build ASR model
- tp_model, tp_train_args = build_model_from_file(
- timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
- )
- if 'cuda' in device:
- tp_model = tp_model.cuda() # force model to cuda
-
- frontend = None
- if tp_train_args.frontend is not None:
- frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
-
- logging.info("tp_model: {}".format(tp_model))
- logging.info("tp_train_args: {}".format(tp_train_args))
- tp_model.to(dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- self.tp_model = tp_model
- self.tp_train_args = tp_train_args
-
- token_list = self.tp_model.token_list
- self.converter = TokenIDConverter(token_list=token_list)
-
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if tp_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- text_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.tp_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
-
- # lfr_factor = max(1, (feats.size()[-1]//80)-1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, enc_len = self.tp_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
-
- # c. Forward Predictor
- _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
- text_lengths.to(self.device) + 1)
- return us_alphas, us_peaks
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
deleted file mode 100644
index 6c10254..0000000
--- a/funasr/bin/tp_inference_launch.py
+++ /dev/null
@@ -1,287 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-
-import argparse
-import logging
-import os
-import sys
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_tp(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- timestamp_infer_config: Optional[str],
- timestamp_model_file: Optional[str],
- timestamp_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- split_with_space: bool = True,
- seg_dict_file: Optional[str] = None,
- **kwargs,
-):
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speechtext2timestamp_kwargs = dict(
- timestamp_infer_config=timestamp_infer_config,
- timestamp_model_file=timestamp_model_file,
- timestamp_cmvn_file=timestamp_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
- speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs)
-
- preprocessor = LMPreprocessor(
- train=False,
- token_type=speechtext2timestamp.tp_train_args.token_type,
- token_list=speechtext2timestamp.tp_train_args.token_list,
- bpemodel=None,
- text_cleaner=None,
- g2p_type=None,
- text_name="text",
- non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file,
- )
-
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- tp_writer = writer[f"timestamp_prediction"]
- # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
- else:
- tp_writer = None
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs
- ):
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- writer = None
- if output_path is not None:
- writer = DatadirWriter(output_path)
- tp_writer = writer[f"timestamp_prediction"]
- else:
- tp_writer = None
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
- loader = build_streaming_iterator(
- task_name="asr",
- preprocess_args=speechtext2timestamp.tp_train_args,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=preprocessor,
- )
-
- tp_result_list = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- logging.info("timestamp predicting, utt_id: {}".format(keys))
- _batch = {'speech': batch['speech'],
- 'speech_lengths': batch['speech_lengths'],
- 'text_lengths': batch['text_lengths']}
- us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
-
- for batch_id in range(_bs):
- key = keys[batch_id]
- token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
- ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token,
- force_time_shift=-3.0)
- logging.warning(ts_str)
- item = {'key': key, 'value': ts_str, 'timestamp': ts_list}
- if tp_writer is not None:
- tp_writer["tp_sync"][key + '#'] = ts_str
- tp_writer["tp_time"][key + '#'] = str(ts_list)
- tp_result_list.append(item)
- return tp_result_list
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "tp_norm":
- return inference_tp(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Timestamp Prediction Inference",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--timestamp_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--timestamp_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--timestamp_cmvn_file",
- type=str,
- help="Global CMVN file",
- )
-
- group = parser.add_argument_group("The inference configuration related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="tp_norm",
- help="The decoding mode",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/cli/train_cli.py b/funasr/bin/train.py
similarity index 96%
rename from funasr/cli/train_cli.py
rename to funasr/bin/train.py
index a22d5d4..4187476 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/bin/train.py
@@ -19,18 +19,13 @@
# from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.tokenizer.funtoken import build_tokenizer
from funasr.datasets.dataset_jsonl import AudioDataset
-from funasr.cli.trainer import Trainer
+from funasr.utils.trainer import Trainer
# from funasr.utils.load_fr_py import load_class_from_path
from funasr.utils.dynamic_import import dynamic_import
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.utils.download_from_hub import download_model
-
-def preprocess_config(cfg: DictConfig):
- for key, value in cfg.items():
- if value == 'None':
- cfg[key] = None
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
deleted file mode 100644
index 5763873..0000000
--- a/funasr/bin/vad_infer.py
+++ /dev/null
@@ -1,180 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import logging
-import math
-from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.torch_utils.device_funcs import to_device
-
-
-class Speech2VadSegment:
- """Speech2VadSegment class
-
- Examples:
- >>> import librosa
- >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
-
- """
-
- def __init__(
- self,
- vad_infer_config: Union[Path, str] = None,
- vad_model_file: Union[Path, str] = None,
- vad_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- **kwargs,
- ):
-
- # 1. Build vad model
- vad_model, vad_infer_args = build_model_from_file(
- vad_infer_config, vad_model_file, None, device, task_name="vad"
- )
- frontend = None
- if vad_infer_args.frontend is not None:
- frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
-
- logging.info("vad_model: {}".format(vad_model))
- logging.info("vad_infer_args: {}".format(vad_infer_args))
- vad_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.vad_model = vad_model
- self.vad_infer_args = vad_infer_args
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
- self.batch_size = batch_size
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict()
- ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- self.frontend.filter_length_max = math.inf
- fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
- feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
- fbanks = to_device(fbanks, device=self.device)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
-
- # b. Forward Encoder streaming
- t_offset = 0
- step = min(feats_len.max(), 6000)
- segments = [[]] * self.batch_size
- for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
- if t_offset + step >= feats_len - 1:
- step = feats_len - t_offset
- is_final = True
- else:
- is_final = False
- batch = {
- "feats": feats[:, t_offset:t_offset + step, :],
- "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
- "is_final": is_final,
- "in_cache": in_cache
- }
- # a. To device
- # batch = to_device(batch, device=self.device)
- segments_part, in_cache = self.vad_model(**batch)
- if segments_part:
- for batch_num in range(0, self.batch_size):
- segments[batch_num] += segments_part[batch_num]
- return fbanks, segments
-
-
-class Speech2VadSegmentOnline(Speech2VadSegment):
- """Speech2VadSegmentOnline class
-
- Examples:
- >>> import librosa
- >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
- >>> audio, rate = librosa.load("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
-
- """
-
- def __init__(self, **kwargs):
- super(Speech2VadSegmentOnline, self).__init__(**kwargs)
- vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
- self.frontend = None
- if self.vad_infer_args.frontend is not None:
- self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- batch_size = speech.shape[0]
- segments = [[]] * batch_size
- if self.frontend is not None:
- reset = in_cache == dict()
- feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final, reset)
- fbanks, _ = self.frontend.get_fbank()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
- if feats.shape[0]:
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- waveforms = self.frontend.get_waveforms()
- if max_end_sil == 800 and self.vad_infer_args.vad_post_conf["max_end_silence_time"] != 800:
- max_end_sil = self.vad_infer_args.vad_post_conf["max_end_silence_time"]
-
- batch = {
- "feats": feats,
- "waveform": waveforms,
- "in_cache": in_cache,
- "is_final": is_final,
- "max_end_sil": max_end_sil
- }
- # a. To device
- batch = to_device(batch, device=self.device)
- segments, in_cache = self.vad_model.forward_online(**batch)
- # in_cache.update(batch['in_cache'])
- # in_cache = {key: value for key, value in batch['in_cache'].items()}
- return fbanks, segments, in_cache
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
deleted file mode 100644
index a031a5a..0000000
--- a/funasr/bin/vad_inference_launch.py
+++ /dev/null
@@ -1,379 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-
-import torch
-
-torch.set_num_threads(1)
-
-import argparse
-import logging
-import os
-import sys
-import json
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
-
-
-def inference_vad(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- **kwargs,
-):
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="vad",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
- else:
- writer = None
- ibest_writer = None
-
- vad_results = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- # do vad segment
- _, results = speech2vadsegment(**batch)
- for i, _ in enumerate(keys):
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
- item = {'key': keys[i], 'value': results[i]}
- vad_results.append(item)
- if writer is not None:
- ibest_writer["text"][keys[i]] = "{}".format(results[i])
- torch.cuda.empty_cache()
- return vad_results
-
- return _forward
-
-
-def inference_vad_online(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- **kwargs,
-):
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = build_streaming_iterator(
- task_name="vad",
- preprocess_args=None,
- data_path_and_name_and_type=data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
- else:
- writer = None
- ibest_writer = None
-
- vad_results = []
- if param_dict is None:
- param_dict = dict()
- param_dict['in_cache'] = dict()
- param_dict['is_final'] = True
- batch_in_cache = param_dict.get('in_cache', dict())
- is_final = param_dict.get('is_final', False)
- max_end_sil = param_dict.get('max_end_sil', 800)
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch['in_cache'] = batch_in_cache
- batch['is_final'] = is_final
- batch['max_end_sil'] = max_end_sil
-
- # do vad segment
- _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
- # param_dict['in_cache'] = batch['in_cache']
- if results:
- for i, _ in enumerate(keys):
- if results[i]:
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
- item = {'key': keys[i], 'value': results[i]}
- vad_results.append(item)
- if writer is not None:
- ibest_writer["text"][keys[i]] = "{}".format(results[i])
-
- return vad_results
-
- return _forward
-
-
-def inference_launch(mode, **kwargs):
- if mode == "offline":
- return inference_vad(**kwargs)
- elif mode == "online":
- return inference_vad_online(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="VAD Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--vad_cmvn_file",
- type=str,
- help="Global CMVN file",
- )
- group.add_argument(
- "--vad_train_config",
- type=str,
- help="VAD training configuration",
- )
-
- group = parser.add_argument_group("The inference configuration related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="vad",
- help="The decoding mode",
- )
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
-
- # set logging messages
- logging.basicConfig(
- level=args.log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(kwargs))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_dir.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- inference_pipeline = inference_launch(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/cli/__init__.py b/funasr/cli/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/cli/__init__.py
+++ /dev/null
diff --git a/funasr/cli/models/__init__.py b/funasr/cli/models/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/cli/models/__init__.py
+++ /dev/null
diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py
deleted file mode 100644
index 7ca80f5..0000000
--- a/funasr/cli/models/paraformer.py
+++ /dev/null
@@ -1,655 +0,0 @@
-import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.nets_utils import make_pad_mask, pad_list
-from funasr.modules.nets_utils import th_accuracy
-from funasr.torch_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-
-from funasr.cli.model_class_factory import *
-
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-
-
-class Paraformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- # token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[str] = None,
- frontend_conf: Optional[Dict] = None,
- specaug: Optional[str] = None,
- specaug_conf: Optional[Dict] = None,
- normalize: str = None,
- normalize_conf: Optional[Dict] = None,
- encoder: str = None,
- encoder_conf: Optional[Dict] = None,
- decoder: str = None,
- decoder_conf: Optional[Dict] = None,
- ctc: str = None,
- ctc_conf: Optional[Dict] = None,
- predictor: str = None,
- predictor_conf: Optional[Dict] = None,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- # report_cer: bool = True,
- # report_wer: bool = True,
- # sym_space: str = "<space>",
- # sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- # predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- use_1st_decoder_loss: bool = False,
- **kwargs,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
-
- super().__init__()
-
- # import pdb;
- # pdb.set_trace()
-
- if frontend is not None:
- frontend_class = frontend_choices.get_class(frontend)
- frontend = frontend_class(**frontend_conf)
- if specaug is not None:
- specaug_class = specaug_choices.get_class(specaug)
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = normalize_choices.get_class(normalize)
- normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_choices.get_class(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
- if decoder is not None:
- decoder_class = decoder_choices.get_class(decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
- if ctc_weight > 0.0:
-
- if ctc_conf is None:
- ctc_conf = {}
-
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
- )
- if predictor is not None:
- predictor_class = predictor_choices.get_class(predictor)
- predictor = predictor_class(**predictor_conf)
-
- # note that eos is the same as sos (equivalent ID)
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.interctc_weight = interctc_weight
- # self.token_list = token_list.copy()
- #
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- # self.preencoder = preencoder
- # self.postencoder = postencoder
- self.encoder = encoder
- #
- # if not hasattr(self.encoder, "interctc_use_conditioning"):
- # self.encoder.interctc_use_conditioning = False
- # if self.encoder.interctc_use_conditioning:
- # self.encoder.conditioning_layer = torch.nn.Linear(
- # vocab_size, self.encoder.output_size()
- # )
- #
- # self.error_calculator = None
- #
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- #
- # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
- self.predictor = predictor
- self.predictor_weight = predictor_weight
- self.predictor_bias = predictor_bias
- self.sampling_ratio = sampling_ratio
- self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- # self.step_cur = 0
- #
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
-
- self.use_1st_decoder_loss = use_1st_decoder_loss
- self.length_normalized_loss = length_normalized_loss
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- decoding_ind: int
- """
- decoding_ind = kwargs.get("kwargs", None)
- # import pdb;
- # pdb.set_trace()
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # # for data-parallel
- # text = text[:, : text_lengths.max()]
- # speech = speech[:, :speech_lengths.max()]
-
- # 1. Encoder
- if hasattr(self.encoder, "overlap_chunk_cls"):
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
- if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + (1 - self.ctc_weight) * pre_loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum()
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
- feats, feats_lengths = speech, speech_lengths
- return {"feats": feats, "feats_lengths": feats_lengths}
-
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
- # # 1. Extract feats
- # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(speech, speech_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # # Pre-encoder, e.g. used for raw input data
- # if self.preencoder is not None:
- # feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc, ind=ind
- )
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc
- )
- else:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- # # Post-encoder, e.g. NLU
- # if self.postencoder is not None:
- # encoder_out, encoder_out_lens = self.postencoder(
- # encoder_out, encoder_out_lens
- # )
-
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
-
- return encoder_out, encoder_out_lens
-
- def calc_predictor(self, encoder_out, encoder_out_lens):
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
-
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
-
- def nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ) -> torch.Tensor:
- """Compute negative log likelihood(nll) from transformer-decoder
- Normally, this function is called in batchify_nll.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- """
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- ) # [batch, seqlen, dim]
- batch_size = decoder_out.size(0)
- decoder_num_class = decoder_out.size(2)
- # nll: negative log-likelihood
- nll = torch.nn.functional.cross_entropy(
- decoder_out.view(-1, decoder_num_class),
- ys_out_pad.view(-1),
- ignore_index=self.ignore_id,
- reduction="none",
- )
- nll = nll.view(batch_size, -1)
- nll = nll.sum(dim=1)
- assert nll.size(0) == batch_size
- return nll
-
- def batchify_nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- batch_size: int = 100,
- ):
- """Compute negative log likelihood(nll) from transformer-decoder
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
- GPU memory usage
- """
- total_num = encoder_out.size(0)
- if total_num <= batch_size:
- nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- else:
- nll = []
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
- batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
- batch_ys_pad = ys_pad[start_idx:end_idx, :]
- batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
- batch_nll = self.nll(
- batch_encoder_out,
- batch_encoder_out_lens,
- batch_ys_pad,
- batch_ys_pad_lens,
- )
- nll.append(batch_nll)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nll)
- assert nll.size(0) == total_num
- return nll
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
-
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
-
-
- if self.use_1st_decoder_loss:
- sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
-
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
- def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
-
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
diff --git a/funasr/cli/model_class_factory.py b/funasr/models/model_class_factory.py
similarity index 93%
rename from funasr/cli/model_class_factory.py
rename to funasr/models/model_class_factory.py
index b329492..819ca21 100644
--- a/funasr/cli/model_class_factory.py
+++ b/funasr/models/model_class_factory.py
@@ -123,27 +123,7 @@
default=None,
optional=True,
)
-# model_choices = ClassChoices(
-# "model",
-# classes=dict(
-# asr=ASRModel,
-# uniasr=UniASR,
-# paraformer=Paraformer,
-# paraformer_online=ParaformerOnline,
-# paraformer_bert=ParaformerBert,
-# bicif_paraformer=BiCifParaformer,
-# contextual_paraformer=ContextualParaformer,
-# neatcontextual_paraformer=NeatContextualParaformer,
-# mfcca=MFCCA,
-# timestamp_prediction=TimestampPredictor,
-# rnnt=TransducerModel,
-# rnnt_unified=UnifiedTransducerModel,
-# bat=BATModel,
-# sa_asr=SAASRModel,
-# ),
-# type_check=None,
-# default="asr",
-# )
+
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 75b36a9..50e7cd7 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -37,7 +37,7 @@
# from funasr.models.predictor.cif import CifPredictorV3
from funasr.models.paraformer.search import Hypothesis
-from funasr.cli.model_class_factory import *
+from funasr.models.model_class_factory import *
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
diff --git a/funasr/cli/trainer.py b/funasr/utils/trainer.py
similarity index 100%
rename from funasr/cli/trainer.py
rename to funasr/utils/trainer.py
diff --git a/setup.py b/setup.py
index 197f346..a1e47af 100644
--- a/setup.py
+++ b/setup.py
@@ -131,6 +131,6 @@
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"console_scripts": [
- "funasr = funasr.bin.inference_cli:main",
+ "funasr = funasr.bin.inference:main_hydra",
]},
)
--
Gitblit v1.9.1