| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | from funasr.utils.misc import statistic_model_parameters |
| | | |
| | | class Speech2Xvector: |
| | | """Speech2Xvector class |
| | |
| | | device=device |
| | | ) |
| | | logging.info("sv_model: {}".format(sv_model)) |
| | | logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model))) |
| | | logging.info("sv_train_args: {}".format(sv_train_args)) |
| | | sv_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | |
| | | |
| | | |
| | | def inference_modelscope( |
| | | output_dir: Optional[str], |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | key_file: Optional[str], |
| | | sv_train_config: Optional[str], |
| | | sv_model_file: Optional[str], |
| | | model_tag: Optional[str], |
| | | output_dir: Optional[str] = None, |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | ngpu: int = 1, |
| | | seed: int = 0, |
| | | num_workers: int = 0, |
| | | log_level: Union[int, str] = "INFO", |
| | | key_file: Optional[str] = None, |
| | | sv_train_config: Optional[str] = "sv.yaml", |
| | | sv_model_file: Optional[str] = "sv.pth", |
| | | model_tag: Optional[str] = None, |
| | | allow_variable_data_keys: bool = True, |
| | | streaming: bool = False, |
| | | embedding_node: str = "resnet1_dense", |
| | |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: Optional[dict] = None, |
| | | ): |
| | | logging.info("param_dict: {}".format(param_dict)) |