big fix for speaker pipeline
| New file |
| | |
| | | import os |
| | | from modelscope.metainfo import Trainers |
| | | from modelscope.trainers import build_trainer |
| | | from funasr.datasets.ms_dataset import MsDataset |
| | | |
| | | |
| | | def modelscope_finetune(params): |
| | | if not os.path.exists(params.output_dir): |
| | | os.makedirs(params.output_dir, exist_ok=True) |
| | | # dataset split ["train", "validation"] |
| | | ds_dict = MsDataset.load(params.data_path) |
| | | kwargs = dict( |
| | | model=params.model, |
| | | model_revision=params.model_revision, |
| | | data_dir=ds_dict, |
| | | dataset_type=params.dataset_type, |
| | | work_dir=params.output_dir, |
| | | batch_bins=params.batch_bins, |
| | | max_epoch=params.max_epoch, |
| | | lr=params.lr) |
| | | trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) |
| | | trainer.train() |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | from funasr.utils.modelscope_param import modelscope_args |
| | | params = modelscope_args(model="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", data_path="./data") |
| | | params.output_dir = "./checkpoint" # m模型保存路径 |
| | | params.data_path = "./example_data/" # 数据路径 |
| | | params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large |
| | | params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, |
| | | params.max_epoch = 50 # 最大训练轮数 |
| | | params.lr = 0.00005 # 设置学习率 |
| | | params.model_revision = "v1.2.1" |
| | | modelscope_finetune(params) |
| New file |
| | |
| | | import os |
| | | import shutil |
| | | import argparse |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | def modelscope_infer(args): |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid) |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model=args.model, |
| | | output_dir=args.output_dir, |
| | | param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt} |
| | | ) |
| | | inference_pipeline(audio_in=args.audio_in, batch_size_token=args.batch_size_token) |
| | | |
| | | if __name__ == "__main__": |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument('--model', type=str, default="damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn") |
| | | parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp") |
| | | parser.add_argument('--output_dir', type=str, default="./results/") |
| | | parser.add_argument('--decoding_mode', type=str, default="normal") |
| | | parser.add_argument('--hotword_txt', type=str, default=None) |
| | | parser.add_argument('--batch_size_token', type=int, default=5000) |
| | | parser.add_argument('--gpuid', type=str, default="0") |
| | | args = parser.parse_args() |
| | | modelscope_infer(args) |
| | |
| | | distribute_spk) |
| | | from funasr.build_utils.build_model_from_file import build_model_from_file |
| | | from funasr.utils.cluster_backend import ClusterBackend |
| | | from funasr.utils.modelscope_utils import get_cache_dir |
| | | from tqdm import tqdm |
| | | |
| | | def inference_asr( |
| | |
| | | time_stamp_writer: bool = True, |
| | | punc_infer_config: Optional[str] = None, |
| | | punc_model_file: Optional[str] = None, |
| | | sv_model_file: Optional[str] = "~/.cache/modelscope/hub/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin", |
| | | sv_model_file: Optional[str] = None, |
| | | streaming: bool = False, |
| | | embedding_node: str = "resnet1_dense", |
| | | sv_threshold: float = 0.9465, |
| | |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if sv_model_file is None: |
| | | sv_model_file = "{}/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/campplus_cn_common.bin".format(get_cache_dir(None)) |
| | | |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | |
| | | ##### speaker_verification ##### |
| | | ################################## |
| | | # load sv model |
| | | sv_model_dict = torch.load(sv_model_file.replace("~", os.environ['HOME']), map_location=torch.device('cpu')) |
| | | sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) |
| | | sv_model = CAMPPlus() |
| | | sv_model.load_state_dict(sv_model_dict) |
| | | sv_model.eval() |
| | |
| | | import os |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | from pathlib import Path |
| | | |
| | | |
| | | def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"): |
| | |
| | | if not os.path.exists(dst): |
| | | os.symlink(model_dir, dst) |
| | | |
| | | model_dir = snapshot_download(model_name, cache_dir=dst_dir_root) |
| | | model_dir = snapshot_download(model_name, cache_dir=dst_dir_root) |
| | | |
| | | def get_default_cache_dir(): |
| | | """ |
| | | default base dir: '~/.cache/modelscope' |
| | | """ |
| | | default_cache_dir = Path.home().joinpath('.cache', 'modelscope') |
| | | return default_cache_dir |
| | | |
| | | def get_cache_dir(model_id): |
| | | """cache dir precedence: |
| | | function parameter > environment > ~/.cache/modelscope/hub |
| | | |
| | | Args: |
| | | model_id (str, optional): The model id. |
| | | |
| | | Returns: |
| | | str: the model_id dir if model_id not None, otherwise cache root dir. |
| | | """ |
| | | default_cache_dir = get_default_cache_dir() |
| | | base_path = os.getenv('MODELSCOPE_CACHE', |
| | | os.path.join(default_cache_dir, 'hub')) |
| | | return base_path if model_id is None else os.path.join( |
| | | base_path, model_id + '/') |