From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky
---
funasr/bin/asr_inference.py | 309 +++++++++++++++++++++++++++------------------------
1 files changed, 163 insertions(+), 146 deletions(-)
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
old mode 100755
new mode 100644
index b937f88..f3b4d56
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -42,25 +42,17 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
-from modelscope.utils.logger import get_logger
-
-logger = get_logger()
header_colors = '\033[95m'
end_colors = '\033[0m'
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
- 'audio_fs': 16000,
- 'model_fs': 16000
-}
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -71,6 +63,7 @@
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
@@ -95,13 +88,14 @@
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, device
+ asr_train_config, asr_model_file, cmvn_file, device
)
- if asr_model.frontend is None and frontend_conf is not None:
- frontend = WavFrontend(**frontend_conf)
- asr_model.frontend = frontend
- # logging.info("asr_model: {}".format(asr_model))
- # logging.info("asr_train_args: {}".format(asr_train_args))
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
@@ -164,7 +158,7 @@
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
- # logging.info(f"Text tokenizer: {tokenizer}")
+ logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
@@ -177,10 +171,11 @@
self.device = device
self.dtype = dtype
self.nbest = nbest
+ self.frontend = frontend
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray]
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
@@ -203,12 +198,16 @@
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
- # data: (Nsamples,) -> (1, Nsamples)
- speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
- # lengths: (1,)
- lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- batch = {"speech": speech, "speech_lengths": lengths}
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@@ -252,7 +251,6 @@
assert check_return_type(results)
return results
-
def inference(
maxlenratio: float,
minlenratio: float,
@@ -266,7 +264,8 @@
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
- audio_lists: Union[List[Any], bytes] = None,
+ cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
@@ -281,10 +280,70 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
- frontend_conf: dict = None,
- fs: Union[dict, int] = 16000,
- lang: Optional[str] = None,
**kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ penalty=penalty,
+ log_level=log_level,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ raw_inputs=raw_inputs,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ key_file=key_file,
+ word_lm_train_config=word_lm_train_config,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ ngram_weight=ngram_weight,
+ nbest=nbest,
+ num_workers=num_workers,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
if batch_size > 1:
@@ -293,63 +352,25 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
- if ngpu >= 1:
+
+ if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
- hop_length: int = 160
- sr: int = 16000
- if isinstance(fs, int):
- sr = fs
- else:
- if 'model_fs' in fs and fs['model_fs'] is not None:
- sr = fs['model_fs']
- # data_path_and_name_and_type for modelscope: (data from audio_lists)
- # ['speech', 'sound', 'am.mvn']
- # data_path_and_name_and_type for funasr:
- # [('/mnt/data/jiangyu.xzy/exp/maas/mvn.1.scp', 'speech', 'kaldi_ark')]
- if isinstance(data_path_and_name_and_type[0], Tuple):
- features_type: str = data_path_and_name_and_type[0][1]
- elif isinstance(data_path_and_name_and_type[0], str):
- features_type: str = data_path_and_name_and_type[1]
- else:
- raise NotImplementedError("unknown features type:{0}".format(data_path_and_name_and_type))
- if features_type != 'sound':
- frontend_conf = None
- flag_modelscope = False
- else:
- flag_modelscope = True
- if frontend_conf is not None:
- if 'hop_length' in frontend_conf:
- hop_length = frontend_conf['hop_length']
-
- finish_count = 0
- file_count = 1
- if flag_modelscope and not isinstance(data_path_and_name_and_type[0], Tuple):
- data_path_and_name_and_type_new = [
- audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
- ]
- if isinstance(audio_lists, bytes):
- file_count = 1
- else:
- file_count = len(audio_lists)
- if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
- mvn_file = data_path_and_name_and_type[2]
- mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
- frontend_conf['mvn_data'] = mvn_data
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
@@ -365,29 +386,26 @@
penalty=penalty,
nbest=nbest,
streaming=streaming,
- frontend_conf=frontend_conf,
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
-
- # 3. Build data-iterator
- if flag_modelscope:
- loader = ASRTask.build_streaming_iterator_modelscope(
- data_path_and_name_and_type_new,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- sample_rate=fs
- )
- else:
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
+ fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
@@ -396,62 +414,56 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
-
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["<space>"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
if writer is not None:
- ibest_writer["text"][key] = text
- return asr_result_list
-
-
-def set_parameters(language: str = None,
- sample_rate: Union[int, Dict[Any, int]] = None):
- if language is not None:
- global global_asr_language
- global_asr_language = language
- if sample_rate is not None:
- global global_sample_rate
- global_sample_rate = sample_rate
-
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+ return asr_result_list
+
+ return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -500,10 +512,10 @@
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
- required=True,
+ required=False,
action="append",
)
- group.add_argument("--audio_lists", type=list, default=None)
+ group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
@@ -520,6 +532,11 @@
help="ASR model parameter file",
)
group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
--
Gitblit v1.9.1