Merge pull request #1247 from alibaba-damo-academy/funasr1.0
Funasr1.0
| | |
| | | ### Speech Recognition (Non-streaming) |
| | | ```python |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="paraformer-zh") |
| | | # for the long duration wav, you could add vad model |
| | | # model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc") |
| | | |
| | | res = model(input="asr_example_zh.wav", batch_size=64) |
| | | # paraformer-zh is a multi-functional asr model |
| | | # use vad, punc, spk or not as you need |
| | | model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \ |
| | | vad_model="fsmn-vad", vad_model_revision="v2.0.2", \ |
| | | punc_model="ct-punc-c", punc_model_revision="v2.0.2", \ |
| | | spk_model="cam++", spk_model_revision="v2.0.2") |
| | | res = model(input=f"{model.model_path}/example/asr_example.wav", |
| | | batch_size=16, |
| | | hotword='魔搭') |
| | | print(res) |
| | | ``` |
| | | Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download. |
| | |
| | | encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention |
| | | decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention |
| | | |
| | | model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.0") |
| | | model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.2") |
| | | |
| | | import soundfile |
| | | import os |
| | |
| | | ```python |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="ct-punc", model_revision="v2.0.1") |
| | | model = AutoModel(model="ct-punc", model_revision="v2.0.2") |
| | | |
| | | res = model(input="那今天的会就到这里吧 happy new year 明年见") |
| | | print(res) |
| | |
| | | ```python |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="fa-zh", model_revision="v2.0.0") |
| | | model = AutoModel(model="fa-zh", model_revision="v2.0.2") |
| | | |
| | | wav_file = f"{model.model_path}/example/asr_example.wav" |
| | | text_file = f"{model.model_path}/example/asr_example.wav" |
| | |
| | | fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2) |
| | | |
| | | for batch_idx, fbank_dict in enumerate(fbanks): |
| | | res = model(**fbank_dict) |
| | | print(res) |
| | | res = model(**fbank_dict) |
| | | print(res) |
| | |
| | | vad_model_revision="v2.0.2", |
| | | punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | punc_model_revision="v2.0.2", |
| | | spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | spk_model_revision="v2.0.2", |
| | | ) |
| | | |
| | | res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", |
| | | hotword='达摩院 磨搭') |
| | | res = model(input=f"{model.model_path}/example/asr_example.wav", |
| | | hotword='达摩院 魔搭') |
| | | print(res) |
| | |
| | | def generate_with_vad(self, input, input_len=None, **cfg): |
| | | |
| | | # step.1: compute the vad model |
| | | model = self.vad_model |
| | | kwargs = self.vad_kwargs |
| | | kwargs.update(cfg) |
| | | self.vad_kwargs.update(cfg) |
| | | beg_vad = time.time() |
| | | res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg) |
| | | vad_res = res |
| | | res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg) |
| | | end_vad = time.time() |
| | | print(f"time cost vad: {end_vad - beg_vad:0.3f}") |
| | | |
| | |
| | | if not len(sorted_data): |
| | | logging.info("decoding, utt: {}, empty speech".format(key)) |
| | | continue |
| | | |
| | | |
| | | # if kwargs["device"] == "cpu": |
| | | # batch_size = 0 |
| | | if len(sorted_data) > 0 and len(sorted_data[0]) > 0: |
| | | batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]) |
| | | |
| | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from io import BytesIO |
| | | from collections.abc import Sequence |
| | | import torch |
| | | import hydra |
| | | import logging |
| | | import argparse |
| | | from io import BytesIO |
| | | import torch.distributed as dist |
| | | from collections.abc import Sequence |
| | | from omegaconf import DictConfig, OmegaConf |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.models.lora.utils import mark_only_lora_as_trainable |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | |
| | | from funasr.register import tables |
| | | from funasr.optimizers import optim_classes |
| | | from funasr.train_utils.trainer import Trainer |
| | | from funasr.schedulers import scheduler_classes |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.train_utils.initialize import initialize |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.models.lora.utils import mark_only_lora_as_trainable |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | # from funasr.tokenizer.build_tokenizer import build_tokenizer |
| | | # from funasr.tokenizer.token_id_converter import TokenIDConverter |
| | | # from funasr.tokenizer.funtoken import build_tokenizer |
| | | from funasr.train_utils.trainer import Trainer |
| | | import torch.distributed as dist |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) |
| | | kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) |
| | | |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) |
| | | kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) |
| | | |
| | | |
| | | main(**kwargs) |
| | | main(**kwargs) |
| | | |
| | | |
| | | def main(**kwargs): |
| | | # preprocess_config(kwargs) |
| | | # import pdb; pdb.set_trace() |
| | | # set random seed |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | # Check if we are using DDP or FSDP |
| | | use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1 |
| | | use_fsdp = kwargs.get("use_fsdp", None) |
| | | if use_ddp or use_fsdp: |
| | | dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://') |
| | | torch.cuda.set_device(local_rank) |
| | | |
| | | # save config.yaml |
| | | if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0: |
| | | os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True) |
| | | yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml") |
| | | OmegaConf.save(config=kwargs, f=yaml_file) |
| | | logging.info("config.yaml is saved to: %s", yaml_file) |
| | | # preprocess_config(kwargs) |
| | | # import pdb; pdb.set_trace() |
| | | # set random seed |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | # Check if we are using DDP or FSDP |
| | | use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1 |
| | | use_fsdp = kwargs.get("use_fsdp", None) |
| | | if use_ddp or use_fsdp: |
| | | dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://') |
| | | torch.cuda.set_device(local_rank) |
| | | |
| | | # save config.yaml |
| | | if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0: |
| | | os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True) |
| | | yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml") |
| | | OmegaConf.save(config=kwargs, f=yaml_file) |
| | | logging.info("config.yaml is saved to: %s", yaml_file) |
| | | |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | if tokenizer is not None: |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | if tokenizer is not None: |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | | |
| | | |
| | | |
| | | # init_param |
| | | init_param = kwargs.get("init_param", None) |
| | | if init_param is not None: |
| | | if not isinstance(init_param, (list, tuple)): |
| | | init_param = (init_param,) |
| | | logging.info("init_param is not None: %s", init_param) |
| | | for p in init_param: |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | | model=model, |
| | | init_param=p, |
| | | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), |
| | | oss_bucket=kwargs.get("oss_bucket", None), |
| | | ) |
| | | else: |
| | | initialize(model, kwargs.get("init", "kaiming_normal")) |
| | | # init_param |
| | | init_param = kwargs.get("init_param", None) |
| | | if init_param is not None: |
| | | if not isinstance(init_param, (list, tuple)): |
| | | init_param = (init_param,) |
| | | logging.info("init_param is not None: %s", init_param) |
| | | for p in init_param: |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | | model=model, |
| | | init_param=p, |
| | | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), |
| | | oss_bucket=kwargs.get("oss_bucket", None), |
| | | ) |
| | | else: |
| | | initialize(model, kwargs.get("init", "kaiming_normal")) |
| | | |
| | | |
| | | # freeze_param |
| | | freeze_param = kwargs.get("freeze_param", None) |
| | | if freeze_param is not None: |
| | | freeze_param = eval(freeze_param) |
| | | if isinstance(freeze_param, Sequence): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | | for t in freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | | logging.info(f"Setting {k}.requires_grad = False") |
| | | p.requires_grad = False |
| | | |
| | | # freeze_param |
| | | freeze_param = kwargs.get("freeze_param", None) |
| | | if freeze_param is not None: |
| | | freeze_param = eval(freeze_param) |
| | | if isinstance(freeze_param, Sequence): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | | for t in freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | | logging.info(f"Setting {k}.requires_grad = False") |
| | | p.requires_grad = False |
| | | |
| | | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | elif use_fsdp: |
| | | model = FSDP(model).cuda(local_rank) |
| | | else: |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | |
| | | # optim |
| | | optim = kwargs.get("optim", "adam") |
| | | assert optim in optim_classes |
| | | optim_class = optim_classes.get(optim) |
| | | optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) |
| | | |
| | | # scheduler |
| | | scheduler = kwargs.get("scheduler", "warmuplr") |
| | | assert scheduler in scheduler_classes |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | elif use_fsdp: |
| | | model = FSDP(model).cuda(local_rank) |
| | | else: |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | |
| | | # optim |
| | | optim = kwargs.get("optim", "adam") |
| | | assert optim in optim_classes |
| | | optim_class = optim_classes.get(optim) |
| | | optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) |
| | | |
| | | # scheduler |
| | | scheduler = kwargs.get("scheduler", "warmuplr") |
| | | assert scheduler in scheduler_classes |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | if batch_sampler is not None: |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | if batch_sampler is not None: |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | |
| | | |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |
| | | scheduler=scheduler, |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=None, |
| | | local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | | **kwargs.get("train_conf"), |
| | | ) |
| | | trainer.run() |
| | | |
| | | if use_ddp or use_fsdp: |
| | | torch.distributed.destroy_process_group() |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |
| | | scheduler=scheduler, |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=None, |
| | | local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | | **kwargs.get("train_conf"), |
| | | ) |
| | | trainer.run() |
| | | |
| | | if use_ddp or use_fsdp: |
| | | torch.distributed.destroy_process_group() |
| | | |
| | | |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main_hydra() |
| | | main_hydra() |
| | |
| | | import torch |
| | | import json |
| | | import torch.distributed as dist |
| | | import numpy as np |
| | | import kaldiio |
| | | import librosa |
| | | import torchaudio |
| | | import time |
| | | import logging |
| | | |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import extract_fbank |
| | | |
| | | |
| | | @tables.register("dataset_classes", "AudioDataset") |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | | """ |
| | | AudioDataset |
| | | """ |
| | | def __init__(self, |
| | | path, |
| | | index_ds: str = None, |
| | | frontend=None, |
| | | tokenizer=None, |
| | | int_pad_value: int = -1, |
| | | float_pad_value: float = 0.0, |
| | | **kwargs): |
| | | super().__init__() |
| | | index_ds_class = tables.index_ds_classes.get(index_ds) |
| | | self.index_ds = index_ds_class(path) |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) |
| | | self.preprocessor_speech = preprocessor_speech |
| | | preprocessor_text = kwargs.get("preprocessor_text", None) |
| | | if preprocessor_text: |
| | | preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) |
| | | self.preprocessor_text = preprocessor_text |
| | | |
| | | self.frontend = frontend |
| | | self.fs = 16000 if frontend is None else frontend.fs |
| | | self.data_type = "sound" |
| | | self.tokenizer = tokenizer |
| | | """ |
| | | AudioDataset |
| | | """ |
| | | def __init__(self, |
| | | path, |
| | | index_ds: str = None, |
| | | frontend=None, |
| | | tokenizer=None, |
| | | int_pad_value: int = -1, |
| | | float_pad_value: float = 0.0, |
| | | **kwargs): |
| | | super().__init__() |
| | | index_ds_class = tables.index_ds_classes.get(index_ds) |
| | | self.index_ds = index_ds_class(path) |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) |
| | | self.preprocessor_speech = preprocessor_speech |
| | | preprocessor_text = kwargs.get("preprocessor_text", None) |
| | | if preprocessor_text: |
| | | preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) |
| | | self.preprocessor_text = preprocessor_text |
| | | |
| | | self.frontend = frontend |
| | | self.fs = 16000 if frontend is None else frontend.fs |
| | | self.data_type = "sound" |
| | | self.tokenizer = tokenizer |
| | | |
| | | self.int_pad_value = int_pad_value |
| | | self.float_pad_value = float_pad_value |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_source_len(item) |
| | | |
| | | def get_target_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_target_len(item) |
| | | |
| | | def __len__(self): |
| | | return len(self.index_ds) |
| | | |
| | | def __getitem__(self, index): |
| | | item = self.index_ds[index] |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | source = item["source"] |
| | | data_src = load_audio(source, fs=self.fs) |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src) |
| | | speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] |
| | | self.int_pad_value = int_pad_value |
| | | self.float_pad_value = float_pad_value |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_source_len(item) |
| | | |
| | | def get_target_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_target_len(item) |
| | | |
| | | def __len__(self): |
| | | return len(self.index_ds) |
| | | |
| | | def __getitem__(self, index): |
| | | item = self.index_ds[index] |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | source = item["source"] |
| | | data_src = load_audio(source, fs=self.fs) |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src) |
| | | speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] |
| | | |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | | target = self.preprocessor_text(target) |
| | | ids = self.tokenizer.encode(target) |
| | | ids_lengths = len(ids) |
| | | text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | | target = self.preprocessor_text(target) |
| | | ids = self.tokenizer.encode(target) |
| | | ids_lengths = len(ids) |
| | | text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) |
| | | |
| | | return {"speech": speech[0, :, :], |
| | | "speech_lengths": speech_lengths, |
| | | "text": text, |
| | | "text_lengths": text_lengths, |
| | | } |
| | | |
| | | |
| | | def collator(self, samples: list=None): |
| | | return {"speech": speech[0, :, :], |
| | | "speech_lengths": speech_lengths, |
| | | "text": text, |
| | | "text_lengths": text_lengths, |
| | | } |
| | | |
| | | |
| | | def collator(self, samples: list=None): |
| | | outputs = {} |
| | | for sample in samples: |
| | | for key in sample.keys(): |
| | | if key not in outputs: |
| | | outputs[key] = [] |
| | | outputs[key].append(sample[key]) |
| | | |
| | | for key, data_list in outputs.items(): |
| | | if data_list[0].dtype == torch.int64: |
| | | |
| | | outputs = {} |
| | | for sample in samples: |
| | | for key in sample.keys(): |
| | | if key not in outputs: |
| | | outputs[key] = [] |
| | | outputs[key].append(sample[key]) |
| | | |
| | | for key, data_list in outputs.items(): |
| | | if data_list[0].dtype == torch.int64: |
| | | |
| | | pad_value = self.int_pad_value |
| | | else: |
| | | pad_value = self.float_pad_value |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) |
| | | return outputs |
| | | pad_value = self.int_pad_value |
| | | else: |
| | | pad_value = self.float_pad_value |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) |
| | | return outputs |
| | | |
| | |
| | | import torch |
| | | import json |
| | | import torch.distributed as dist |
| | | import time |
| | | import torch |
| | | import logging |
| | | import torch.distributed as dist |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonl") |
| | | class IndexDSJsonl(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path): |
| | | super().__init__() |
| | | |
| | | contents = [] |
| | | with open(path, encoding='utf-8') as fin: |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | self.contents.append(data['text']) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data["prompt"] |
| | | source = data["source"] |
| | | target = data["target"] |
| | | source_len = data["source_len"] |
| | | target_len = data["target_len"] |
| | | |
| | | def __init__(self, path): |
| | | super().__init__() |
| | | |
| | | contents = [] |
| | | with open(path, encoding='utf-8') as fin: |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | self.contents.append(data['text']) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data["prompt"] |
| | | source = data["source"] |
| | | target = data["target"] |
| | | source_len = data["source_len"] |
| | | target_len = data["target_len"] |
| | | |
| | | contents.append({"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | ) |
| | | |
| | | self.contents = [] |
| | | total_num = len(contents) |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | num_per_rank = total_num // world_size |
| | | |
| | | # rank = 0 |
| | | # import ipdb; ipdb.set_trace() |
| | | self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | |
| | | logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents))) |
| | | contents.append({"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | ) |
| | | |
| | | self.contents = [] |
| | | total_num = len(contents) |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | num_per_rank = total_num // world_size |
| | | |
| | | # rank = 0 |
| | | # import ipdb; ipdb.set_trace() |
| | | self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | |
| | | logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents))) |
| | | |
| | | def __len__(self): |
| | | return len(self.contents) |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | | def __len__(self): |
| | | return len(self.contents) |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | |
| | | import torch |
| | | |
| | | import numpy as np |
| | | |
| | | from funasr.register import tables |
| | |
| | | |
| | | @tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler") |
| | | class BatchSampler(torch.utils.data.BatchSampler): |
| | | |
| | | def __init__(self, dataset, |
| | | batch_type: str = "example", |
| | | batch_size: int = 100, |
| | | buffer_size: int = 30, |
| | | drop_last: bool = False, |
| | | shuffle: bool = True, |
| | | **kwargs): |
| | | |
| | | self.drop_last = drop_last |
| | | self.pre_idx = -1 |
| | | self.dataset = dataset |
| | | self.total_samples = len(dataset) |
| | | self.batch_type = batch_type |
| | | self.batch_size = batch_size |
| | | self.buffer_size = buffer_size |
| | | self.max_token_length = kwargs.get("max_token_length", 5000) |
| | | self.shuffle_idx = np.arange(self.total_samples) |
| | | self.shuffle = shuffle |
| | | |
| | | def __len__(self): |
| | | return self.total_samples |
| | | |
| | | def set_epoch(self, epoch): |
| | | np.random.seed(epoch) |
| | | |
| | | def __iter__(self): |
| | | |
| | | if self.shuffle: |
| | | np.random.shuffle(self.shuffle_idx) |
| | | |
| | | batch = [] |
| | | max_token = 0 |
| | | num_sample = 0 |
| | | |
| | | iter_num = (self.total_samples - 1) // self.buffer_size + 1 |
| | | # print("iter_num: ", iter_num) |
| | | for iter in range(self.pre_idx + 1, iter_num): |
| | | datalen_with_index = [] |
| | | for i in range(self.buffer_size): |
| | | idx = iter * self.buffer_size + i |
| | | if idx >= self.total_samples: |
| | | continue |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.get_source_len(idx_map) + \ |
| | | self.dataset.get_target_len(idx_map) |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |
| | | datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) |
| | | for item in datalen_with_index_sort: |
| | | idx, sample_len_cur_raw = item |
| | | if sample_len_cur_raw > self.max_token_length: |
| | | continue |
| | | |
| | | max_token_cur = max(max_token, sample_len_cur_raw) |
| | | max_token_padding = 1 + num_sample |
| | | if self.batch_type == 'length': |
| | | max_token_padding *= max_token_cur |
| | | if max_token_padding <= self.batch_size: |
| | | batch.append(idx) |
| | | max_token = max_token_cur |
| | | num_sample += 1 |
| | | else: |
| | | yield batch |
| | | batch = [idx] |
| | | max_token = sample_len_cur_raw |
| | | num_sample = 1 |
| | | |
| | | def __init__(self, dataset, |
| | | batch_type: str = "example", |
| | | batch_size: int = 100, |
| | | buffer_size: int = 30, |
| | | drop_last: bool = False, |
| | | shuffle: bool = True, |
| | | **kwargs): |
| | | |
| | | self.drop_last = drop_last |
| | | self.pre_idx = -1 |
| | | self.dataset = dataset |
| | | self.total_samples = len(dataset) |
| | | self.batch_type = batch_type |
| | | self.batch_size = batch_size |
| | | self.buffer_size = buffer_size |
| | | self.max_token_length = kwargs.get("max_token_length", 5000) |
| | | self.shuffle_idx = np.arange(self.total_samples) |
| | | self.shuffle = shuffle |
| | | |
| | | def __len__(self): |
| | | return self.total_samples |
| | | |
| | | def set_epoch(self, epoch): |
| | | np.random.seed(epoch) |
| | | |
| | | def __iter__(self): |
| | | |
| | | if self.shuffle: |
| | | np.random.shuffle(self.shuffle_idx) |
| | | |
| | | batch = [] |
| | | max_token = 0 |
| | | num_sample = 0 |
| | | |
| | | iter_num = (self.total_samples - 1) // self.buffer_size + 1 |
| | | # print("iter_num: ", iter_num) |
| | | for iter in range(self.pre_idx + 1, iter_num): |
| | | datalen_with_index = [] |
| | | for i in range(self.buffer_size): |
| | | idx = iter * self.buffer_size + i |
| | | if idx >= self.total_samples: |
| | | continue |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.get_source_len(idx_map) + \ |
| | | self.dataset.get_target_len(idx_map) |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |
| | | datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) |
| | | for item in datalen_with_index_sort: |
| | | idx, sample_len_cur_raw = item |
| | | if sample_len_cur_raw > self.max_token_length: |
| | | continue |
| | | |
| | | max_token_cur = max(max_token, sample_len_cur_raw) |
| | | max_token_padding = 1 + num_sample |
| | | if self.batch_type == 'length': |
| | | max_token_padding *= max_token_cur |
| | | if max_token_padding <= self.batch_size: |
| | | batch.append(idx) |
| | | max_token = max_token_cur |
| | | num_sample += 1 |
| | | else: |
| | | yield batch |
| | | batch = [idx] |
| | | max_token = sample_len_cur_raw |
| | | num_sample = 1 |
| | | |
| | |
| | | import json |
| | | import os |
| | | import json |
| | | from omegaconf import OmegaConf |
| | | import torch |
| | | |
| | | from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf |
| | | |
| | | |
| | | def download_model(**kwargs): |
| | | model_hub = kwargs.get("model_hub", "ms") |
| | | if model_hub == "ms": |
| | | kwargs = download_from_ms(**kwargs) |
| | | |
| | | return kwargs |
| | | model_hub = kwargs.get("model_hub", "ms") |
| | | if model_hub == "ms": |
| | | kwargs = download_from_ms(**kwargs) |
| | | |
| | | return kwargs |
| | | |
| | | def download_from_ms(**kwargs): |
| | | model_or_path = kwargs.get("model") |
| | | if model_or_path in name_maps_ms: |
| | | model_or_path = name_maps_ms[model_or_path] |
| | | model_revision = kwargs.get("model_revision") |
| | | if not os.path.exists(model_or_path): |
| | | model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True)) |
| | | kwargs["model_path"] = model_or_path |
| | | |
| | | config = os.path.join(model_or_path, "config.yaml") |
| | | if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")): |
| | | |
| | | config = OmegaConf.load(config) |
| | | kwargs = OmegaConf.merge(config, kwargs) |
| | | init_param = os.path.join(model_or_path, "model.pb") |
| | | kwargs["init_param"] = init_param |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.txt")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.json")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json") |
| | | if os.path.exists(os.path.join(model_or_path, "seg_dict")): |
| | | kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") |
| | | if os.path.exists(os.path.join(model_or_path, "bpe.model")): |
| | | kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") |
| | | kwargs["model"] = config["model"] |
| | | if os.path.exists(os.path.join(model_or_path, "am.mvn")): |
| | | kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") |
| | | if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")): |
| | | kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict") |
| | | elif os.path.exists(os.path.join(model_or_path, "configuration.json")): |
| | | with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: |
| | | conf_json = json.load(f) |
| | | cfg = {} |
| | | add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg) |
| | | cfg.update(kwargs) |
| | | config = OmegaConf.load(cfg["config"]) |
| | | kwargs = OmegaConf.merge(config, cfg) |
| | | kwargs["model"] = config["model"] |
| | | return OmegaConf.to_container(kwargs, resolve=True) |
| | | model_or_path = kwargs.get("model") |
| | | if model_or_path in name_maps_ms: |
| | | model_or_path = name_maps_ms[model_or_path] |
| | | model_revision = kwargs.get("model_revision") |
| | | if not os.path.exists(model_or_path): |
| | | model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True)) |
| | | kwargs["model_path"] = model_or_path |
| | | |
| | | config = os.path.join(model_or_path, "config.yaml") |
| | | if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")): |
| | | |
| | | config = OmegaConf.load(config) |
| | | kwargs = OmegaConf.merge(config, kwargs) |
| | | init_param = os.path.join(model_or_path, "model.pb") |
| | | kwargs["init_param"] = init_param |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.txt")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.json")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json") |
| | | if os.path.exists(os.path.join(model_or_path, "seg_dict")): |
| | | kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") |
| | | if os.path.exists(os.path.join(model_or_path, "bpe.model")): |
| | | kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") |
| | | kwargs["model"] = config["model"] |
| | | if os.path.exists(os.path.join(model_or_path, "am.mvn")): |
| | | kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") |
| | | if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")): |
| | | kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict") |
| | | elif os.path.exists(os.path.join(model_or_path, "configuration.json")): |
| | | with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: |
| | | conf_json = json.load(f) |
| | | cfg = {} |
| | | add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg) |
| | | cfg.update(kwargs) |
| | | config = OmegaConf.load(cfg["config"]) |
| | | kwargs = OmegaConf.merge(config, cfg) |
| | | kwargs["model"] = config["model"] |
| | | return OmegaConf.to_container(kwargs, resolve=True) |
| | | |
| | | def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}): |
| | | |
| | | if isinstance(file_path_metas, dict): |
| | | for k, v in file_path_metas.items(): |
| | | if isinstance(v, str): |
| | | p = os.path.join(model_or_path, v) |
| | | if os.path.exists(p): |
| | | cfg[k] = p |
| | | elif isinstance(v, dict): |
| | | if k not in cfg: |
| | | cfg[k] = {} |
| | | return add_file_root_path(model_or_path, v, cfg[k]) |
| | | |
| | | return cfg |
| | | |
| | | if isinstance(file_path_metas, dict): |
| | | for k, v in file_path_metas.items(): |
| | | if isinstance(v, str): |
| | | p = os.path.join(model_or_path, v) |
| | | if os.path.exists(p): |
| | | cfg[k] = p |
| | | elif isinstance(v, dict): |
| | | if k not in cfg: |
| | | cfg[k] = {} |
| | | return add_file_root_path(model_or_path, v, cfg[k]) |
| | | |
| | | return cfg |
| | | |
| | | |
| | | def get_or_download_model_dir( |
| | | model, |
| | | model_revision=None, |
| | | is_training=False, |
| | | check_latest=True, |
| | | ): |
| | | """ Get local model directory or download model if necessary. |
| | | model, |
| | | model_revision=None, |
| | | is_training=False, |
| | | check_latest=True, |
| | | ): |
| | | """ Get local model directory or download model if necessary. |
| | | |
| | | Args: |
| | | model (str): model id or path to local model directory. |
| | | model_revision (str, optional): model version number. |
| | | :param is_training: |
| | | """ |
| | | from modelscope.hub.check_model import check_local_model_is_latest |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | Args: |
| | | model (str): model id or path to local model directory. |
| | | model_revision (str, optional): model version number. |
| | | :param is_training: |
| | | """ |
| | | from modelscope.hub.check_model import check_local_model_is_latest |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | |
| | | from modelscope.utils.constant import Invoke, ThirdParty |
| | | |
| | | key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE |
| | | |
| | | if os.path.exists(model) and check_latest: |
| | | model_cache_dir = model if os.path.isdir( |
| | | model) else os.path.dirname(model) |
| | | try: |
| | | check_local_model_is_latest( |
| | | model_cache_dir, |
| | | user_agent={ |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | except: |
| | | print("could not check the latest version") |
| | | else: |
| | | model_cache_dir = snapshot_download( |
| | | model, |
| | | revision=model_revision, |
| | | user_agent={ |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | return model_cache_dir |
| | | from modelscope.utils.constant import Invoke, ThirdParty |
| | | |
| | | key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE |
| | | |
| | | if os.path.exists(model) and check_latest: |
| | | model_cache_dir = model if os.path.isdir( |
| | | model) else os.path.dirname(model) |
| | | try: |
| | | check_local_model_is_latest( |
| | | model_cache_dir, |
| | | user_agent={ |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | except: |
| | | print("could not check the latest version") |
| | | else: |
| | | model_cache_dir = snapshot_download( |
| | | model, |
| | | revision=model_revision, |
| | | user_agent={ |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | return model_cache_dir |
| | |
| | | |
| | | |
| | | name_maps_ms = { |
| | | "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn", |
| | | "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", |
| | | "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", |
| | | "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", |
| | | "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large", |
| | | "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline", |
| | | "cam++": "damo/speech_campplus_sv_zh-cn_16k-common", |
| | | } |
| | | |
| | | name_maps_hf = { |
| | |
| | | from pathlib import Path |
| | | import os |
| | | import argparse |
| | | from pathlib import Path |
| | | |
| | | from funasr.utils.types import str2bool |
| | | |
| | | |
| | | def main(): |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument('--model-name', type=str, required=True) |
| | | parser.add_argument('--export-dir', type=str, required=True) |
| | | parser.add_argument('--export', type=str2bool, default=True, help='whether to export model') |
| | | parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') |
| | | parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]') |
| | | parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model') |
| | | parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') |
| | | parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') |
| | | parser.add_argument('--model_revision', type=str, default=None, help='model_revision') |
| | | parser.add_argument('--calib_num', type=int, default=200, help='calib max num') |
| | | args = parser.parse_args() |
| | | |
| | | model_dir = args.model_name |
| | | if not Path(args.model_name).exists(): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | try: |
| | | model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision) |
| | | except: |
| | | raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \ |
| | | (model_dir) |
| | | if args.export: |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if args.quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | if not os.path.exists(model_file): |
| | | print(".onnx is not exist, begin to export onnx") |
| | | from funasr.bin.export_model import ModelExport |
| | | export_model = ModelExport( |
| | | cache_dir=args.export_dir, |
| | | onnx=True, |
| | | device="cpu", |
| | | quant=args.quantize, |
| | | ) |
| | | export_model.export(model_dir) |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument('--model-name', type=str, required=True) |
| | | parser.add_argument('--export-dir', type=str, required=True) |
| | | parser.add_argument('--export', type=str2bool, default=True, help='whether to export model') |
| | | parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') |
| | | parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]') |
| | | parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model') |
| | | parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') |
| | | parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') |
| | | parser.add_argument('--model_revision', type=str, default=None, help='model_revision') |
| | | parser.add_argument('--calib_num', type=int, default=200, help='calib max num') |
| | | args = parser.parse_args() |
| | | |
| | | model_dir = args.model_name |
| | | if not Path(args.model_name).exists(): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | try: |
| | | model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision) |
| | | except: |
| | | raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \ |
| | | (model_dir) |
| | | if args.export: |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if args.quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | if not os.path.exists(model_file): |
| | | print(".onnx is not exist, begin to export onnx") |
| | | from funasr.bin.export_model import ModelExport |
| | | export_model = ModelExport( |
| | | cache_dir=args.export_dir, |
| | | onnx=True, |
| | | device="cpu", |
| | | quant=args.quantize, |
| | | ) |
| | | export_model.export(model_dir) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | main() |
| | |
| | | |
| | | @tables.register("model_classes", "Branchformer") |
| | | class Branchformer(Transformer): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__(*args, **kwargs) |
| | | super().__init__(*args, **kwargs) |
| | |
| | | |
| | | @tables.register("model_classes", "Conformer") |
| | | class Conformer(Transformer): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__(*args, **kwargs) |
| | | super().__init__(*args, **kwargs) |
| | |
| | | punc_array = punctuations |
| | | else: |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | |
| | |
| | | |
| | | @tables.register("model_classes", "EBranchformer") |
| | | class EBranchformer(Transformer): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__(*args, **kwargs) |
| | | super().__init__(*args, **kwargs) |
| | |
| | | |
| | | @tables.register("model_classes", "SANM") |
| | | class SANM(Transformer): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__(*args, **kwargs) |
| | | super().__init__(*args, **kwargs) |
| | |
| | | |
| | | import math |
| | | import torch |
| | | import numpy as np |
| | | import math |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | import logging |
| | | import torch.nn.functional as F |
| | | from funasr.models.scama.utils import sequence_mask |
| | | |
| | | from funasr.models.scama.utils import sequence_mask |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class overlap_chunk(): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | San-m: Memory equipped self-attention for end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | San-m: Memory equipped self-attention for end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | | """ |
| | | def __init__(self, |
| | | chunk_size: tuple = (16,), |
| | | stride: tuple = (10,), |
| | | pad_left: tuple = (0,), |
| | | encoder_att_look_back_factor: tuple = (1,), |
| | | """ |
| | | def __init__(self, |
| | | chunk_size: tuple = (16,), |
| | | stride: tuple = (10,), |
| | | pad_left: tuple = (0,), |
| | | encoder_att_look_back_factor: tuple = (1,), |
| | | shfit_fsmn: int = 0, |
| | | decoder_att_look_back_factor: tuple = (1,), |
| | | ): |
| | | ): |
| | | |
| | | pad_left = self.check_chunk_size_args(chunk_size, pad_left) |
| | | encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor) |
| | | decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor) |
| | | self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \ |
| | | = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor |
| | | self.shfit_fsmn = shfit_fsmn |
| | | self.x_add_mask = None |
| | | self.x_rm_mask = None |
| | | self.x_len = None |
| | | self.mask_shfit_chunk = None |
| | | self.mask_chunk_predictor = None |
| | | self.mask_att_chunk_encoder = None |
| | | self.mask_shift_att_chunk_decoder = None |
| | | self.chunk_outs = None |
| | | self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ |
| | | = None, None, None, None, None |
| | | pad_left = self.check_chunk_size_args(chunk_size, pad_left) |
| | | encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor) |
| | | decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor) |
| | | self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \ |
| | | = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor |
| | | self.shfit_fsmn = shfit_fsmn |
| | | self.x_add_mask = None |
| | | self.x_rm_mask = None |
| | | self.x_len = None |
| | | self.mask_shfit_chunk = None |
| | | self.mask_chunk_predictor = None |
| | | self.mask_att_chunk_encoder = None |
| | | self.mask_shift_att_chunk_decoder = None |
| | | self.chunk_outs = None |
| | | self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \ |
| | | = None, None, None, None, None |
| | | |
| | | def check_chunk_size_args(self, chunk_size, x): |
| | | if len(x) < len(chunk_size): |
| | | x = [x[0] for i in chunk_size] |
| | | return x |
| | | def check_chunk_size_args(self, chunk_size, x): |
| | | if len(x) < len(chunk_size): |
| | | x = [x[0] for i in chunk_size] |
| | | return x |
| | | |
| | | def get_chunk_size(self, |
| | | ind: int = 0 |
| | | ): |
| | | # with torch.no_grad: |
| | | chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \ |
| | | self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind] |
| | | self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \ |
| | | = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor |
| | | return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur |
| | | def get_chunk_size(self, |
| | | ind: int = 0 |
| | | ): |
| | | # with torch.no_grad: |
| | | chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \ |
| | | self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind] |
| | | self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \ |
| | | = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor |
| | | return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur |
| | | |
| | | def random_choice(self, training=True, decoding_ind=None): |
| | | chunk_num = len(self.chunk_size) |
| | | ind = 0 |
| | | if training and chunk_num > 1: |
| | | ind = torch.randint(0, chunk_num, ()).cpu().item() |
| | | if not training and decoding_ind is not None: |
| | | ind = int(decoding_ind) |
| | | def random_choice(self, training=True, decoding_ind=None): |
| | | chunk_num = len(self.chunk_size) |
| | | ind = 0 |
| | | if training and chunk_num > 1: |
| | | ind = torch.randint(0, chunk_num, ()).cpu().item() |
| | | if not training and decoding_ind is not None: |
| | | ind = int(decoding_ind) |
| | | |
| | | return ind |
| | | return ind |
| | | |
| | | |
| | | |
| | | |
| | | def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): |
| | | def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1): |
| | | |
| | | with torch.no_grad(): |
| | | x_len = x_len.cpu().numpy() |
| | | x_len_max = x_len.max() |
| | | with torch.no_grad(): |
| | | x_len = x_len.cpu().numpy() |
| | | x_len_max = x_len.max() |
| | | |
| | | chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind) |
| | | shfit_fsmn = self.shfit_fsmn |
| | | pad_right = chunk_size - stride - pad_left |
| | | chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind) |
| | | shfit_fsmn = self.shfit_fsmn |
| | | pad_right = chunk_size - stride - pad_left |
| | | |
| | | chunk_num_batch = np.ceil(x_len/stride).astype(np.int32) |
| | | x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride |
| | | x_len_chunk = x_len_chunk.astype(x_len.dtype) |
| | | x_len_chunk_max = x_len_chunk.max() |
| | | chunk_num_batch = np.ceil(x_len/stride).astype(np.int32) |
| | | x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride |
| | | x_len_chunk = x_len_chunk.astype(x_len.dtype) |
| | | x_len_chunk_max = x_len_chunk.max() |
| | | |
| | | chunk_num = int(math.ceil(x_len_max/stride)) |
| | | dtype = np.int32 |
| | | max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left) |
| | | x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) |
| | | x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) |
| | | mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) |
| | | mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype) |
| | | mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) |
| | | mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype) |
| | | for chunk_ids in range(chunk_num): |
| | | # x_mask add |
| | | fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype) |
| | | x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) |
| | | x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype) |
| | | x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype) |
| | | x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) |
| | | x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] |
| | | x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0) |
| | | x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0) |
| | | chunk_num = int(math.ceil(x_len_max/stride)) |
| | | dtype = np.int32 |
| | | max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left) |
| | | x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) |
| | | x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) |
| | | mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) |
| | | mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype) |
| | | mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) |
| | | mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype) |
| | | for chunk_ids in range(chunk_num): |
| | | # x_mask add |
| | | fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype) |
| | | x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) |
| | | x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype) |
| | | x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype) |
| | | x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1) |
| | | x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp] |
| | | x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0) |
| | | x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0) |
| | | |
| | | # x_mask rm |
| | | fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype) |
| | | padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype) |
| | | padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype) |
| | | x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) |
| | | x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype) |
| | | x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype) |
| | | x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0) |
| | | x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride] |
| | | x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1) |
| | | x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1) |
| | | # x_mask rm |
| | | fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype) |
| | | padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype) |
| | | padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype) |
| | | x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) |
| | | x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype) |
| | | x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype) |
| | | x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0) |
| | | x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride] |
| | | x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1) |
| | | x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1) |
| | | |
| | | # fsmn_padding_mask |
| | | pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) |
| | | ones_1 = np.ones([chunk_size, num_units], dtype=dtype) |
| | | mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0) |
| | | mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) |
| | | # fsmn_padding_mask |
| | | pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) |
| | | ones_1 = np.ones([chunk_size, num_units], dtype=dtype) |
| | | mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0) |
| | | mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) |
| | | |
| | | # predictor mask |
| | | zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) |
| | | ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) |
| | | zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype) |
| | | ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) |
| | | mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0) |
| | | mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) |
| | | # predictor mask |
| | | zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) |
| | | ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) |
| | | zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype) |
| | | ones_zeros = np.concatenate([ones_2, zeros_3], axis=0) |
| | | mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0) |
| | | mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0) |
| | | |
| | | # encoder att mask |
| | | zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype) |
| | | # encoder att mask |
| | | zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype) |
| | | |
| | | zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) |
| | | zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype) |
| | | zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) |
| | | zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype) |
| | | |
| | | encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) |
| | | zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) |
| | | ones_2_mid = np.ones([stride, stride], dtype=dtype) |
| | | zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype) |
| | | zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype) |
| | | ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) |
| | | ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1) |
| | | ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) |
| | | encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) |
| | | zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) |
| | | ones_2_mid = np.ones([stride, stride], dtype=dtype) |
| | | zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype) |
| | | zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype) |
| | | ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0) |
| | | ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1) |
| | | ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) |
| | | |
| | | zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) |
| | | ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) |
| | | ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) |
| | | zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) |
| | | ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) |
| | | ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) |
| | | |
| | | zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) |
| | | zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype) |
| | | zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0) |
| | | zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype) |
| | | |
| | | ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1) |
| | | mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0) |
| | | mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0) |
| | | ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1) |
| | | mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0) |
| | | mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0) |
| | | |
| | | |
| | | # decoder fsmn_shift_att_mask |
| | | zeros_1 = np.zeros([shfit_fsmn, 1]) |
| | | ones_1 = np.ones([chunk_size, 1]) |
| | | mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0) |
| | | mask_shift_att_chunk_decoder = np.concatenate( |
| | | [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0) |
| | | # decoder fsmn_shift_att_mask |
| | | zeros_1 = np.zeros([shfit_fsmn, 1]) |
| | | ones_1 = np.ones([chunk_size, 1]) |
| | | mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0) |
| | | mask_shift_att_chunk_decoder = np.concatenate( |
| | | [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0) |
| | | |
| | | self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left] |
| | | self.x_len_chunk = x_len_chunk |
| | | self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] |
| | | self.x_len = x_len |
| | | self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] |
| | | self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :] |
| | | self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max] |
| | | self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :] |
| | | self.chunk_outs = (self.x_add_mask, |
| | | self.x_len_chunk, |
| | | self.x_rm_mask, |
| | | self.x_len, |
| | | self.mask_shfit_chunk, |
| | | self.mask_chunk_predictor, |
| | | self.mask_att_chunk_encoder, |
| | | self.mask_shift_att_chunk_decoder) |
| | | self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left] |
| | | self.x_len_chunk = x_len_chunk |
| | | self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] |
| | | self.x_len = x_len |
| | | self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] |
| | | self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :] |
| | | self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max] |
| | | self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :] |
| | | self.chunk_outs = (self.x_add_mask, |
| | | self.x_len_chunk, |
| | | self.x_rm_mask, |
| | | self.x_len, |
| | | self.mask_shfit_chunk, |
| | | self.mask_chunk_predictor, |
| | | self.mask_att_chunk_encoder, |
| | | self.mask_shift_att_chunk_decoder) |
| | | |
| | | return self.chunk_outs |
| | | return self.chunk_outs |
| | | |
| | | |
| | | def split_chunk(self, x, x_len, chunk_outs): |
| | | """ |
| | | :param x: (b, t, d) |
| | | :param x_length: (b) |
| | | :param ind: int |
| | | :return: |
| | | """ |
| | | x = x[:, :x_len.max(), :] |
| | | b, t, d = x.size() |
| | | x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to( |
| | | x.device) |
| | | x *= x_len_mask[:, :, None] |
| | | def split_chunk(self, x, x_len, chunk_outs): |
| | | """ |
| | | :param x: (b, t, d) |
| | | :param x_length: (b) |
| | | :param ind: int |
| | | :return: |
| | | """ |
| | | x = x[:, :x_len.max(), :] |
| | | b, t, d = x.size() |
| | | x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to( |
| | | x.device) |
| | | x *= x_len_mask[:, :, None] |
| | | |
| | | x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) |
| | | x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype) |
| | | pad = (0, 0, self.pad_left_cur, 0) |
| | | x = F.pad(x, pad, "constant", 0.0) |
| | | b, t, d = x.size() |
| | | x = torch.transpose(x, 1, 0) |
| | | x = torch.reshape(x, [t, -1]) |
| | | x_chunk = torch.mm(x_add_mask, x) |
| | | x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) |
| | | x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype) |
| | | x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype) |
| | | pad = (0, 0, self.pad_left_cur, 0) |
| | | x = F.pad(x, pad, "constant", 0.0) |
| | | b, t, d = x.size() |
| | | x = torch.transpose(x, 1, 0) |
| | | x = torch.reshape(x, [t, -1]) |
| | | x_chunk = torch.mm(x_add_mask, x) |
| | | x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0) |
| | | |
| | | return x_chunk, x_len_chunk |
| | | return x_chunk, x_len_chunk |
| | | |
| | | def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): |
| | | x_chunk = x_chunk[:, :x_len_chunk.max(), :] |
| | | b, t, d = x_chunk.size() |
| | | x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( |
| | | x_chunk.device) |
| | | x_chunk *= x_len_chunk_mask[:, :, None] |
| | | def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs): |
| | | x_chunk = x_chunk[:, :x_len_chunk.max(), :] |
| | | b, t, d = x_chunk.size() |
| | | x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to( |
| | | x_chunk.device) |
| | | x_chunk *= x_len_chunk_mask[:, :, None] |
| | | |
| | | x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype) |
| | | x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) |
| | | x_chunk = torch.transpose(x_chunk, 1, 0) |
| | | x_chunk = torch.reshape(x_chunk, [t, -1]) |
| | | x = torch.mm(x_rm_mask, x_chunk) |
| | | x = torch.reshape(x, [-1, b, d]).transpose(1, 0) |
| | | x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype) |
| | | x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype) |
| | | x_chunk = torch.transpose(x_chunk, 1, 0) |
| | | x_chunk = torch.reshape(x_chunk, [t, -1]) |
| | | x = torch.mm(x_rm_mask, x_chunk) |
| | | x = torch.reshape(x, [-1, b, d]).transpose(1, 0) |
| | | |
| | | return x, x_len |
| | | return x, x_len |
| | | |
| | | def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | |
| | | def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | |
| | | def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, :, :, ], [batch_size, 1, num_units]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, :, :, ], [batch_size, 1, 1]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32): |
| | | with torch.no_grad(): |
| | | x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx] |
| | | x = np.tile(x[None, None, :, 0], [batch_size, 1, 1]) |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | |
| | | |
| | | def build_scama_mask_for_cross_attention_decoder( |
| | | predictor_alignments: torch.Tensor, |
| | | predictor_alignments: torch.Tensor, |
| | | encoder_sequence_length: torch.Tensor, |
| | | chunk_size: int = 5, |
| | | encoder_chunk_size: int = 5, |
| | |
| | | attention_chunk_size: int = 1, |
| | | attention_chunk_type: str = 'chunk', |
| | | step=None, |
| | | predictor_mask_chunk_hopping: torch.Tensor = None, |
| | | decoder_att_look_back_factor: int = 1, |
| | | mask_shift_att_chunk_decoder: torch.Tensor = None, |
| | | target_length: torch.Tensor = None, |
| | | is_training=True, |
| | | predictor_mask_chunk_hopping: torch.Tensor = None, |
| | | decoder_att_look_back_factor: int = 1, |
| | | mask_shift_att_chunk_decoder: torch.Tensor = None, |
| | | target_length: torch.Tensor = None, |
| | | is_training=True, |
| | | dtype: torch.dtype = torch.float32): |
| | | with torch.no_grad(): |
| | | device = predictor_alignments.device |
| | | batch_size, chunk_num = predictor_alignments.size() |
| | | maximum_encoder_length = encoder_sequence_length.max().item() |
| | | int_type = predictor_alignments.dtype |
| | | if not is_training: |
| | | target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype) |
| | | maximum_target_length = target_length.max() |
| | | predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1) |
| | | predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1) |
| | | |
| | | |
| | | index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device) |
| | | index = torch.cumsum(index, dim=1) |
| | | index = index[:, :, None].repeat(1, 1, chunk_num) |
| | | |
| | | index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type) |
| | | index_div_bool_zeros = index_div == 0 |
| | | index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1 |
| | | |
| | | index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num) |
| | | |
| | | index_div_bool_zeros_count *= chunk_size |
| | | index_div_bool_zeros_count += attention_chunk_center_bias |
| | | index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length) |
| | | index_div_bool_zeros_count_ori = index_div_bool_zeros_count |
| | | |
| | | index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size |
| | | max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size |
| | | |
| | | mask_flip, mask_flip2 = None, None |
| | | if attention_chunk_size is not None: |
| | | index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size |
| | | index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk) |
| | | index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device) |
| | | mask_flip = 1 - index_div_bool_zeros_count_beg_mask |
| | | attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1) |
| | | index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2 |
| | | |
| | | index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk) |
| | | index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device) |
| | | mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask |
| | | |
| | | mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device) |
| | | |
| | | if predictor_mask_chunk_hopping is not None: |
| | | b, k, t = mask.size() |
| | | predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1) |
| | | |
| | | mask_mask_flip = mask |
| | | if mask_flip is not None: |
| | | mask_mask_flip = mask_flip * mask |
| | | |
| | | def _fn(): |
| | | mask_sliced = mask[:b, :k, encoder_chunk_size:t] |
| | | zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device) |
| | | mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2) |
| | | _, _, tt = predictor_mask_chunk_hopping.size() |
| | | pad_right_p = max_len_chunk - tt |
| | | predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0) |
| | | masked = mask_sliced * predictor_mask_chunk_hopping_pad |
| | | |
| | | mask_true = mask_mask_flip + masked |
| | | return mask_true |
| | | |
| | | mask = _fn() if t > chunk_size else mask_mask_flip |
| | | |
| | | |
| | | |
| | | if mask_flip2 is not None: |
| | | mask *= mask_flip2 |
| | | |
| | | mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device) |
| | | mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None] |
| | | |
| | | |
| | | |
| | | mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device) |
| | | mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :] |
| | | |
| | | |
| | | |
| | | |
| | | if attention_chunk_type == 'full': |
| | | mask = torch.ones_like(mask).to(device) |
| | | if mask_shift_att_chunk_decoder is not None: |
| | | mask = mask * mask_shift_att_chunk_decoder |
| | | mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device) |
| | | with torch.no_grad(): |
| | | device = predictor_alignments.device |
| | | batch_size, chunk_num = predictor_alignments.size() |
| | | maximum_encoder_length = encoder_sequence_length.max().item() |
| | | int_type = predictor_alignments.dtype |
| | | if not is_training: |
| | | target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype) |
| | | maximum_target_length = target_length.max() |
| | | predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1) |
| | | predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1) |
| | | |
| | | |
| | | index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device) |
| | | index = torch.cumsum(index, dim=1) |
| | | index = index[:, :, None].repeat(1, 1, chunk_num) |
| | | |
| | | index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type) |
| | | index_div_bool_zeros = index_div == 0 |
| | | index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1 |
| | | |
| | | index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num) |
| | | |
| | | index_div_bool_zeros_count *= chunk_size |
| | | index_div_bool_zeros_count += attention_chunk_center_bias |
| | | index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length) |
| | | index_div_bool_zeros_count_ori = index_div_bool_zeros_count |
| | | |
| | | index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size |
| | | max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size |
| | | |
| | | mask_flip, mask_flip2 = None, None |
| | | if attention_chunk_size is not None: |
| | | index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size |
| | | index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk) |
| | | index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device) |
| | | mask_flip = 1 - index_div_bool_zeros_count_beg_mask |
| | | attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1) |
| | | index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2 |
| | | |
| | | index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk) |
| | | index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device) |
| | | mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask |
| | | |
| | | mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device) |
| | | |
| | | if predictor_mask_chunk_hopping is not None: |
| | | b, k, t = mask.size() |
| | | predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1) |
| | | |
| | | mask_mask_flip = mask |
| | | if mask_flip is not None: |
| | | mask_mask_flip = mask_flip * mask |
| | | |
| | | def _fn(): |
| | | mask_sliced = mask[:b, :k, encoder_chunk_size:t] |
| | | zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device) |
| | | mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2) |
| | | _, _, tt = predictor_mask_chunk_hopping.size() |
| | | pad_right_p = max_len_chunk - tt |
| | | predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0) |
| | | masked = mask_sliced * predictor_mask_chunk_hopping_pad |
| | | |
| | | mask_true = mask_mask_flip + masked |
| | | return mask_true |
| | | |
| | | mask = _fn() if t > chunk_size else mask_mask_flip |
| | | |
| | | |
| | | |
| | | if mask_flip2 is not None: |
| | | mask *= mask_flip2 |
| | | |
| | | mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device) |
| | | mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None] |
| | | |
| | | |
| | | |
| | | mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device) |
| | | mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :] |
| | | |
| | | |
| | | |
| | | |
| | | if attention_chunk_type == 'full': |
| | | mask = torch.ones_like(mask).to(device) |
| | | if mask_shift_att_chunk_decoder is not None: |
| | | mask = mask * mask_shift_att_chunk_decoder |
| | | mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device) |
| | | |
| | | return mask |
| | | return mask |
| | | |
| | |
| | | import os |
| | | import torch |
| | | from torch.nn import functional as F |
| | | import yaml |
| | | import torch |
| | | import numpy as np |
| | | from torch.nn import functional as F |
| | | |
| | | |
| | | def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): |
| | | if maxlen is None: |
| | | maxlen = lengths.max() |
| | | row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
| | | matrix = torch.unsqueeze(lengths, dim=-1) |
| | | mask = row_vector < matrix |
| | | mask = mask.detach() |
| | | if maxlen is None: |
| | | maxlen = lengths.max() |
| | | row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
| | | matrix = torch.unsqueeze(lengths, dim=-1) |
| | | mask = row_vector < matrix |
| | | mask = mask.detach() |
| | | |
| | | return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
| | | return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
| | | |
| | | def apply_cmvn(inputs, mvn): |
| | | device = inputs.device |
| | | dtype = inputs.dtype |
| | | frame, dim = inputs.shape |
| | | meams = np.tile(mvn[0:1, :dim], (frame, 1)) |
| | | vars = np.tile(mvn[1:2, :dim], (frame, 1)) |
| | | inputs -= torch.from_numpy(meams).type(dtype).to(device) |
| | | inputs *= torch.from_numpy(vars).type(dtype).to(device) |
| | | device = inputs.device |
| | | dtype = inputs.dtype |
| | | frame, dim = inputs.shape |
| | | meams = np.tile(mvn[0:1, :dim], (frame, 1)) |
| | | vars = np.tile(mvn[1:2, :dim], (frame, 1)) |
| | | inputs -= torch.from_numpy(meams).type(dtype).to(device) |
| | | inputs *= torch.from_numpy(vars).type(dtype).to(device) |
| | | |
| | | return inputs.type(torch.float32) |
| | | return inputs.type(torch.float32) |
| | | |
| | | |
| | | |
| | |
| | | |
| | | |
| | | |
| | | outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True) |
| | | outputs *= stoch_layer_coeff |
| | | outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True) |
| | | outputs *= stoch_layer_coeff |
| | | |
| | | input_dim = inputs.size(-1) |
| | | output_dim = outputs.size(-1) |
| | | input_dim = inputs.size(-1) |
| | | output_dim = outputs.size(-1) |
| | | |
| | | if input_dim == output_dim: |
| | | outputs += inputs |
| | | return outputs |
| | | if input_dim == output_dim: |
| | | outputs += inputs |
| | | return outputs |
| | | |
| | | |
| | | def proc_tf_vocab(vocab_path): |
| | | with open(vocab_path, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | if '<unk>' not in token_list: |
| | | token_list.append('<unk>') |
| | | return token_list |
| | | with open(vocab_path, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | if '<unk>' not in token_list: |
| | | token_list.append('<unk>') |
| | | return token_list |
| | | |
| | | |
| | | def gen_config_for_tfmodel(config_path, vocab_path, output_dir): |
| | | token_list = proc_tf_vocab(vocab_path) |
| | | with open(config_path, encoding="utf-8") as f: |
| | | config = yaml.safe_load(f) |
| | | |
| | | config['token_list'] = token_list |
| | | |
| | | if not os.path.exists(output_dir): |
| | | os.makedirs(output_dir) |
| | | |
| | | with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f: |
| | | yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False) |
| | | token_list = proc_tf_vocab(vocab_path) |
| | | with open(config_path, encoding="utf-8") as f: |
| | | config = yaml.safe_load(f) |
| | | |
| | | config['token_list'] = token_list |
| | | |
| | | if not os.path.exists(output_dir): |
| | | os.makedirs(output_dir) |
| | | |
| | | with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f: |
| | | yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False) |
| | | |
| | | |
| | | class NoAliasSafeDumper(yaml.SafeDumper): |
| | | # Disable anchor/alias in yaml because looks ugly |
| | | def ignore_aliases(self, data): |
| | | return True |
| | | # Disable anchor/alias in yaml because looks ugly |
| | | def ignore_aliases(self, data): |
| | | return True |
| | | |
| | | |
| | | def yaml_no_alias_safe_dump(data, stream=None, **kwargs): |
| | | """Safe-dump in yaml with no anchor/alias""" |
| | | return yaml.dump( |
| | | data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs |
| | | ) |
| | | """Safe-dump in yaml with no anchor/alias""" |
| | | return yaml.dump( |
| | | data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs |
| | | ) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | import sys |
| | | |
| | | config_path = sys.argv[1] |
| | | vocab_path = sys.argv[2] |
| | | output_dir = sys.argv[3] |
| | | gen_config_for_tfmodel(config_path, vocab_path, output_dir) |
| | | import sys |
| | | |
| | | config_path = sys.argv[1] |
| | | vocab_path = sys.argv[2] |
| | | output_dir = sys.argv[3] |
| | | gen_config_for_tfmodel(config_path, vocab_path, output_dir) |
| | |
| | | ys_pad_lens, |
| | | hw_list, |
| | | nfilter=50, |
| | | seaco_weight=1.0): |
| | | seaco_weight=1.0): |
| | | # decoder forward |
| | | decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) |
| | | decoder_pred = torch.log_softmax(decoder_out, dim=-1) |
| | |
| | | |
| | | dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation |
| | | dha_pred = torch.log_softmax(dha_output, dim=-1) |
| | | # import pdb; pdb.set_trace() |
| | | def _merge_res(dec_output, dha_output): |
| | | lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) |
| | | dha_ids = dha_output.max(-1)[-1][0] |
| | | dha_ids = dha_output.max(-1)[-1]# [0] |
| | | dha_mask = (dha_ids == 8377).int().unsqueeze(-1) |
| | | a = (1 - lmbd) / lmbd |
| | | b = 1 / lmbd |
| | |
| | | logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) |
| | | return logits |
| | | merged_pred = _merge_res(decoder_pred, dha_pred) |
| | | # import pdb; pdb.set_trace() |
| | | return merged_pred |
| | | else: |
| | | return decoder_pred |
| | |
| | | speech_lengths: (Batch, ) |
| | | """ |
| | | # with autocast(False): |
| | | # # 1. Extract feats |
| | | # feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | # # 1. Extract feats |
| | | # feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | # |
| | | # # 2. Data augmentation |
| | | # if self.specaug is not None and self.training: |
| | | # feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | # # 2. Data augmentation |
| | | # if self.specaug is not None and self.training: |
| | | # feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | # |
| | | # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | # if self.normalize is not None: |
| | | # feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | # if self.normalize is not None: |
| | | # feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # Pre-encoder, e.g. used for raw input data |
| | | # if self.preencoder is not None: |
| | | # feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | # feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk( |
| | | encoder_out, |
| | | encoder_out_lens, |
| | |
| | | |
| | | # # Post-encoder, e.g. NLU |
| | | # if self.postencoder is not None: |
| | | # encoder_out, encoder_out_lens = self.postencoder( |
| | | # encoder_out, encoder_out_lens |
| | | # ) |
| | | # encoder_out, encoder_out_lens = self.postencoder( |
| | | # encoder_out, encoder_out_lens |
| | | # ) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | from funasr.optimizers.sgd import SGD |
| | | |
| | | optim_classes = dict( |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | | adagrad=torch.optim.Adagrad, |
| | | adamax=torch.optim.Adamax, |
| | | asgd=torch.optim.ASGD, |
| | | lbfgs=torch.optim.LBFGS, |
| | | rmsprop=torch.optim.RMSprop, |
| | | rprop=torch.optim.Rprop, |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | | adagrad=torch.optim.Adagrad, |
| | | adamax=torch.optim.Adamax, |
| | | asgd=torch.optim.ASGD, |
| | | lbfgs=torch.optim.LBFGS, |
| | | rmsprop=torch.optim.RMSprop, |
| | | rprop=torch.optim.Rprop, |
| | | ) |
| | |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | |
| | | scheduler_classes = dict( |
| | | ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, |
| | | lambdalr=torch.optim.lr_scheduler.LambdaLR, |
| | | steplr=torch.optim.lr_scheduler.StepLR, |
| | | multisteplr=torch.optim.lr_scheduler.MultiStepLR, |
| | | exponentiallr=torch.optim.lr_scheduler.ExponentialLR, |
| | | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
| | | noamlr=NoamLR, |
| | | warmuplr=WarmupLR, |
| | | tri_stage=TriStageLR, |
| | | cycliclr=torch.optim.lr_scheduler.CyclicLR, |
| | | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | | ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, |
| | | lambdalr=torch.optim.lr_scheduler.LambdaLR, |
| | | steplr=torch.optim.lr_scheduler.StepLR, |
| | | multisteplr=torch.optim.lr_scheduler.MultiStepLR, |
| | | exponentiallr=torch.optim.lr_scheduler.ExponentialLR, |
| | | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
| | | noamlr=NoamLR, |
| | | warmuplr=WarmupLR, |
| | | tri_stage=TriStageLR, |
| | | cycliclr=torch.optim.lr_scheduler.CyclicLR, |
| | | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | | ) |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | from typing import Iterable |
| | | from typing import List |
| | | from pathlib import Path |
| | | from typing import Dict |
| | | from typing import Iterable |
| | | from typing import List |
| | | from typing import Union |
| | | import json |
| | | |
| | | import numpy as np |
| | | from abc import ABC |
| | | from pathlib import Path |
| | | from abc import abstractmethod |
| | | from typing import Union, Iterable, List, Dict |
| | | |
| | | |
| | | class AbsTokenizer(ABC): |
| | | @abstractmethod |
| | | def text2tokens(self, line: str) -> List[str]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | raise NotImplementedError |
| | | @abstractmethod |
| | | def text2tokens(self, line: str) -> List[str]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class BaseTokenizer(ABC): |
| | | def __init__(self, token_list: Union[Path, str, Iterable[str]] = None, |
| | | unk_symbol: str = "<unk>", |
| | | **kwargs, |
| | | ): |
| | | |
| | | if token_list is not None: |
| | | if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | |
| | | with token_list.open("r", encoding="utf-8") as f: |
| | | for idx, line in enumerate(f): |
| | | line = line.rstrip() |
| | | self.token_list.append(line) |
| | | elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | |
| | | with open(token_list, 'r', encoding='utf-8') as f: |
| | | self.token_list = json.load(f) |
| | | |
| | | |
| | | else: |
| | | self.token_list: List[str] = list(token_list) |
| | | self.token_list_repr = "" |
| | | for i, t in enumerate(self.token_list): |
| | | if i == 3: |
| | | break |
| | | self.token_list_repr += f"{t}, " |
| | | self.token_list_repr += f"... (NVocab={(len(self.token_list))})" |
| | | |
| | | self.token2id: Dict[str, int] = {} |
| | | for i, t in enumerate(self.token_list): |
| | | if t in self.token2id: |
| | | raise RuntimeError(f'Symbol "{t}" is duplicated') |
| | | self.token2id[t] = i |
| | | |
| | | self.unk_symbol = unk_symbol |
| | | if self.unk_symbol not in self.token2id: |
| | | raise RuntimeError( |
| | | f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | def encode(self, text): |
| | | tokens = self.text2tokens(text) |
| | | text_ints = self.tokens2ids(tokens) |
| | | |
| | | return text_ints |
| | | |
| | | def decode(self, text_ints): |
| | | token = self.ids2tokens(text_ints) |
| | | text = self.tokens2text(token) |
| | | return text |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | | |
| | | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | if isinstance(integers, np.ndarray) and integers.ndim != 1: |
| | | raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | @abstractmethod |
| | | def text2tokens(self, line: str) -> List[str]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | raise NotImplementedError |
| | | def __init__(self, token_list: Union[Path, str, Iterable[str]] = None, |
| | | unk_symbol: str = "<unk>", |
| | | **kwargs, |
| | | ): |
| | | |
| | | if token_list is not None: |
| | | if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | |
| | | with token_list.open("r", encoding="utf-8") as f: |
| | | for idx, line in enumerate(f): |
| | | line = line.rstrip() |
| | | self.token_list.append(line) |
| | | elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | |
| | | with open(token_list, 'r', encoding='utf-8') as f: |
| | | self.token_list = json.load(f) |
| | | |
| | | |
| | | else: |
| | | self.token_list: List[str] = list(token_list) |
| | | self.token_list_repr = "" |
| | | for i, t in enumerate(self.token_list): |
| | | if i == 3: |
| | | break |
| | | self.token_list_repr += f"{t}, " |
| | | self.token_list_repr += f"... (NVocab={(len(self.token_list))})" |
| | | |
| | | self.token2id: Dict[str, int] = {} |
| | | for i, t in enumerate(self.token_list): |
| | | if t in self.token2id: |
| | | raise RuntimeError(f'Symbol "{t}" is duplicated') |
| | | self.token2id[t] = i |
| | | |
| | | self.unk_symbol = unk_symbol |
| | | if self.unk_symbol not in self.token2id: |
| | | raise RuntimeError( |
| | | f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | def encode(self, text): |
| | | tokens = self.text2tokens(text) |
| | | text_ints = self.tokens2ids(tokens) |
| | | |
| | | return text_ints |
| | | |
| | | def decode(self, text_ints): |
| | | token = self.ids2tokens(text_ints) |
| | | text = self.tokens2text(token) |
| | | return text |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | | |
| | | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | if isinstance(integers, np.ndarray) and integers.ndim != 1: |
| | | raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | @abstractmethod |
| | | def text2tokens(self, line: str) -> List[str]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def tokens2text(self, tokens: Iterable[str]) -> str: |
| | | raise NotImplementedError |
| | |
| | | import torch |
| | | import os |
| | | from funasr.train_utils.device_funcs import to_device |
| | | import logging |
| | | import time |
| | | import torch |
| | | import logging |
| | | from tqdm import tqdm |
| | | from contextlib import nullcontext |
| | | import torch.distributed as dist |
| | | from contextlib import nullcontext |
| | | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | | |
| | | |
| | | class Trainer: |
| | | """ |
| | | A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, |
| | | and optionally resuming from a saved checkpoint. |
| | | """ |
| | | A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch, |
| | | and optionally resuming from a saved checkpoint. |
| | | |
| | | Attributes: |
| | | max_epoch (int): Maximum number of epochs for training. |
| | | model (torch.nn.Module): The model to be trained. |
| | | optim (torch.optim.Optimizer): The optimizer to use for training. |
| | | scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. |
| | | dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset. |
| | | dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset. |
| | | output_dir (str): Directory where model checkpoints will be saved. |
| | | resume (str, optional): Path to a checkpoint to resume training from. |
| | | """ |
| | | |
| | | def __init__(self, model, |
| | | optim, |
| | | scheduler, |
| | | dataloader_train, |
| | | dataloader_val, |
| | | local_rank, |
| | | use_ddp=False, |
| | | use_fsdp=False, |
| | | **kwargs): |
| | | """ |
| | | Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings. |
| | | Attributes: |
| | | max_epoch (int): Maximum number of epochs for training. |
| | | model (torch.nn.Module): The model to be trained. |
| | | optim (torch.optim.Optimizer): The optimizer to use for training. |
| | | scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. |
| | | dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset. |
| | | dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset. |
| | | output_dir (str): Directory where model checkpoints will be saved. |
| | | resume (str, optional): Path to a checkpoint to resume training from. |
| | | """ |
| | | |
| | | def __init__(self, model, |
| | | optim, |
| | | scheduler, |
| | | dataloader_train, |
| | | dataloader_val, |
| | | local_rank, |
| | | use_ddp=False, |
| | | use_fsdp=False, |
| | | **kwargs): |
| | | """ |
| | | Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings. |
| | | |
| | | Args: |
| | | model (torch.nn.Module): The model to be trained. |
| | | optim (torch.optim.Optimizer): The optimizer to use for training. |
| | | scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. |
| | | dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset. |
| | | dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset. |
| | | **kwargs: Additional keyword arguments: |
| | | max_epoch (int): The maximum number of epochs for training. |
| | | output_dir (str): The directory where model checkpoints will be saved. Default is './'. |
| | | resume (str, optional): The file path to a checkpoint to resume training from. |
| | | """ |
| | | |
| | | self.model = model |
| | | self.optim = optim |
| | | self.scheduler = scheduler |
| | | self.dataloader_train = dataloader_train |
| | | self.dataloader_val = dataloader_val |
| | | self.output_dir = kwargs.get('output_dir', './') |
| | | self.resume = kwargs.get('resume', True) |
| | | self.start_epoch = 0 |
| | | self.max_epoch = kwargs.get('max_epoch', 100) |
| | | self.local_rank = local_rank |
| | | self.use_ddp = use_ddp |
| | | self.use_fsdp = use_fsdp |
| | | self.device = next(model.parameters()).device |
| | | self.kwargs = kwargs |
| | | |
| | | if self.resume: |
| | | self._resume_checkpoint(self.resume) |
| | | |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | self.rank = rank |
| | | self.world_size = world_size |
| | | |
| | | def _save_checkpoint(self, epoch): |
| | | """ |
| | | Saves a checkpoint containing the model's state, the optimizer's state, |
| | | and the scheduler's state at the end of the given epoch. This method is |
| | | intended to be called at the end of each epoch to save the training progress. |
| | | Args: |
| | | model (torch.nn.Module): The model to be trained. |
| | | optim (torch.optim.Optimizer): The optimizer to use for training. |
| | | scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler. |
| | | dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset. |
| | | dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset. |
| | | **kwargs: Additional keyword arguments: |
| | | max_epoch (int): The maximum number of epochs for training. |
| | | output_dir (str): The directory where model checkpoints will be saved. Default is './'. |
| | | resume (str, optional): The file path to a checkpoint to resume training from. |
| | | """ |
| | | |
| | | self.model = model |
| | | self.optim = optim |
| | | self.scheduler = scheduler |
| | | self.dataloader_train = dataloader_train |
| | | self.dataloader_val = dataloader_val |
| | | self.output_dir = kwargs.get('output_dir', './') |
| | | self.resume = kwargs.get('resume', True) |
| | | self.start_epoch = 0 |
| | | self.max_epoch = kwargs.get('max_epoch', 100) |
| | | self.local_rank = local_rank |
| | | self.use_ddp = use_ddp |
| | | self.use_fsdp = use_fsdp |
| | | self.device = next(model.parameters()).device |
| | | self.kwargs = kwargs |
| | | |
| | | if self.resume: |
| | | self._resume_checkpoint(self.resume) |
| | | |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | self.rank = rank |
| | | self.world_size = world_size |
| | | |
| | | def _save_checkpoint(self, epoch): |
| | | """ |
| | | Saves a checkpoint containing the model's state, the optimizer's state, |
| | | and the scheduler's state at the end of the given epoch. This method is |
| | | intended to be called at the end of each epoch to save the training progress. |
| | | |
| | | Args: |
| | | epoch (int): The epoch number at which the checkpoint is being saved. |
| | | """ |
| | | state = { |
| | | 'epoch': epoch, |
| | | 'state_dict': self.model.state_dict(), |
| | | 'optimizer': self.optim.state_dict(), |
| | | 'scheduler': self.scheduler.state_dict(), |
| | | } |
| | | # Create output directory if it does not exist |
| | | os.makedirs(self.output_dir, exist_ok=True) |
| | | filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') |
| | | torch.save(state, filename) |
| | | print(f'Checkpoint saved to {filename}') |
| | | |
| | | def _resume_checkpoint(self, resume_path): |
| | | """ |
| | | Resumes training from a checkpoint at the given file path. |
| | | Loads the model's state, the optimizer's state, and the scheduler's state. |
| | | Args: |
| | | epoch (int): The epoch number at which the checkpoint is being saved. |
| | | """ |
| | | state = { |
| | | 'epoch': epoch, |
| | | 'state_dict': self.model.state_dict(), |
| | | 'optimizer': self.optim.state_dict(), |
| | | 'scheduler': self.scheduler.state_dict(), |
| | | } |
| | | # Create output directory if it does not exist |
| | | os.makedirs(self.output_dir, exist_ok=True) |
| | | filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') |
| | | torch.save(state, filename) |
| | | print(f'Checkpoint saved to {filename}') |
| | | |
| | | def _resume_checkpoint(self, resume_path): |
| | | """ |
| | | Resumes training from a checkpoint at the given file path. |
| | | Loads the model's state, the optimizer's state, and the scheduler's state. |
| | | |
| | | Args: |
| | | resume_path (str): The file path to the checkpoint to resume from. |
| | | """ |
| | | if os.path.isfile(resume_path): |
| | | checkpoint = torch.load(resume_path) |
| | | self.start_epoch = checkpoint['epoch'] + 1 |
| | | self.model.load_state_dict(checkpoint['state_dict']) |
| | | self.optim.load_state_dict(checkpoint['optimizer']) |
| | | self.scheduler.load_state_dict(checkpoint['scheduler']) |
| | | print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})") |
| | | else: |
| | | print(f"No checkpoint found at '{resume_path}', starting from scratch") |
| | | |
| | | def run(self): |
| | | """ |
| | | Starts the training process, iterating over epochs, training the model, |
| | | and saving checkpoints at the end of each epoch. |
| | | """ |
| | | for epoch in range(self.start_epoch, self.max_epoch + 1): |
| | | self._train_epoch(epoch) |
| | | # self._validate_epoch(epoch) |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | self.scheduler.step() |
| | | |
| | | def _train_epoch(self, epoch): |
| | | """ |
| | | Defines the training process for a single epoch with gradient accumulation. |
| | | Args: |
| | | epoch (int): The current epoch number. |
| | | """ |
| | | self.model.train() |
| | | pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), |
| | | dynamic_ncols=True) |
| | | |
| | | # Set the number of steps for gradient accumulation |
| | | accum_grad = self.kwargs.get("accum_grad", 1) |
| | | # Initialize the gradient accumulation |
| | | self.optim.zero_grad() |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | for batch_idx, batch in enumerate(self.dataloader_train): |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1-time5:0.3f}" |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | batch = to_device(batch, self.device) |
| | | |
| | | my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext |
| | | with my_context(): |
| | | time2 = time.perf_counter() |
| | | retval = self.model(**batch) |
| | | time3 = time.perf_counter() |
| | | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| | | loss, stats, weight = retval |
| | | stats = {k: v for k, v in stats.items() if v is not None} |
| | | if self.use_ddp or self.use_fsdp: |
| | | # Apply weighted averaging for loss and stats |
| | | loss = (loss * weight.type(loss.dtype)).sum() |
| | | # if distributed, this method can also apply all_reduce() |
| | | stats, weight = recursive_average(stats, weight, distributed=True) |
| | | # Now weight is summation over all workers |
| | | loss /= weight |
| | | # Multiply world_size because DistributedDataParallel |
| | | # automatically normalizes the gradient by world_size. |
| | | loss *= self.world_size |
| | | # Scale the loss since we're not updating for every mini-batch |
| | | loss = loss / accum_grad |
| | | loss.backward() |
| | | time4 = time.perf_counter() |
| | | speed_stats["backward_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | | # Perform an optimizer step only after accumulating enough gradients |
| | | if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train): |
| | | # Perform gradient clipping if it is set |
| | | if self.kwargs.get("grad_clip", None) is not None: |
| | | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | | self.model.parameters(), |
| | | max_norm=self.kwargs.get("grad_clip", 10.0), |
| | | norm_type=self.kwargs.get("grad_clip_type", 2.0), |
| | | ) |
| | | if not torch.isfinite(grad_norm): |
| | | logging.warning( |
| | | f"The grad norm is {grad_norm}. Skipping updating the model." |
| | | ) |
| | | self.optim.zero_grad() # Reset gradients |
| | | continue |
| | | |
| | | # Execute an optimization step (update model parameters) |
| | | self.optim.step() |
| | | self.scheduler.step() |
| | | # Clear gradients for the next accumulation stage |
| | | self.optim.zero_grad() |
| | | total_time = f"{time.perf_counter() - time5:0.3f}" |
| | | time5 = time.perf_counter() |
| | | speed_stats["optim_time"] = f"{time5 - time4:0.3f}" |
| | | |
| | | speed_stats["total_time"] = total_time |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"Epoch: {epoch + 1}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" |
| | | ) |
| | | pbar.set_description(description) |
| | | |
| | | # if batch_idx == 2: |
| | | # break |
| | | pbar.close() |
| | | Args: |
| | | resume_path (str): The file path to the checkpoint to resume from. |
| | | """ |
| | | if os.path.isfile(resume_path): |
| | | checkpoint = torch.load(resume_path) |
| | | self.start_epoch = checkpoint['epoch'] + 1 |
| | | self.model.load_state_dict(checkpoint['state_dict']) |
| | | self.optim.load_state_dict(checkpoint['optimizer']) |
| | | self.scheduler.load_state_dict(checkpoint['scheduler']) |
| | | print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})") |
| | | else: |
| | | print(f"No checkpoint found at '{resume_path}', starting from scratch") |
| | | |
| | | def run(self): |
| | | """ |
| | | Starts the training process, iterating over epochs, training the model, |
| | | and saving checkpoints at the end of each epoch. |
| | | """ |
| | | for epoch in range(self.start_epoch, self.max_epoch + 1): |
| | | self._train_epoch(epoch) |
| | | # self._validate_epoch(epoch) |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | self.scheduler.step() |
| | | |
| | | def _train_epoch(self, epoch): |
| | | """ |
| | | Defines the training process for a single epoch with gradient accumulation. |
| | | Args: |
| | | epoch (int): The current epoch number. |
| | | """ |
| | | self.model.train() |
| | | pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), |
| | | dynamic_ncols=True) |
| | | |
| | | # Set the number of steps for gradient accumulation |
| | | accum_grad = self.kwargs.get("accum_grad", 1) |
| | | # Initialize the gradient accumulation |
| | | self.optim.zero_grad() |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | for batch_idx, batch in enumerate(self.dataloader_train): |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1-time5:0.3f}" |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | batch = to_device(batch, self.device) |
| | | |
| | | my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext |
| | | with my_context(): |
| | | time2 = time.perf_counter() |
| | | retval = self.model(**batch) |
| | | time3 = time.perf_counter() |
| | | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| | | loss, stats, weight = retval |
| | | stats = {k: v for k, v in stats.items() if v is not None} |
| | | if self.use_ddp or self.use_fsdp: |
| | | # Apply weighted averaging for loss and stats |
| | | loss = (loss * weight.type(loss.dtype)).sum() |
| | | # if distributed, this method can also apply all_reduce() |
| | | stats, weight = recursive_average(stats, weight, distributed=True) |
| | | # Now weight is summation over all workers |
| | | loss /= weight |
| | | # Multiply world_size because DistributedDataParallel |
| | | # automatically normalizes the gradient by world_size. |
| | | loss *= self.world_size |
| | | # Scale the loss since we're not updating for every mini-batch |
| | | loss = loss / accum_grad |
| | | loss.backward() |
| | | time4 = time.perf_counter() |
| | | speed_stats["backward_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | | # Perform an optimizer step only after accumulating enough gradients |
| | | if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train): |
| | | # Perform gradient clipping if it is set |
| | | if self.kwargs.get("grad_clip", None) is not None: |
| | | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | | self.model.parameters(), |
| | | max_norm=self.kwargs.get("grad_clip", 10.0), |
| | | norm_type=self.kwargs.get("grad_clip_type", 2.0), |
| | | ) |
| | | if not torch.isfinite(grad_norm): |
| | | logging.warning( |
| | | f"The grad norm is {grad_norm}. Skipping updating the model." |
| | | ) |
| | | self.optim.zero_grad() # Reset gradients |
| | | continue |
| | | |
| | | # Execute an optimization step (update model parameters) |
| | | self.optim.step() |
| | | self.scheduler.step() |
| | | # Clear gradients for the next accumulation stage |
| | | self.optim.zero_grad() |
| | | total_time = f"{time.perf_counter() - time5:0.3f}" |
| | | time5 = time.perf_counter() |
| | | speed_stats["optim_time"] = f"{time5 - time4:0.3f}" |
| | | |
| | | speed_stats["total_time"] = total_time |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"Epoch: {epoch + 1}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" |
| | | ) |
| | | pbar.set_description(description) |
| | | |
| | | # if batch_idx == 2: |
| | | # break |
| | | pbar.close() |
| | | |
| | | def _validate_epoch(self, epoch): |
| | | """ |
| | | Defines the validation process for a single epoch. |
| | | Should be implemented with the actual model validation steps. |
| | | |
| | | Args: |
| | | epoch (int): The current epoch number. |
| | | """ |
| | | self.model.eval() |
| | | with torch.no_grad(): |
| | | for data, target in self.dataloader_val: |
| | | # Implement the model validation steps here |
| | | pass |
| | | def _validate_epoch(self, epoch): |
| | | """ |
| | | Defines the validation process for a single epoch. |
| | | Should be implemented with the actual model validation steps. |
| | | |
| | | Args: |
| | | epoch (int): The current epoch number. |
| | | """ |
| | | self.model.eval() |
| | | with torch.no_grad(): |
| | | for data, target in self.dataloader_val: |
| | | # Implement the model validation steps here |
| | | pass |
| | |
| | | import logging |
| | | from torch.nn.utils.rnn import pad_sequence |
| | | try: |
| | | from funasr.download.file import download_from_url |
| | | from funasr.download.file import download_from_url |
| | | except: |
| | | print("urllib is not installed, if you infer from url, please install it first.") |
| | | print("urllib is not installed, if you infer from url, please install it first.") |
| | | |
| | | |
| | | |
| | | def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs): |
| | | if isinstance(data_or_path_or_list, (list, tuple)): |
| | | if data_type is not None and isinstance(data_type, (list, tuple)): |
| | | if isinstance(data_or_path_or_list, (list, tuple)): |
| | | if data_type is not None and isinstance(data_type, (list, tuple)): |
| | | |
| | | data_types = [data_type] * len(data_or_path_or_list) |
| | | data_or_path_or_list_ret = [[] for d in data_type] |
| | | for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)): |
| | | |
| | | for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)): |
| | | |
| | | data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs) |
| | | data_or_path_or_list_ret[j].append(data_or_path_or_list_j) |
| | | data_types = [data_type] * len(data_or_path_or_list) |
| | | data_or_path_or_list_ret = [[] for d in data_type] |
| | | for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)): |
| | | |
| | | for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)): |
| | | |
| | | data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs) |
| | | data_or_path_or_list_ret[j].append(data_or_path_or_list_j) |
| | | |
| | | return data_or_path_or_list_ret |
| | | else: |
| | | return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list] |
| | | |
| | | if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file |
| | | data_or_path_or_list = download_from_url(data_or_path_or_list) |
| | | |
| | | if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file |
| | | if data_type is None or data_type == "sound": |
| | | data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) |
| | | data_or_path_or_list = data_or_path_or_list[0, :] |
| | | elif data_type == "text" and tokenizer is not None: |
| | | data_or_path_or_list = tokenizer.encode(data_or_path_or_list) |
| | | elif data_type == "image": # undo |
| | | pass |
| | | elif data_type == "video": # undo |
| | | pass |
| | | |
| | | # if data_in is a file or url, set is_final=True |
| | | if "cache" in kwargs: |
| | | kwargs["cache"]["is_final"] = True |
| | | elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None: |
| | | data_or_path_or_list = tokenizer.encode(data_or_path_or_list) |
| | | elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point |
| | | data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,] |
| | | else: |
| | | pass |
| | | # print(f"unsupport data type: {data_or_path_or_list}, return raw data") |
| | | |
| | | if audio_fs != fs and data_type != "text": |
| | | resampler = torchaudio.transforms.Resample(audio_fs, fs) |
| | | data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :] |
| | | return data_or_path_or_list |
| | | return data_or_path_or_list_ret |
| | | else: |
| | | return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list] |
| | | |
| | | if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file |
| | | data_or_path_or_list = download_from_url(data_or_path_or_list) |
| | | |
| | | if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file |
| | | if data_type is None or data_type == "sound": |
| | | data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) |
| | | data_or_path_or_list = data_or_path_or_list[0, :] |
| | | elif data_type == "text" and tokenizer is not None: |
| | | data_or_path_or_list = tokenizer.encode(data_or_path_or_list) |
| | | elif data_type == "image": # undo |
| | | pass |
| | | elif data_type == "video": # undo |
| | | pass |
| | | |
| | | # if data_in is a file or url, set is_final=True |
| | | if "cache" in kwargs: |
| | | kwargs["cache"]["is_final"] = True |
| | | elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None: |
| | | data_or_path_or_list = tokenizer.encode(data_or_path_or_list) |
| | | elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point |
| | | data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,] |
| | | else: |
| | | pass |
| | | # print(f"unsupport data type: {data_or_path_or_list}, return raw data") |
| | | |
| | | if audio_fs != fs and data_type != "text": |
| | | resampler = torchaudio.transforms.Resample(audio_fs, fs) |
| | | data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :] |
| | | return data_or_path_or_list |
| | | |
| | | def load_bytes(input): |
| | | middle_data = np.frombuffer(input, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | | if middle_data.dtype.kind not in 'iu': |
| | | raise TypeError("'middle_data' must be an array of integers") |
| | | dtype = np.dtype('float32') |
| | | if dtype.kind != 'f': |
| | | raise TypeError("'dtype' must be a floating point type") |
| | | |
| | | i = np.iinfo(middle_data.dtype) |
| | | abs_max = 2 ** (i.bits - 1) |
| | | offset = i.min + abs_max |
| | | array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) |
| | | return array |
| | | middle_data = np.frombuffer(input, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | | if middle_data.dtype.kind not in 'iu': |
| | | raise TypeError("'middle_data' must be an array of integers") |
| | | dtype = np.dtype('float32') |
| | | if dtype.kind != 'f': |
| | | raise TypeError("'dtype' must be a floating point type") |
| | | |
| | | i = np.iinfo(middle_data.dtype) |
| | | abs_max = 2 ** (i.bits - 1) |
| | | offset = i.min + abs_max |
| | | array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) |
| | | return array |
| | | |
| | | def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs): |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if isinstance(data, np.ndarray): |
| | | data = torch.from_numpy(data) |
| | | if len(data.shape) < 2: |
| | | data = data[None, :] # data: [batch, N] |
| | | data_len = [data.shape[1]] if data_len is None else data_len |
| | | elif isinstance(data, torch.Tensor): |
| | | if len(data.shape) < 2: |
| | | data = data[None, :] # data: [batch, N] |
| | | data_len = [data.shape[1]] if data_len is None else data_len |
| | | elif isinstance(data, (list, tuple)): |
| | | data_list, data_len = [], [] |
| | | for data_i in data: |
| | | if isinstance(data_i, np.ndarray): |
| | | data_i = torch.from_numpy(data_i) |
| | | data_list.append(data_i) |
| | | data_len.append(data_i.shape[0]) |
| | | data = pad_sequence(data_list, batch_first=True) # data: [batch, N] |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # if data_type == "sound": |
| | | data, data_len = frontend(data, data_len, **kwargs) |
| | | |
| | | if isinstance(data_len, (list, tuple)): |
| | | data_len = torch.tensor([data_len]) |
| | | return data.to(torch.float32), data_len.to(torch.int32) |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if isinstance(data, np.ndarray): |
| | | data = torch.from_numpy(data) |
| | | if len(data.shape) < 2: |
| | | data = data[None, :] # data: [batch, N] |
| | | data_len = [data.shape[1]] if data_len is None else data_len |
| | | elif isinstance(data, torch.Tensor): |
| | | if len(data.shape) < 2: |
| | | data = data[None, :] # data: [batch, N] |
| | | data_len = [data.shape[1]] if data_len is None else data_len |
| | | elif isinstance(data, (list, tuple)): |
| | | data_list, data_len = [], [] |
| | | for data_i in data: |
| | | if isinstance(data_i, np.ndarray): |
| | | data_i = torch.from_numpy(data_i) |
| | | data_list.append(data_i) |
| | | data_len.append(data_i.shape[0]) |
| | | data = pad_sequence(data_list, batch_first=True) # data: [batch, N] |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # if data_type == "sound": |
| | | data, data_len = frontend(data, data_len, **kwargs) |
| | | |
| | | if isinstance(data_len, (list, tuple)): |
| | | data_len = torch.tensor([data_len]) |
| | | return data.to(torch.float32), data_len.to(torch.int32) |
| | | |
| | |
| | | import torch |
| | | from torch.nn.utils.rnn import pad_sequence |
| | | |
| | | def slice_padding_fbank(speech, speech_lengths, vad_segments): |
| | | speech_list = [] |
| | | speech_lengths_list = [] |
| | | for i, segment in enumerate(vad_segments): |
| | | |
| | | bed_idx = int(segment[0][0]*16) |
| | | end_idx = min(int(segment[0][1]*16), speech_lengths[0]) |
| | | speech_i = speech[0, bed_idx: end_idx] |
| | | speech_lengths_i = end_idx-bed_idx |
| | | speech_list.append(speech_i) |
| | | speech_lengths_list.append(speech_lengths_i) |
| | | feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) |
| | | speech_lengths_pad = torch.Tensor(speech_lengths_list).int() |
| | | return feats_pad, speech_lengths_pad |
| | | |
| | | def slice_padding_fbank(speech, speech_lengths, vad_segments): |
| | | speech_list = [] |
| | | speech_lengths_list = [] |
| | | for i, segment in enumerate(vad_segments): |
| | | |
| | | bed_idx = int(segment[0][0]*16) |
| | | end_idx = min(int(segment[0][1]*16), speech_lengths[0]) |
| | | speech_i = speech[0, bed_idx: end_idx] |
| | | speech_lengths_i = end_idx-bed_idx |
| | | speech_list.append(speech_i) |
| | | speech_lengths_list.append(speech_lengths_i) |
| | | feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) |
| | | speech_lengths_pad = torch.Tensor(speech_lengths_list).int() |
| | | return feats_pad, speech_lengths_pad |
| | | |
| | | def slice_padding_audio_samples(speech, speech_lengths, vad_segments): |
| | | speech_list = [] |
| | | speech_lengths_list = [] |
| | | for i, segment in enumerate(vad_segments): |
| | | bed_idx = int(segment[0][0] * 16) |
| | | end_idx = min(int(segment[0][1] * 16), speech_lengths) |
| | | speech_i = speech[bed_idx: end_idx] |
| | | speech_lengths_i = end_idx - bed_idx |
| | | speech_list.append(speech_i) |
| | | speech_lengths_list.append(speech_lengths_i) |
| | | |
| | | return speech_list, speech_lengths_list |
| | | speech_list = [] |
| | | speech_lengths_list = [] |
| | | for i, segment in enumerate(vad_segments): |
| | | bed_idx = int(segment[0][0] * 16) |
| | | end_idx = min(int(segment[0][1] * 16), speech_lengths) |
| | | speech_i = speech[bed_idx: end_idx] |
| | | speech_lengths_i = end_idx - bed_idx |
| | | speech_list.append(speech_i) |
| | | speech_lengths_list.append(speech_lengths_i) |
| | | |
| | | return speech_list, speech_lengths_list |
| | |
| | | |
| | | from funasr.runtime.python.libtorch.funasr_torch import Paraformer |
| | | if args.backend == "onnx": |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads) |
| | | |
| | | wav_file_f = open(args.wav_file, 'r') |
| | |
| | | |
| | | output_dir = args.output_dir |
| | | if not os.path.exists(output_dir): |
| | | os.makedirs(output_dir) |
| | | os.makedirs(output_dir) |
| | | if os.name == 'nt': # Windows |
| | | newline = '\r\n' |
| | | newline = '\r\n' |
| | | else: # Linux Mac |
| | | newline = '\n' |
| | | newline = '\n' |
| | | text_f = open(os.path.join(output_dir, "text"), "w", newline=newline) |
| | | token_f = open(os.path.join(output_dir, "token"), "w", newline=newline) |
| | | |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_name, wav_path = wav_path_i.strip().split() |
| | | result = model(wav_path) |
| | | text_i = "{} {}\n".format(wav_name, result[0]['preds'][0]) |
| | | token_i = "{} {}\n".format(wav_name, result[0]['preds'][1]) |
| | | text_f.write(text_i) |
| | | text_f.flush() |
| | | token_f.write(token_i) |
| | | token_f.flush() |
| | | wav_name, wav_path = wav_path_i.strip().split() |
| | | result = model(wav_path) |
| | | text_i = "{} {}\n".format(wav_name, result[0]['preds'][0]) |
| | | token_i = "{} {}\n".format(wav_name, result[0]['preds'][1]) |
| | | text_f.write(text_i) |
| | | text_f.flush() |
| | | token_f.write(token_i) |
| | | token_f.flush() |
| | | text_f.close() |
| | | token_f.close() |
| | | |
| | | |
| | |
| | | |
| | | from funasr.runtime.python.libtorch.funasr_torch import Paraformer |
| | | if args.backend == "onnx": |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads) |
| | | |
| | | wav_file_f = open(args.wav_file, 'r') |
| | |
| | | num = 30 |
| | | wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip() |
| | | for i in range(num): |
| | | beg_time = time.time() |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = end_time-beg_time |
| | | total += duration |
| | | print(result) |
| | | print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53)) |
| | | beg_time = time.time() |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = end_time-beg_time |
| | | total += duration |
| | | print(result) |
| | | print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53)) |
| | | |
| | | # infer time |
| | | beg_time = time.time() |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | result = model(wav_path) |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = (end_time-beg_time)*1000 |
| | | print("total_time_comput_ms: {}".format(int(duration))) |
| | | |
| | | duration_time = 0.0 |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | duration_time += len(waveform)/16.0 |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | duration_time += len(waveform)/16.0 |
| | | print("total_time_wav_ms: {}".format(int(duration_time))) |
| | | |
| | | print("total_rtf: {:.5}".format(duration/duration_time)) |
| | |
| | | |
| | | from funasr.runtime.python.libtorch.funasr_torch import Paraformer |
| | | if args.backend == "onnx": |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads) |
| | | |
| | | wav_file_f = open(args.wav_file, 'r') |
| | |
| | | num = 30 |
| | | wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip() |
| | | for i in range(num): |
| | | beg_time = time.time() |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = end_time-beg_time |
| | | total += duration |
| | | print(result) |
| | | print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53)) |
| | | beg_time = time.time() |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = end_time-beg_time |
| | | total += duration |
| | | print(result) |
| | | print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53)) |
| | | |
| | | # infer time |
| | | wav_path = [] |
| | | beg_time = time.time() |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | wav_path += [wav_path_i] |
| | | wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | wav_path += [wav_path_i] |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = (end_time-beg_time)*1000 |
| | |
| | | |
| | | duration_time = 0.0 |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | duration_time += len(waveform)/16.0 |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | duration_time += len(waveform)/16.0 |
| | | print("total_time_wav_ms: {}".format(int(duration_time))) |
| | | |
| | | print("total_rtf: {:.5}".format(duration/duration_time)) |