| funasr/bin/tp_infer.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/bin/tp_inference_launch.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/build_utils/build_model_from_file.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/build_utils/build_streaming_iterator.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/bin/tp_infer.py
@@ -1,57 +1,35 @@ # -*- encoding: utf-8 -*- #!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import argparse import logging from optparse import Option import sys import json from pathlib import Path from typing import Any from typing import List from typing import Optional from typing import Sequence from typing import Tuple from typing import Union from typing import Dict import numpy as np import torch from typeguard import check_argument_types from funasr.fileio.datadir_writer import DatadirWriter from funasr.datasets.preprocessor import LMPreprocessor from funasr.tasks.asr import ASRTaskAligner as ASRTask from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.models.frontend.wav_frontend import WavFrontend from funasr.text.token_id_converter import TokenIDConverter from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.torch_utils.device_funcs import to_device class Speech2Timestamp: def __init__( self, timestamp_infer_config: Union[Path, str] = None, timestamp_model_file: Union[Path, str] = None, timestamp_cmvn_file: Union[Path, str] = None, device: str = "cpu", dtype: str = "float32", **kwargs, self, timestamp_infer_config: Union[Path, str] = None, timestamp_model_file: Union[Path, str] = None, timestamp_cmvn_file: Union[Path, str] = None, device: str = "cpu", dtype: str = "float32", **kwargs, ): assert check_argument_types() # 1. Build ASR model tp_model, tp_train_args = ASRTask.build_model_from_file( timestamp_infer_config, timestamp_model_file, device=device tp_model, tp_train_args = build_model_from_file( timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp" ) if 'cuda' in device: tp_model = tp_model.cuda() # force model to cuda @@ -59,13 +37,12 @@ frontend = None if tp_train_args.frontend is not None: frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf) logging.info("tp_model: {}".format(tp_model)) logging.info("tp_train_args: {}".format(tp_train_args)) tp_model.to(dtype=getattr(torch, dtype)).eval() logging.info(f"Decoding device={device}, dtype={dtype}") self.tp_model = tp_model self.tp_train_args = tp_train_args @@ -79,13 +56,13 @@ self.encoder_downsampling_factor = 1 if tp_train_args.encoder_conf["input_layer"] == "conv2d": self.encoder_downsampling_factor = 4 @torch.no_grad() def __call__( self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, text_lengths: Union[torch.Tensor, np.ndarray] = None self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, text_lengths: Union[torch.Tensor, np.ndarray] = None ): assert check_argument_types() @@ -113,8 +90,6 @@ enc = enc[0] # c. Forward Predictor _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device) + 1) return us_alphas, us_peaks funasr/bin/tp_inference_launch.py
@@ -1,5 +1,5 @@ # -*- encoding: utf-8 -*- #!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) @@ -8,87 +8,66 @@ import logging import os import sys from typing import Union, Dict, Any from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none import argparse import logging from optparse import Option import sys import json from pathlib import Path from typing import Any from typing import List from typing import Optional from typing import Sequence from typing import Tuple from typing import Union from typing import Dict import numpy as np import torch from typeguard import check_argument_types from funasr.fileio.datadir_writer import DatadirWriter from funasr.bin.tp_infer import Speech2Timestamp from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.datasets.preprocessor import LMPreprocessor from funasr.tasks.asr import ASRTaskAligner as ASRTask from funasr.torch_utils.device_funcs import to_device from funasr.fileio.datadir_writer import DatadirWriter from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.models.frontend.wav_frontend import WavFrontend from funasr.text.token_id_converter import TokenIDConverter from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.bin.tp_infer import Speech2Timestamp def inference_tp( batch_size: int, ngpu: int, log_level: Union[int, str], # data_path_and_name_and_type, timestamp_infer_config: Optional[str], timestamp_model_file: Optional[str], timestamp_cmvn_file: Optional[str] = None, # raw_inputs: Union[np.ndarray, torch.Tensor] = None, key_file: Optional[str] = None, allow_variable_data_keys: bool = False, output_dir: Optional[str] = None, dtype: str = "float32", seed: int = 0, num_workers: int = 1, split_with_space: bool = True, seg_dict_file: Optional[str] = None, **kwargs, batch_size: int, ngpu: int, log_level: Union[int, str], # data_path_and_name_and_type, timestamp_infer_config: Optional[str], timestamp_model_file: Optional[str], timestamp_cmvn_file: Optional[str] = None, # raw_inputs: Union[np.ndarray, torch.Tensor] = None, key_file: Optional[str] = None, allow_variable_data_keys: bool = False, output_dir: Optional[str] = None, dtype: str = "float32", seed: int = 0, num_workers: int = 1, split_with_space: bool = True, seg_dict_file: Optional[str] = None, **kwargs, ): assert check_argument_types() ncpu = kwargs.get("ncpu", 1) torch.set_num_threads(ncpu) if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build speech2vadsegment speechtext2timestamp_kwargs = dict( timestamp_infer_config=timestamp_infer_config, @@ -99,7 +78,7 @@ ) logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs)) speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs) preprocessor = LMPreprocessor( train=False, token_type=speechtext2timestamp.tp_train_args.token_type, @@ -112,21 +91,21 @@ split_with_space=split_with_space, seg_dict_file=seg_dict_file, ) if output_dir is not None: writer = DatadirWriter(output_dir) tp_writer = writer[f"timestamp_prediction"] # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list) else: tp_writer = None def _forward( data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, output_dir_v2: Optional[str] = None, fs: dict = None, param_dict: dict = None, **kwargs data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, output_dir_v2: Optional[str] = None, fs: dict = None, param_dict: dict = None, **kwargs ): output_path = output_dir_v2 if output_dir_v2 is not None else output_dir writer = None @@ -140,32 +119,31 @@ if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, loader = build_streaming_iterator( task_name="asr", preprocess_args=speechtext2timestamp.tp_train_args, data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=preprocessor, collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) tp_result_list = [] for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" logging.info("timestamp predicting, utt_id: {}".format(keys)) _batch = {'speech': batch['speech'], 'speech_lengths': batch['speech_lengths'], 'text_lengths': batch['text_lengths']} us_alphas, us_cif_peak = speechtext2timestamp(**_batch) for batch_id in range(_bs): key = keys[batch_id] token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id]) @@ -178,10 +156,8 @@ tp_writer["tp_time"][key + '#'] = str(ts_list) tp_result_list.append(item) return tp_result_list return _forward def inference_launch(mode, **kwargs): @@ -190,6 +166,7 @@ else: logging.info("Unknown decoding mode: {}".format(mode)) return None def get_parser(): parser = config_argparse.ArgumentParser( @@ -306,7 +283,6 @@ inference_pipeline = inference_launch(**kwargs) return inference_pipeline(kwargs["data_path_and_name_and_type"]) if __name__ == "__main__": funasr/build_utils/build_model_from_file.py
@@ -87,7 +87,7 @@ ckpt, mode, ): assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp" logging.info("start convert tf model to torch model") from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict var_dict_tf = load_tf_dict(ckpt) @@ -148,7 +148,7 @@ if model.decoder is not None: var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) else: elif "mode" == "sv": # speech encoder var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) @@ -158,7 +158,19 @@ # decoder var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) else: # encoder var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) # predictor var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) # decoder var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) # bias_encoder var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update return var_dict_torch_update funasr/build_utils/build_streaming_iterator.py
@@ -5,7 +5,7 @@ from funasr.datasets.iterable_dataset import IterableESPnetDataset from funasr.datasets.small_datasets.collate_fn import CommonCollateFn from funasr.datasets.small_datasets.preprocessor import build_preprocess from funasr.build_utils.build_model_from_file import build_model_from_file def build_streaming_iterator( task_name, @@ -18,6 +18,7 @@ dtype: str = np.float32, num_workers: int = 1, use_collate_fn: bool = True, preprocess_fn=None, ngpu: int = 0, train: bool=False, ) -> DataLoader: @@ -25,7 +26,9 @@ assert check_argument_types() # preprocess if preprocess_args is not None: if preprocess_fn is not None: preprocess_fn = preprocess_fn elif preprocess_args is not None: preprocess_args.task_name = task_name preprocess_fn = build_preprocess(preprocess_args, train) else: