From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky
---
funasr/bin/asr_inference.py | 272 ++++++++++++++++++++++++++++++++++++------------------
1 files changed, 181 insertions(+), 91 deletions(-)
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
old mode 100755
new mode 100644
index 6ee0ffe..f3b4d56
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -12,6 +12,7 @@
from typing import Sequence
from typing import Tuple
from typing import Union
+from typing import Dict
import numpy as np
import torch
@@ -38,6 +39,12 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
class Speech2Text:
@@ -45,7 +52,7 @@
Examples:
>>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -56,6 +63,7 @@
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
@@ -72,6 +80,7 @@
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
+ frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
@@ -79,8 +88,12 @@
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, device
+ asr_train_config, asr_model_file, cmvn_file, device
)
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
@@ -129,36 +142,6 @@
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
- # TODO(karita): make all scorers batchfied
- if batch_size == 1:
- non_batch = [
- k
- for k, v in beam_search.full_scorers.items()
- if not isinstance(v, BatchScorerInterface)
- ]
- if len(non_batch) == 0:
- if streaming:
- beam_search.__class__ = BatchBeamSearchOnlineSim
- beam_search.set_streaming_config(asr_train_config)
- logging.info(
- "BatchBeamSearchOnlineSim implementation is selected."
- )
- else:
- beam_search.__class__ = BatchBeamSearch
- logging.info("BatchBeamSearch implementation is selected.")
- else:
- logging.warning(
- f"As non-batch scorers {non_batch} are found, "
- f"fall back to non-batch implementation."
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
- logging.info(f"Beam_search: {beam_search}")
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
@@ -188,10 +171,11 @@
self.device = device
self.dtype = dtype
self.nbest = nbest
+ self.frontend = frontend
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray]
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
@@ -203,7 +187,7 @@
"""Inference
Args:
- data: Input speech data
+ speech: Input speech data
Returns:
text, token, token_int, hyp
@@ -214,11 +198,16 @@
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
- # data: (Nsamples,) -> (1, Nsamples)
- speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- # lengths: (1,)
- lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- batch = {"speech": speech, "speech_lengths": lengths}
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@@ -262,35 +251,99 @@
assert check_return_type(results)
return results
-
def inference(
- output_dir: str,
maxlenratio: float,
minlenratio: float,
batch_size: int,
- dtype: str,
beam_size: int,
ngpu: int,
- seed: int,
ctc_weight: float,
lm_weight: float,
- ngram_weight: float,
penalty: float,
- nbest: int,
- num_workers: int,
log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- key_file: Optional[str],
+ data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
- lm_train_config: Optional[str],
- lm_file: Optional[str],
- word_lm_train_config: Optional[str],
- token_type: Optional[str],
- bpemodel: Optional[str],
- allow_variable_data_keys: bool,
- streaming: bool,
+ cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
**kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ penalty=penalty,
+ log_level=log_level,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ raw_inputs=raw_inputs,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ key_file=key_file,
+ word_lm_train_config=word_lm_train_config,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ ngram_weight=ngram_weight,
+ nbest=nbest,
+ num_workers=num_workers,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
):
assert check_argument_types()
if batch_size > 1:
@@ -299,24 +352,25 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
- if ngpu >= 1:
+
+ if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
@@ -335,52 +389,81 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
-
- # 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- with DatadirWriter(output_dir) as writer:
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["<space>"], [2], hyp]] * nbest
-
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
if text is not None:
- ibest_writer["text"][key] = text
-
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+ return asr_result_list
+
+ return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -429,9 +512,11 @@
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
- required=True,
+ required=False,
action="append",
)
+ group.add_argument("--raw_inputs", type=list, default=None)
+ # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
@@ -447,6 +532,11 @@
help="ASR model parameter file",
)
group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
--
Gitblit v1.9.1