Shi Xian
2024-01-15 ddbc8b5eded1fff6084001d160d46b532020ecb7
Merge pull request #1247 from alibaba-damo-academy/funasr1.0

Funasr1.0
29个文件已修改
2843 ■■■■ 已修改文件
README.md 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/demo.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/inference.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 294 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/datasets.py 171 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 108 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/samplers.py 141 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/download_from_hub.py 193 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/name_maps_from_hub.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/runtime_sdk_download_tool.py 76 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/branchformer/model.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer/model.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_branchformer/model.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/model.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/chunk_utilis.py 644 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/utils.py 107 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/uniasr/e2e_uni_asr.py 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/optimizers/__init__.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/schedulers/__init__.py 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py 178 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 444 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_utils.py 170 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/vad_utils.py 50 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/utils/test_cer.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/utils/test_rtf.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/utils/test_rtf_gpu.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -90,12 +90,15 @@
### Speech Recognition (Non-streaming)
```python
from funasr import AutoModel
model = AutoModel(model="paraformer-zh")
# for the long duration wav, you could add vad model
# model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc")
res = model(input="asr_example_zh.wav", batch_size=64)
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
                  vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
                  punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
                  spk_model="cam++", spk_model_revision="v2.0.2")
res = model(input=f"{model.model_path}/example/asr_example.wav",
            batch_size=16,
            hotword='魔搭')
print(res)
```
Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
@@ -108,7 +111,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.0")
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.2")
import soundfile
import os
@@ -163,7 +166,7 @@
```python
from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.1")
model = AutoModel(model="ct-punc", model_revision="v2.0.2")
res = model(input="那今天的会就到这里吧 happy new year 明年见")
print(res)
@@ -172,7 +175,7 @@
```python
from funasr import AutoModel
model = AutoModel(model="fa-zh", model_revision="v2.0.0")
model = AutoModel(model="fa-zh", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/asr_example.wav"
examples/industrial_data_pretraining/paraformer/demo.py
@@ -18,5 +18,5 @@
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
    res = model(**fbank_dict)
    print(res)
    res = model(**fbank_dict)
    print(res)
examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -11,8 +11,10 @@
                  vad_model_revision="v2.0.2",
                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model_revision="v2.0.2",
                  spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  spk_model_revision="v2.0.2",
                  )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
            hotword='达摩院 磨搭')
res = model(input=f"{model.model_path}/example/asr_example.wav",
            hotword='达摩院 魔搭')
print(res)
funasr/bin/inference.py
@@ -274,12 +274,9 @@
    def generate_with_vad(self, input, input_len=None, **cfg):
        
        # step.1: compute the vad model
        model = self.vad_model
        kwargs = self.vad_kwargs
        kwargs.update(cfg)
        self.vad_kwargs.update(cfg)
        beg_vad = time.time()
        res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
        vad_res = res
        res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
        end_vad = time.time()
        print(f"time cost vad: {end_vad - beg_vad:0.3f}")
@@ -312,10 +309,7 @@
            if not len(sorted_data):
                logging.info("decoding, utt: {}, empty speech".format(key))
                continue
            # if kwargs["device"] == "cpu":
            #     batch_size = 0
            if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
                batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
            
funasr/bin/train.py
@@ -1,178 +1,180 @@
import argparse
import logging
import os
import sys
from io import BytesIO
from collections.abc import Sequence
import torch
import hydra
import logging
import argparse
from io import BytesIO
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.models.lora.utils import mark_only_lora_as_trainable
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.train_utils.initialize import initialize
from funasr.download.download_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
# from funasr.tokenizer.funtoken import build_tokenizer
from funasr.train_utils.trainer import Trainer
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.download.download_from_hub import download_model
from funasr.register import tables
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
    if kwargs.get("debug", False):
        import pdb; pdb.set_trace()
    if kwargs.get("debug", False):
        import pdb; pdb.set_trace()
    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
    main(**kwargs)
    main(**kwargs)
def main(**kwargs):
    # preprocess_config(kwargs)
    # import pdb; pdb.set_trace()
    # set random seed
    tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    # Check if we are using DDP or FSDP
    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", None)
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
        torch.cuda.set_device(local_rank)
    # save config.yaml
    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
        OmegaConf.save(config=kwargs, f=yaml_file)
        logging.info("config.yaml is saved to: %s", yaml_file)
    # preprocess_config(kwargs)
    # import pdb; pdb.set_trace()
    # set random seed
    tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    # Check if we are using DDP or FSDP
    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", None)
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
        torch.cuda.set_device(local_rank)
    # save config.yaml
    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
        OmegaConf.save(config=kwargs, f=yaml_file)
        logging.info("config.yaml is saved to: %s", yaml_file)
    tokenizer = kwargs.get("tokenizer", None)
    if tokenizer is not None:
        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
        kwargs["tokenizer"] = tokenizer
    # build frontend if frontend is none None
    frontend = kwargs.get("frontend", None)
    if frontend is not None:
        frontend_class = tables.frontend_classes.get(frontend)
        frontend = frontend_class(**kwargs["frontend_conf"])
        kwargs["frontend"] = frontend
        kwargs["input_size"] = frontend.output_size()
    # import pdb;
    # pdb.set_trace()
    # build model
    model_class = tables.model_classes.get(kwargs["model"])
    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
    tokenizer = kwargs.get("tokenizer", None)
    if tokenizer is not None:
        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
        kwargs["tokenizer"] = tokenizer
    # build frontend if frontend is none None
    frontend = kwargs.get("frontend", None)
    if frontend is not None:
        frontend_class = tables.frontend_classes.get(frontend)
        frontend = frontend_class(**kwargs["frontend_conf"])
        kwargs["frontend"] = frontend
        kwargs["input_size"] = frontend.output_size()
    # import pdb;
    # pdb.set_trace()
    # build model
    model_class = tables.model_classes.get(kwargs["model"])
    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
    # init_param
    init_param = kwargs.get("init_param", None)
    if init_param is not None:
        if not isinstance(init_param, (list, tuple)):
            init_param = (init_param,)
        logging.info("init_param is not None: %s", init_param)
        for p in init_param:
            logging.info(f"Loading pretrained params from {p}")
            load_pretrained_model(
                model=model,
                init_param=p,
                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
                oss_bucket=kwargs.get("oss_bucket", None),
            )
    else:
        initialize(model, kwargs.get("init", "kaiming_normal"))
    # init_param
    init_param = kwargs.get("init_param", None)
    if init_param is not None:
        if not isinstance(init_param, (list, tuple)):
            init_param = (init_param,)
        logging.info("init_param is not None: %s", init_param)
        for p in init_param:
            logging.info(f"Loading pretrained params from {p}")
            load_pretrained_model(
                model=model,
                init_param=p,
                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
                oss_bucket=kwargs.get("oss_bucket", None),
            )
    else:
        initialize(model, kwargs.get("init", "kaiming_normal"))
    # freeze_param
    freeze_param = kwargs.get("freeze_param", None)
    if freeze_param is not None:
        freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
        for t in freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    # freeze_param
    freeze_param = kwargs.get("freeze_param", None)
    if freeze_param is not None:
        freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
        for t in freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
    elif use_fsdp:
        model = FSDP(model).cuda(local_rank)
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    # optim
    optim = kwargs.get("optim", "adam")
    assert optim in optim_classes
    optim_class = optim_classes.get(optim)
    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
    # scheduler
    scheduler = kwargs.get("scheduler", "warmuplr")
    assert scheduler in scheduler_classes
    scheduler_class = scheduler_classes.get(scheduler)
    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
    elif use_fsdp:
        model = FSDP(model).cuda(local_rank)
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    # optim
    optim = kwargs.get("optim", "adam")
    assert optim in optim_classes
    optim_class = optim_classes.get(optim)
    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
    # scheduler
    scheduler = kwargs.get("scheduler", "warmuplr")
    assert scheduler in scheduler_classes
    scheduler_class = scheduler_classes.get(scheduler)
    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
    # import pdb;
    # pdb.set_trace()
    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
    # import pdb;
    # pdb.set_trace()
    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
    if batch_sampler is not None:
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,
                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                pin_memory=True)
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
    if batch_sampler is not None:
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,
                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                pin_memory=True)
    trainer = Trainer(
        model=model,
        optim=optim,
        scheduler=scheduler,
        dataloader_train=dataloader_tr,
        dataloader_val=None,
        local_rank=local_rank,
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
        **kwargs.get("train_conf"),
    )
    trainer.run()
    if use_ddp or use_fsdp:
        torch.distributed.destroy_process_group()
    trainer = Trainer(
        model=model,
        optim=optim,
        scheduler=scheduler,
        dataloader_train=dataloader_tr,
        dataloader_val=None,
        local_rank=local_rank,
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
        **kwargs.get("train_conf"),
    )
    trainer.run()
    if use_ddp or use_fsdp:
        torch.distributed.destroy_process_group()
if __name__ == "__main__":
    main_hydra()
    main_hydra()
funasr/datasets/audio_datasets/datasets.py
@@ -1,102 +1,93 @@
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
import torchaudio
import time
import logging
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank
@tables.register("dataset_classes", "AudioDataset")
class AudioDataset(torch.utils.data.Dataset):
    """
    AudioDataset
    """
    def __init__(self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                  **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
    """
    AudioDataset
    """
    def __init__(self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                  **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.int_pad_value = int_pad_value
        self.float_pad_value = float_pad_value
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
        self.int_pad_value = int_pad_value
        self.float_pad_value = float_pad_value
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        ids = self.tokenizer.encode(target)
        ids_lengths = len(ids)
        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        ids = self.tokenizer.encode(target)
        ids_lengths = len(ids)
        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
        return {"speech": speech[0, :, :],
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                }
    def collator(self, samples: list=None):
        return {"speech": speech[0, :, :],
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                }
    def collator(self, samples: list=None):
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if data_list[0].dtype == torch.int64:
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if data_list[0].dtype == torch.int64:
                pad_value = self.int_pad_value
            else:
                pad_value = self.float_pad_value
            outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs
                pad_value = self.int_pad_value
            else:
                pad_value = self.float_pad_value
            outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs
funasr/datasets/audio_datasets/index_ds.py
@@ -1,64 +1,64 @@
import torch
import json
import torch.distributed as dist
import time
import torch
import logging
import torch.distributed as dist
from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonl")
class IndexDSJsonl(torch.utils.data.Dataset):
    def __init__(self, path):
        super().__init__()
        contents = []
        with open(path, encoding='utf-8') as fin:
            for line in fin:
                data = json.loads(line.strip())
                if "text" in data:  # for sft
                    self.contents.append(data['text'])
                if "source" in data:  # for speech lab pretrain
                    prompt = data["prompt"]
                    source = data["source"]
                    target = data["target"]
                    source_len = data["source_len"]
                    target_len = data["target_len"]
    def __init__(self, path):
        super().__init__()
        contents = []
        with open(path, encoding='utf-8') as fin:
            for line in fin:
                data = json.loads(line.strip())
                if "text" in data:  # for sft
                    self.contents.append(data['text'])
                if "source" in data:  # for speech lab pretrain
                    prompt = data["prompt"]
                    source = data["source"]
                    target = data["target"]
                    source_len = data["source_len"]
                    target_len = data["target_len"]
                    contents.append({"source": source,
                                     "prompt": prompt,
                                     "target": target,
                                     "source_len": source_len,
                                     "target_len": target_len,
                                     }
                                    )
        self.contents = []
        total_num = len(contents)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
            logging.warning("distributed is not initialized, only single shard")
        num_per_rank = total_num // world_size
        # rank = 0
        # import ipdb; ipdb.set_trace()
        self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
        logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
                    contents.append({"source": source,
                                     "prompt": prompt,
                                     "target": target,
                                     "source_len": source_len,
                                     "target_len": target_len,
                                     }
                                    )
        self.contents = []
        total_num = len(contents)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
            logging.warning("distributed is not initialized, only single shard")
        num_per_rank = total_num // world_size
        # rank = 0
        # import ipdb; ipdb.set_trace()
        self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
        logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, index):
        return self.contents[index]
    def get_source_len(self, data_dict):
        return data_dict["source_len"]
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, index):
        return self.contents[index]
    def get_source_len(self, data_dict):
        return data_dict["source_len"]
    def get_target_len(self, data_dict):
        return data_dict["target_len"] if "target_len" in data_dict else 0
    def get_target_len(self, data_dict):
        return data_dict["target_len"] if "target_len" in data_dict else 0
funasr/datasets/audio_datasets/samplers.py
@@ -1,5 +1,4 @@
import torch
import numpy as np
from funasr.register import tables
@@ -7,74 +6,74 @@
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = False,
                 shuffle: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle
    def __len__(self):
        return self.total_samples
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                sample_len_cur = self.dataset.get_source_len(idx_map) + \
                                 self.dataset.get_target_len(idx_map)
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type == 'length':
                    max_token_padding *= max_token_cur
                if max_token_padding <= self.batch_size:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    yield batch
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = False,
                 shuffle: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle
    def __len__(self):
        return self.total_samples
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                sample_len_cur = self.dataset.get_source_len(idx_map) + \
                                 self.dataset.get_target_len(idx_map)
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type == 'length':
                    max_token_padding *= max_token_cur
                if max_token_padding <= self.batch_size:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    yield batch
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
funasr/download/download_from_hub.py
@@ -1,110 +1,111 @@
import json
import os
import json
from omegaconf import OmegaConf
import torch
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
def download_model(**kwargs):
    model_hub = kwargs.get("model_hub", "ms")
    if model_hub == "ms":
        kwargs = download_from_ms(**kwargs)
    return kwargs
    model_hub = kwargs.get("model_hub", "ms")
    if model_hub == "ms":
        kwargs = download_from_ms(**kwargs)
    return kwargs
def download_from_ms(**kwargs):
    model_or_path = kwargs.get("model")
    if model_or_path in name_maps_ms:
        model_or_path = name_maps_ms[model_or_path]
    model_revision = kwargs.get("model_revision")
    if not os.path.exists(model_or_path):
        model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
    kwargs["model_path"] = model_or_path
    config = os.path.join(model_or_path, "config.yaml")
    if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
        config = OmegaConf.load(config)
        kwargs = OmegaConf.merge(config, kwargs)
        init_param = os.path.join(model_or_path, "model.pb")
        kwargs["init_param"] = init_param
        if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
        if os.path.exists(os.path.join(model_or_path, "tokens.json")):
            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
        if os.path.exists(os.path.join(model_or_path, "seg_dict")):
            kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
        if os.path.exists(os.path.join(model_or_path, "bpe.model")):
            kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
        kwargs["model"] = config["model"]
        if os.path.exists(os.path.join(model_or_path, "am.mvn")):
            kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
        if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
            kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
    elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
        with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
            conf_json = json.load(f)
            cfg = {}
            add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
            cfg.update(kwargs)
            config = OmegaConf.load(cfg["config"])
            kwargs = OmegaConf.merge(config, cfg)
        kwargs["model"] = config["model"]
    return OmegaConf.to_container(kwargs, resolve=True)
    model_or_path = kwargs.get("model")
    if model_or_path in name_maps_ms:
        model_or_path = name_maps_ms[model_or_path]
    model_revision = kwargs.get("model_revision")
    if not os.path.exists(model_or_path):
        model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
    kwargs["model_path"] = model_or_path
    config = os.path.join(model_or_path, "config.yaml")
    if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
        config = OmegaConf.load(config)
        kwargs = OmegaConf.merge(config, kwargs)
        init_param = os.path.join(model_or_path, "model.pb")
        kwargs["init_param"] = init_param
        if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
        if os.path.exists(os.path.join(model_or_path, "tokens.json")):
            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
        if os.path.exists(os.path.join(model_or_path, "seg_dict")):
            kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
        if os.path.exists(os.path.join(model_or_path, "bpe.model")):
            kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
        kwargs["model"] = config["model"]
        if os.path.exists(os.path.join(model_or_path, "am.mvn")):
            kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
        if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
            kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
    elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
        with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
            conf_json = json.load(f)
            cfg = {}
            add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
            cfg.update(kwargs)
            config = OmegaConf.load(cfg["config"])
            kwargs = OmegaConf.merge(config, cfg)
        kwargs["model"] = config["model"]
    return OmegaConf.to_container(kwargs, resolve=True)
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
    if isinstance(file_path_metas, dict):
        for k, v in file_path_metas.items():
            if isinstance(v, str):
                p = os.path.join(model_or_path, v)
                if os.path.exists(p):
                    cfg[k] = p
            elif isinstance(v, dict):
                if k not in cfg:
                    cfg[k] = {}
                return add_file_root_path(model_or_path, v, cfg[k])
    return cfg
    if isinstance(file_path_metas, dict):
        for k, v in file_path_metas.items():
            if isinstance(v, str):
                p = os.path.join(model_or_path, v)
                if os.path.exists(p):
                    cfg[k] = p
            elif isinstance(v, dict):
                if k not in cfg:
                    cfg[k] = {}
                return add_file_root_path(model_or_path, v, cfg[k])
    return cfg
def get_or_download_model_dir(
        model,
        model_revision=None,
        is_training=False,
        check_latest=True,
    ):
    """ Get local model directory or download model if necessary.
        model,
        model_revision=None,
        is_training=False,
        check_latest=True,
    ):
    """ Get local model directory or download model if necessary.
    Args:
        model (str): model id or path to local model directory.
        model_revision  (str, optional): model version number.
        :param is_training:
    """
    from modelscope.hub.check_model import check_local_model_is_latest
    from modelscope.hub.snapshot_download import snapshot_download
    Args:
        model (str): model id or path to local model directory.
        model_revision  (str, optional): model version number.
        :param is_training:
    """
    from modelscope.hub.check_model import check_local_model_is_latest
    from modelscope.hub.snapshot_download import snapshot_download
    from modelscope.utils.constant import Invoke, ThirdParty
    key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
    if os.path.exists(model) and check_latest:
        model_cache_dir = model if os.path.isdir(
            model) else os.path.dirname(model)
        try:
            check_local_model_is_latest(
                model_cache_dir,
                user_agent={
                    Invoke.KEY: key,
                    ThirdParty.KEY: "funasr"
                })
        except:
            print("could not check the latest version")
    else:
        model_cache_dir = snapshot_download(
            model,
            revision=model_revision,
            user_agent={
                Invoke.KEY: key,
                ThirdParty.KEY: "funasr"
            })
    return model_cache_dir
    from modelscope.utils.constant import Invoke, ThirdParty
    key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
    if os.path.exists(model) and check_latest:
        model_cache_dir = model if os.path.isdir(
            model) else os.path.dirname(model)
        try:
            check_local_model_is_latest(
                model_cache_dir,
                user_agent={
                    Invoke.KEY: key,
                    ThirdParty.KEY: "funasr"
                })
        except:
            print("could not check the latest version")
    else:
        model_cache_dir = snapshot_download(
            model,
            revision=model_revision,
            user_agent={
                Invoke.KEY: key,
                ThirdParty.KEY: "funasr"
            })
    return model_cache_dir
funasr/download/name_maps_from_hub.py
@@ -1,14 +1,13 @@
name_maps_ms = {
    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
    "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
    "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
    "cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
}
name_maps_hf = {
funasr/download/runtime_sdk_download_tool.py
@@ -1,45 +1,47 @@
from pathlib import Path
import os
import argparse
from pathlib import Path
from funasr.utils.types import str2bool
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, required=True)
    parser.add_argument('--export-dir', type=str, required=True)
    parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
    args = parser.parse_args()
    model_dir = args.model_name
    if not Path(args.model_name).exists():
        from modelscope.hub.snapshot_download import snapshot_download
        try:
            model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
        except:
            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
                (model_dir)
    if args.export:
        model_file = os.path.join(model_dir, 'model.onnx')
        if args.quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.bin.export_model import ModelExport
            export_model = ModelExport(
                cache_dir=args.export_dir,
                onnx=True,
                device="cpu",
                quant=args.quantize,
            )
            export_model.export(model_dir)
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, required=True)
    parser.add_argument('--export-dir', type=str, required=True)
    parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
    args = parser.parse_args()
    model_dir = args.model_name
    if not Path(args.model_name).exists():
        from modelscope.hub.snapshot_download import snapshot_download
        try:
            model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
        except:
            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
                (model_dir)
    if args.export:
        model_file = os.path.join(model_dir, 'model.onnx')
        if args.quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.bin.export_model import ModelExport
            export_model = ModelExport(
                cache_dir=args.export_dir,
                onnx=True,
                device="cpu",
                quant=args.quantize,
            )
            export_model.export(model_dir)
if __name__ == "__main__":
    main()
    main()
funasr/models/branchformer/model.py
@@ -5,12 +5,12 @@
@tables.register("model_classes", "Branchformer")
class Branchformer(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        *args,
        **kwargs,
    ):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        super().__init__(*args, **kwargs)
funasr/models/conformer/model.py
@@ -7,13 +7,13 @@
@tables.register("model_classes", "Conformer")
class Conformer(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        *args,
        **kwargs,
    ):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        super().__init__(*args, **kwargs)
funasr/models/ct_transformer/model.py
@@ -344,7 +344,6 @@
                punc_array = punctuations
            else:
                punc_array = torch.cat([punc_array, punctuations], dim=0)
        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
        results.append(result_i)
    
funasr/models/e_branchformer/model.py
@@ -5,12 +5,12 @@
@tables.register("model_classes", "EBranchformer")
class EBranchformer(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        *args,
        **kwargs,
    ):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        super().__init__(*args, **kwargs)
funasr/models/sanm/model.py
@@ -7,12 +7,12 @@
@tables.register("model_classes", "SANM")
class SANM(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        *args,
        **kwargs,
    ):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        super().__init__(*args, **kwargs)
funasr/models/scama/chunk_utilis.py
@@ -1,289 +1,287 @@
import math
import torch
import numpy as np
import math
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import logging
import torch.nn.functional as F
from funasr.models.scama.utils import sequence_mask
from funasr.models.scama.utils import sequence_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class overlap_chunk():
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(self,
        chunk_size: tuple = (16,),
        stride: tuple = (10,),
        pad_left: tuple = (0,),
        encoder_att_look_back_factor: tuple = (1,),
    """
    def __init__(self,
        chunk_size: tuple = (16,),
        stride: tuple = (10,),
        pad_left: tuple = (0,),
        encoder_att_look_back_factor: tuple = (1,),
        shfit_fsmn: int = 0,
        decoder_att_look_back_factor: tuple = (1,),
    ):
    ):
        pad_left = self.check_chunk_size_args(chunk_size, pad_left)
        encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
        decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
        self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
            = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
        self.shfit_fsmn = shfit_fsmn
        self.x_add_mask = None
        self.x_rm_mask = None
        self.x_len = None
        self.mask_shfit_chunk = None
        self.mask_chunk_predictor = None
        self.mask_att_chunk_encoder = None
        self.mask_shift_att_chunk_decoder = None
        self.chunk_outs = None
        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
            = None, None, None, None, None
        pad_left = self.check_chunk_size_args(chunk_size, pad_left)
        encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
        decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
        self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
            = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
        self.shfit_fsmn = shfit_fsmn
        self.x_add_mask = None
        self.x_rm_mask = None
        self.x_len = None
        self.mask_shfit_chunk = None
        self.mask_chunk_predictor = None
        self.mask_att_chunk_encoder = None
        self.mask_shift_att_chunk_decoder = None
        self.chunk_outs = None
        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
            = None, None, None, None, None
    def check_chunk_size_args(self, chunk_size, x):
        if len(x) < len(chunk_size):
            x = [x[0] for i in chunk_size]
        return x
    def check_chunk_size_args(self, chunk_size, x):
        if len(x) < len(chunk_size):
            x = [x[0] for i in chunk_size]
        return x
    def get_chunk_size(self,
        ind: int = 0
    ):
        # with torch.no_grad:
        chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
            self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
            = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
        return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
    def get_chunk_size(self,
        ind: int = 0
    ):
        # with torch.no_grad:
        chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
            self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
            = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
        return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
    def random_choice(self, training=True, decoding_ind=None):
        chunk_num = len(self.chunk_size)
        ind = 0
        if training and chunk_num > 1:
            ind = torch.randint(0, chunk_num, ()).cpu().item()
        if not training and decoding_ind is not None:
            ind = int(decoding_ind)
    def random_choice(self, training=True, decoding_ind=None):
        chunk_num = len(self.chunk_size)
        ind = 0
        if training and chunk_num > 1:
            ind = torch.randint(0, chunk_num, ()).cpu().item()
        if not training and decoding_ind is not None:
            ind = int(decoding_ind)
        return ind
        return ind
    def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
    def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
        with torch.no_grad():
            x_len = x_len.cpu().numpy()
            x_len_max = x_len.max()
        with torch.no_grad():
            x_len = x_len.cpu().numpy()
            x_len_max = x_len.max()
            chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
            shfit_fsmn = self.shfit_fsmn
            pad_right = chunk_size - stride - pad_left
            chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
            shfit_fsmn = self.shfit_fsmn
            pad_right = chunk_size - stride - pad_left
            chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
            x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
            x_len_chunk = x_len_chunk.astype(x_len.dtype)
            x_len_chunk_max = x_len_chunk.max()
            chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
            x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
            x_len_chunk = x_len_chunk.astype(x_len.dtype)
            x_len_chunk_max = x_len_chunk.max()
            chunk_num = int(math.ceil(x_len_max/stride))
            dtype = np.int32
            max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
            x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
            x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
            mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
            mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
            mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
            mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
            for chunk_ids in range(chunk_num):
                # x_mask add
                fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
                x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
                x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
                x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
                x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
                x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
                x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
                x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
            chunk_num = int(math.ceil(x_len_max/stride))
            dtype = np.int32
            max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
            x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
            x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
            mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
            mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
            mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
            mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
            for chunk_ids in range(chunk_num):
                # x_mask add
                fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
                x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
                x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
                x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
                x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
                x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
                x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
                x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
                # x_mask rm
                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
                padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
                padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
                x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
                x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
                x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
                x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
                x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
                x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
                x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
                # x_mask rm
                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
                padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
                padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
                x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
                x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
                x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
                x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
                x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
                x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
                x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
                # fsmn_padding_mask
                pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
                ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
                mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
                mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
                # fsmn_padding_mask
                pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
                ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
                mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
                mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
                # predictor mask
                zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
                ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
                zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
                ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
                mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
                mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
                # predictor mask
                zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
                ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
                zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
                ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
                mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
                mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
                # encoder att mask
                zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
                # encoder att mask
                zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
                zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
                zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
                zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
                zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
                encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
                zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
                ones_2_mid = np.ones([stride, stride], dtype=dtype)
                zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
                zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
                ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
                ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
                ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
                encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
                zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
                ones_2_mid = np.ones([stride, stride], dtype=dtype)
                zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
                zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
                ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
                ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
                ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
                zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
                ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
                ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
                zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
                ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
                ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
                zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
                zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
                zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
                zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
                ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
                mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
                mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
                ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
                mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
                mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
                # decoder fsmn_shift_att_mask
                zeros_1 = np.zeros([shfit_fsmn, 1])
                ones_1 = np.ones([chunk_size, 1])
                mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
                mask_shift_att_chunk_decoder = np.concatenate(
                    [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
                # decoder fsmn_shift_att_mask
                zeros_1 = np.zeros([shfit_fsmn, 1])
                ones_1 = np.ones([chunk_size, 1])
                mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
                mask_shift_att_chunk_decoder = np.concatenate(
                    [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
            self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
            self.x_len_chunk = x_len_chunk
            self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
            self.x_len = x_len
            self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
            self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
            self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
            self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
            self.chunk_outs = (self.x_add_mask,
                self.x_len_chunk,
                self.x_rm_mask,
                self.x_len,
                self.mask_shfit_chunk,
                self.mask_chunk_predictor,
                self.mask_att_chunk_encoder,
                self.mask_shift_att_chunk_decoder)
            self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
            self.x_len_chunk = x_len_chunk
            self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
            self.x_len = x_len
            self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
            self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
            self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
            self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
            self.chunk_outs = (self.x_add_mask,
                self.x_len_chunk,
                self.x_rm_mask,
                self.x_len,
                self.mask_shfit_chunk,
                self.mask_chunk_predictor,
                self.mask_att_chunk_encoder,
                self.mask_shift_att_chunk_decoder)
        return self.chunk_outs
        return self.chunk_outs
    def split_chunk(self, x, x_len, chunk_outs):
        """
        :param x: (b, t, d)
        :param x_length: (b)
        :param ind: int
        :return:
        """
        x = x[:, :x_len.max(), :]
        b, t, d = x.size()
        x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
            x.device)
        x *= x_len_mask[:, :, None]
    def split_chunk(self, x, x_len, chunk_outs):
        """
        :param x: (b, t, d)
        :param x_length: (b)
        :param ind: int
        :return:
        """
        x = x[:, :x_len.max(), :]
        b, t, d = x.size()
        x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
            x.device)
        x *= x_len_mask[:, :, None]
        x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
        x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
        pad = (0, 0, self.pad_left_cur, 0)
        x = F.pad(x, pad, "constant", 0.0)
        b, t, d = x.size()
        x = torch.transpose(x, 1, 0)
        x = torch.reshape(x, [t, -1])
        x_chunk = torch.mm(x_add_mask, x)
        x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
        x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
        x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
        pad = (0, 0, self.pad_left_cur, 0)
        x = F.pad(x, pad, "constant", 0.0)
        b, t, d = x.size()
        x = torch.transpose(x, 1, 0)
        x = torch.reshape(x, [t, -1])
        x_chunk = torch.mm(x_add_mask, x)
        x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
        return x_chunk, x_len_chunk
        return x_chunk, x_len_chunk
    def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
        x_chunk = x_chunk[:, :x_len_chunk.max(), :]
        b, t, d = x_chunk.size()
        x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
            x_chunk.device)
        x_chunk *= x_len_chunk_mask[:, :, None]
    def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
        x_chunk = x_chunk[:, :x_len_chunk.max(), :]
        b, t, d = x_chunk.size()
        x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
            x_chunk.device)
        x_chunk *= x_len_chunk_mask[:, :, None]
        x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
        x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
        x_chunk = torch.transpose(x_chunk, 1, 0)
        x_chunk = torch.reshape(x_chunk, [t, -1])
        x = torch.mm(x_rm_mask, x_chunk)
        x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
        x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
        x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
        x_chunk = torch.transpose(x_chunk, 1, 0)
        x_chunk = torch.reshape(x_chunk, [t, -1])
        x = torch.mm(x_rm_mask, x_chunk)
        x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
        return x, x_len
        return x, x_len
    def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
        with torch.no_grad():
            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
            x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
def build_scama_mask_for_cross_attention_decoder(
                              predictor_alignments: torch.Tensor,
                              predictor_alignments: torch.Tensor,
                              encoder_sequence_length: torch.Tensor,
                              chunk_size: int = 5,
                              encoder_chunk_size: int = 5,
@@ -291,100 +289,100 @@
                              attention_chunk_size: int = 1,
                              attention_chunk_type: str = 'chunk',
                              step=None,
                              predictor_mask_chunk_hopping: torch.Tensor = None,
                              decoder_att_look_back_factor: int = 1,
                              mask_shift_att_chunk_decoder: torch.Tensor = None,
                              target_length: torch.Tensor = None,
                              is_training=True,
                              predictor_mask_chunk_hopping: torch.Tensor = None,
                              decoder_att_look_back_factor: int = 1,
                              mask_shift_att_chunk_decoder: torch.Tensor = None,
                              target_length: torch.Tensor = None,
                              is_training=True,
                              dtype: torch.dtype = torch.float32):
    with torch.no_grad():
        device = predictor_alignments.device
        batch_size, chunk_num = predictor_alignments.size()
        maximum_encoder_length = encoder_sequence_length.max().item()
        int_type = predictor_alignments.dtype
        if not is_training:
            target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
        maximum_target_length = target_length.max()
        predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
        predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
        index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
        index = torch.cumsum(index, dim=1)
        index = index[:, :, None].repeat(1, 1, chunk_num)
        index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
        index_div_bool_zeros = index_div == 0
        index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
        index_div_bool_zeros_count *= chunk_size
        index_div_bool_zeros_count += attention_chunk_center_bias
        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
        index_div_bool_zeros_count_ori = index_div_bool_zeros_count
        index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
        max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
        mask_flip, mask_flip2 = None, None
        if attention_chunk_size is not None:
            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
            mask_flip = 1 - index_div_bool_zeros_count_beg_mask
            attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
            mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
        mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
        if predictor_mask_chunk_hopping is not None:
                b, k, t = mask.size()
                predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
                mask_mask_flip = mask
                if mask_flip is not None:
                        mask_mask_flip = mask_flip * mask
                def _fn():
                        mask_sliced = mask[:b, :k, encoder_chunk_size:t]
                        zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
                        mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
                        _, _, tt = predictor_mask_chunk_hopping.size()
                        pad_right_p = max_len_chunk - tt
                        predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
                        masked = mask_sliced * predictor_mask_chunk_hopping_pad
                        mask_true = mask_mask_flip + masked
                        return mask_true
                mask = _fn() if t > chunk_size else mask_mask_flip
        if mask_flip2 is not None:
            mask *= mask_flip2
        mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
        mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
        mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
        mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
        if attention_chunk_type == 'full':
            mask = torch.ones_like(mask).to(device)
        if mask_shift_att_chunk_decoder is not None:
            mask = mask * mask_shift_att_chunk_decoder
        mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
    with torch.no_grad():
        device = predictor_alignments.device
        batch_size, chunk_num = predictor_alignments.size()
        maximum_encoder_length = encoder_sequence_length.max().item()
        int_type = predictor_alignments.dtype
        if not is_training:
            target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
        maximum_target_length = target_length.max()
        predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
        predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
        index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
        index = torch.cumsum(index, dim=1)
        index = index[:, :, None].repeat(1, 1, chunk_num)
        index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
        index_div_bool_zeros = index_div == 0
        index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
        index_div_bool_zeros_count *= chunk_size
        index_div_bool_zeros_count += attention_chunk_center_bias
        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
        index_div_bool_zeros_count_ori = index_div_bool_zeros_count
        index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
        max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
        mask_flip, mask_flip2 = None, None
        if attention_chunk_size is not None:
            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
            mask_flip = 1 - index_div_bool_zeros_count_beg_mask
            attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
            mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
        mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
        if predictor_mask_chunk_hopping is not None:
                b, k, t = mask.size()
                predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
                mask_mask_flip = mask
                if mask_flip is not None:
                        mask_mask_flip = mask_flip * mask
                def _fn():
                        mask_sliced = mask[:b, :k, encoder_chunk_size:t]
                        zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
                        mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
                        _, _, tt = predictor_mask_chunk_hopping.size()
                        pad_right_p = max_len_chunk - tt
                        predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
                        masked = mask_sliced * predictor_mask_chunk_hopping_pad
                        mask_true = mask_mask_flip + masked
                        return mask_true
                mask = _fn() if t > chunk_size else mask_mask_flip
        if mask_flip2 is not None:
            mask *= mask_flip2
        mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
        mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
        mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
        mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
        if attention_chunk_type == 'full':
            mask = torch.ones_like(mask).to(device)
        if mask_shift_att_chunk_decoder is not None:
            mask = mask * mask_shift_att_chunk_decoder
        mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
    return mask
    return mask
funasr/models/scama/utils.py
@@ -1,29 +1,30 @@
import os
import torch
from torch.nn import functional as F
import yaml
import torch
import numpy as np
from torch.nn import functional as F
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
    if maxlen is None:
        maxlen = lengths.max()
    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
    matrix = torch.unsqueeze(lengths, dim=-1)
    mask = row_vector < matrix
    mask = mask.detach()
    if maxlen is None:
        maxlen = lengths.max()
    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
    matrix = torch.unsqueeze(lengths, dim=-1)
    mask = row_vector < matrix
    mask = mask.detach()
    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def apply_cmvn(inputs, mvn):
    device = inputs.device
    dtype = inputs.dtype
    frame, dim = inputs.shape
    meams = np.tile(mvn[0:1, :dim], (frame, 1))
    vars = np.tile(mvn[1:2, :dim], (frame, 1))
    inputs -= torch.from_numpy(meams).type(dtype).to(device)
    inputs *= torch.from_numpy(vars).type(dtype).to(device)
    device = inputs.device
    dtype = inputs.dtype
    frame, dim = inputs.shape
    meams = np.tile(mvn[0:1, :dim], (frame, 1))
    vars = np.tile(mvn[1:2, :dim], (frame, 1))
    inputs -= torch.from_numpy(meams).type(dtype).to(device)
    inputs *= torch.from_numpy(vars).type(dtype).to(device)
    return inputs.type(torch.float32)
    return inputs.type(torch.float32)
@@ -36,56 +37,56 @@
    outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
    outputs *= stoch_layer_coeff
    outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
    outputs *= stoch_layer_coeff
    input_dim = inputs.size(-1)
    output_dim = outputs.size(-1)
    input_dim = inputs.size(-1)
    output_dim = outputs.size(-1)
    if input_dim == output_dim:
        outputs += inputs
    return outputs
    if input_dim == output_dim:
        outputs += inputs
    return outputs
def proc_tf_vocab(vocab_path):
    with open(vocab_path, encoding="utf-8") as f:
        token_list = [line.rstrip() for line in f]
        if '<unk>' not in token_list:
            token_list.append('<unk>')
    return token_list
    with open(vocab_path, encoding="utf-8") as f:
        token_list = [line.rstrip() for line in f]
        if '<unk>' not in token_list:
            token_list.append('<unk>')
    return token_list
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
    token_list = proc_tf_vocab(vocab_path)
    with open(config_path, encoding="utf-8") as f:
        config = yaml.safe_load(f)
    config['token_list'] = token_list
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
        yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
    token_list = proc_tf_vocab(vocab_path)
    with open(config_path, encoding="utf-8") as f:
        config = yaml.safe_load(f)
    config['token_list'] = token_list
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
        yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
class NoAliasSafeDumper(yaml.SafeDumper):
    # Disable anchor/alias in yaml because looks ugly
    def ignore_aliases(self, data):
        return True
    # Disable anchor/alias in yaml because looks ugly
    def ignore_aliases(self, data):
        return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
    """Safe-dump in yaml with no anchor/alias"""
    return yaml.dump(
        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
    )
    """Safe-dump in yaml with no anchor/alias"""
    return yaml.dump(
        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
    )
if __name__ == '__main__':
    import sys
    config_path = sys.argv[1]
    vocab_path = sys.argv[2]
    output_dir = sys.argv[3]
    gen_config_for_tfmodel(config_path, vocab_path, output_dir)
    import sys
    config_path = sys.argv[1]
    vocab_path = sys.argv[2]
    output_dir = sys.argv[3]
    gen_config_for_tfmodel(config_path, vocab_path, output_dir)
funasr/models/seaco_paraformer/model.py
@@ -212,7 +212,7 @@
                               ys_pad_lens, 
                               hw_list,
                               nfilter=50,
                                 seaco_weight=1.0):
                               seaco_weight=1.0):
        # decoder forward
        decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
        decoder_pred = torch.log_softmax(decoder_out, dim=-1)
@@ -254,10 +254,9 @@
            
            dha_output = self.hotword_output_layer(merged)  # remove the last token in loss calculation
            dha_pred = torch.log_softmax(dha_output, dim=-1)
            # import pdb; pdb.set_trace()
            def _merge_res(dec_output, dha_output):
                lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
                dha_ids = dha_output.max(-1)[-1][0]
                dha_ids = dha_output.max(-1)[-1]# [0]
                dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
                a = (1 - lmbd) / lmbd
                b = 1 / lmbd
@@ -267,6 +266,7 @@
                logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                return logits
            merged_pred = _merge_res(decoder_pred, dha_pred)
            # import pdb; pdb.set_trace()
            return merged_pred
        else:
            return decoder_pred
funasr/models/uniasr/e2e_uni_asr.py
@@ -541,20 +541,20 @@
                        speech_lengths: (Batch, )
        """
        # with autocast(False):
        #     # 1. Extract feats
        #     feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        #     # 1. Extract feats
        #     feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        #
        #     # 2. Data augmentation
        #     if self.specaug is not None and self.training:
        #         feats, feats_lengths = self.specaug(feats, feats_lengths)
        #     # 2. Data augmentation
        #     if self.specaug is not None and self.training:
        #         feats, feats_lengths = self.specaug(feats, feats_lengths)
        #
        #     # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        #     if self.normalize is not None:
        #         feats, feats_lengths = self.normalize(feats, feats_lengths)
        #     # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        #     if self.normalize is not None:
        #         feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        # if self.preencoder is not None:
        #     feats, feats_lengths = self.preencoder(feats, feats_lengths)
        #     feats, feats_lengths = self.preencoder(feats, feats_lengths)
        encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
            encoder_out,
            encoder_out_lens,
@@ -584,9 +584,9 @@
        # # Post-encoder, e.g. NLU
        # if self.postencoder is not None:
        #     encoder_out, encoder_out_lens = self.postencoder(
        #         encoder_out, encoder_out_lens
        #     )
        #     encoder_out, encoder_out_lens = self.postencoder(
        #         encoder_out, encoder_out_lens
        #     )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
funasr/optimizers/__init__.py
@@ -3,15 +3,15 @@
from funasr.optimizers.sgd import SGD
optim_classes = dict(
    adam=torch.optim.Adam,
    fairseq_adam=FairseqAdam,
    adamw=torch.optim.AdamW,
    sgd=SGD,
    adadelta=torch.optim.Adadelta,
    adagrad=torch.optim.Adagrad,
    adamax=torch.optim.Adamax,
    asgd=torch.optim.ASGD,
    lbfgs=torch.optim.LBFGS,
    rmsprop=torch.optim.RMSprop,
    rprop=torch.optim.Rprop,
    adam=torch.optim.Adam,
    fairseq_adam=FairseqAdam,
    adamw=torch.optim.AdamW,
    sgd=SGD,
    adadelta=torch.optim.Adadelta,
    adagrad=torch.optim.Adagrad,
    adamax=torch.optim.Adamax,
    asgd=torch.optim.ASGD,
    lbfgs=torch.optim.LBFGS,
    rmsprop=torch.optim.RMSprop,
    rprop=torch.optim.Rprop,
)
funasr/schedulers/__init__.py
@@ -8,16 +8,16 @@
from funasr.schedulers.warmup_lr import WarmupLR
scheduler_classes = dict(
    ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
    lambdalr=torch.optim.lr_scheduler.LambdaLR,
    steplr=torch.optim.lr_scheduler.StepLR,
    multisteplr=torch.optim.lr_scheduler.MultiStepLR,
    exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
    noamlr=NoamLR,
    warmuplr=WarmupLR,
    tri_stage=TriStageLR,
    cycliclr=torch.optim.lr_scheduler.CyclicLR,
    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
    ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
    lambdalr=torch.optim.lr_scheduler.LambdaLR,
    steplr=torch.optim.lr_scheduler.StepLR,
    multisteplr=torch.optim.lr_scheduler.MultiStepLR,
    exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
    noamlr=NoamLR,
    warmuplr=WarmupLR,
    tri_stage=TriStageLR,
    cycliclr=torch.optim.lr_scheduler.CyclicLR,
    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
funasr/tokenizer/abs_tokenizer.py
@@ -1,100 +1,94 @@
from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import json
import numpy as np
from abc import ABC
from pathlib import Path
from abc import abstractmethod
from typing import Union, Iterable, List, Dict
class AbsTokenizer(ABC):
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
class BaseTokenizer(ABC):
    def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
                 unk_symbol: str = "<unk>",
                 **kwargs,
                 ):
        if token_list is not None:
            if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with token_list.open("r", encoding="utf-8") as f:
                    for idx, line in enumerate(f):
                        line = line.rstrip()
                        self.token_list.append(line)
            elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with open(token_list, 'r', encoding='utf-8') as f:
                    self.token_list = json.load(f)
            else:
                self.token_list: List[str] = list(token_list)
                self.token_list_repr = ""
                for i, t in enumerate(self.token_list):
                    if i == 3:
                        break
                    self.token_list_repr += f"{t}, "
                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
            self.token2id: Dict[str, int] = {}
            for i, t in enumerate(self.token_list):
                if t in self.token2id:
                    raise RuntimeError(f'Symbol "{t}" is duplicated')
                self.token2id[t] = i
            self.unk_symbol = unk_symbol
            if self.unk_symbol not in self.token2id:
                raise RuntimeError(
                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
                )
            self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)
        return text_ints
    def decode(self, text_ints):
        token = self.ids2tokens(text_ints)
        text = self.tokens2text(token)
        return text
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        return [self.token2id.get(i, self.unk_id) for i in tokens]
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
    def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
                 unk_symbol: str = "<unk>",
                 **kwargs,
                 ):
        if token_list is not None:
            if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with token_list.open("r", encoding="utf-8") as f:
                    for idx, line in enumerate(f):
                        line = line.rstrip()
                        self.token_list.append(line)
            elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with open(token_list, 'r', encoding='utf-8') as f:
                    self.token_list = json.load(f)
            else:
                self.token_list: List[str] = list(token_list)
                self.token_list_repr = ""
                for i, t in enumerate(self.token_list):
                    if i == 3:
                        break
                    self.token_list_repr += f"{t}, "
                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
            self.token2id: Dict[str, int] = {}
            for i, t in enumerate(self.token_list):
                if t in self.token2id:
                    raise RuntimeError(f'Symbol "{t}" is duplicated')
                self.token2id[t] = i
            self.unk_symbol = unk_symbol
            if self.unk_symbol not in self.token2id:
                raise RuntimeError(
                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
                )
            self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)
        return text_ints
    def decode(self, text_ints):
        token = self.ids2tokens(text_ints)
        text = self.tokens2text(token)
        return text
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        return [self.token2id.get(i, self.unk_id) for i in tokens]
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
funasr/train_utils/trainer.py
@@ -1,233 +1,235 @@
import torch
import os
from funasr.train_utils.device_funcs import to_device
import logging
import time
import torch
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from contextlib import nullcontext
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
class Trainer:
    """
    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
    and optionally resuming from a saved checkpoint.
    """
    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
    and optionally resuming from a saved checkpoint.
    Attributes:
        max_epoch (int): Maximum number of epochs for training.
        model (torch.nn.Module): The model to be trained.
        optim (torch.optim.Optimizer): The optimizer to use for training.
        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        output_dir (str): Directory where model checkpoints will be saved.
        resume (str, optional): Path to a checkpoint to resume training from.
    """
    def __init__(self, model,
                 optim,
                 scheduler,
                 dataloader_train,
                 dataloader_val,
                 local_rank,
                 use_ddp=False,
                 use_fsdp=False,
                 **kwargs):
        """
        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
    Attributes:
        max_epoch (int): Maximum number of epochs for training.
        model (torch.nn.Module): The model to be trained.
        optim (torch.optim.Optimizer): The optimizer to use for training.
        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        output_dir (str): Directory where model checkpoints will be saved.
        resume (str, optional): Path to a checkpoint to resume training from.
    """
    def __init__(self, model,
                 optim,
                 scheduler,
                 dataloader_train,
                 dataloader_val,
                 local_rank,
                 use_ddp=False,
                 use_fsdp=False,
                 **kwargs):
        """
        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
        Args:
            model (torch.nn.Module): The model to be trained.
            optim (torch.optim.Optimizer): The optimizer to use for training.
            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
            dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
            dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
            **kwargs: Additional keyword arguments:
                      max_epoch (int): The maximum number of epochs for training.
                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
                      resume (str, optional): The file path to a checkpoint to resume training from.
        """
        self.model = model
        self.optim = optim
        self.scheduler = scheduler
        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val
        self.output_dir = kwargs.get('output_dir', './')
        self.resume = kwargs.get('resume', True)
        self.start_epoch = 0
        self.max_epoch = kwargs.get('max_epoch', 100)
        self.local_rank = local_rank
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.device = next(model.parameters()).device
        self.kwargs = kwargs
        if self.resume:
            self._resume_checkpoint(self.resume)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
            logging.warning("distributed is not initialized, only single shard")
        self.rank = rank
        self.world_size = world_size
    def _save_checkpoint(self, epoch):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
        intended to be called at the end of each epoch to save the training progress.
        Args:
            model (torch.nn.Module): The model to be trained.
            optim (torch.optim.Optimizer): The optimizer to use for training.
            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
            dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
            dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
            **kwargs: Additional keyword arguments:
                      max_epoch (int): The maximum number of epochs for training.
                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
                      resume (str, optional): The file path to a checkpoint to resume training from.
        """
        self.model = model
        self.optim = optim
        self.scheduler = scheduler
        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val
        self.output_dir = kwargs.get('output_dir', './')
        self.resume = kwargs.get('resume', True)
        self.start_epoch = 0
        self.max_epoch = kwargs.get('max_epoch', 100)
        self.local_rank = local_rank
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.device = next(model.parameters()).device
        self.kwargs = kwargs
        if self.resume:
            self._resume_checkpoint(self.resume)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
            logging.warning("distributed is not initialized, only single shard")
        self.rank = rank
        self.world_size = world_size
    def _save_checkpoint(self, epoch):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
        intended to be called at the end of each epoch to save the training progress.
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optim.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
        torch.save(state, filename)
        print(f'Checkpoint saved to {filename}')
    def _resume_checkpoint(self, resume_path):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optim.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
        torch.save(state, filename)
        print(f'Checkpoint saved to {filename}')
    def _resume_checkpoint(self, resume_path):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.
        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if os.path.isfile(resume_path):
            checkpoint = torch.load(resume_path)
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
        else:
            print(f"No checkpoint found at '{resume_path}', starting from scratch")
    def run(self):
        """
        Starts the training process, iterating over epochs, training the model,
        and saving checkpoints at the end of each epoch.
        """
        for epoch in range(self.start_epoch, self.max_epoch + 1):
            self._train_epoch(epoch)
            # self._validate_epoch(epoch)
            if self.rank == 0:
                self._save_checkpoint(epoch)
            self.scheduler.step()
    def _train_epoch(self, epoch):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        self.model.train()
        pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
                    dynamic_ncols=True)
        # Set the number of steps for gradient accumulation
        accum_grad = self.kwargs.get("accum_grad", 1)
        # Initialize the gradient accumulation
        self.optim.zero_grad()
        speed_stats = {}
        time5 = time.perf_counter()
        for batch_idx, batch in enumerate(self.dataloader_train):
            time1 = time.perf_counter()
            speed_stats["data_load"] = f"{time1-time5:0.3f}"
            # import pdb;
            # pdb.set_trace()
            batch = to_device(batch, self.device)
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                retval = self.model(**batch)
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
                stats = {k: v for k, v in stats.items() if v is not None}
                if self.use_ddp or self.use_fsdp:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight, distributed=True)
                    # Now weight is summation over all workers
                    loss /= weight
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accum_grad
                loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
                # Perform gradient clipping if it is set
                if self.kwargs.get("grad_clip", None) is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        max_norm=self.kwargs.get("grad_clip", 10.0),
                        norm_type=self.kwargs.get("grad_clip_type", 2.0),
                    )
                    if not torch.isfinite(grad_norm):
                        logging.warning(
                            f"The grad norm is {grad_norm}. Skipping updating the model."
                        )
                        self.optim.zero_grad()  # Reset gradients
                        continue
                # Execute an optimization step (update model parameters)
                self.optim.step()
                self.scheduler.step()
                # Clear gradients for the next accumulation stage
                self.optim.zero_grad()
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
                speed_stats["total_time"] = total_time
            # import pdb;
            # pdb.set_trace()
            pbar.update(1)
            if self.local_rank == 0:
                description = (
                    f"Epoch: {epoch + 1}/{self.max_epoch}, "
                    f"step {batch_idx}/{len(self.dataloader_train)}, "
                    f"{speed_stats}, "
                    f"(loss: {loss.detach().cpu().item():.3f}), "
                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                )
                pbar.set_description(description)
            # if batch_idx == 2:
            #     break
        pbar.close()
        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if os.path.isfile(resume_path):
            checkpoint = torch.load(resume_path)
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
        else:
            print(f"No checkpoint found at '{resume_path}', starting from scratch")
    def run(self):
        """
        Starts the training process, iterating over epochs, training the model,
        and saving checkpoints at the end of each epoch.
        """
        for epoch in range(self.start_epoch, self.max_epoch + 1):
            self._train_epoch(epoch)
            # self._validate_epoch(epoch)
            if self.rank == 0:
                self._save_checkpoint(epoch)
            self.scheduler.step()
    def _train_epoch(self, epoch):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        self.model.train()
        pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
                    dynamic_ncols=True)
        # Set the number of steps for gradient accumulation
        accum_grad = self.kwargs.get("accum_grad", 1)
        # Initialize the gradient accumulation
        self.optim.zero_grad()
        speed_stats = {}
        time5 = time.perf_counter()
        for batch_idx, batch in enumerate(self.dataloader_train):
            time1 = time.perf_counter()
            speed_stats["data_load"] = f"{time1-time5:0.3f}"
            # import pdb;
            # pdb.set_trace()
            batch = to_device(batch, self.device)
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                retval = self.model(**batch)
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
                stats = {k: v for k, v in stats.items() if v is not None}
                if self.use_ddp or self.use_fsdp:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight, distributed=True)
                    # Now weight is summation over all workers
                    loss /= weight
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accum_grad
                loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
                # Perform gradient clipping if it is set
                if self.kwargs.get("grad_clip", None) is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        max_norm=self.kwargs.get("grad_clip", 10.0),
                        norm_type=self.kwargs.get("grad_clip_type", 2.0),
                    )
                    if not torch.isfinite(grad_norm):
                        logging.warning(
                            f"The grad norm is {grad_norm}. Skipping updating the model."
                        )
                        self.optim.zero_grad()  # Reset gradients
                        continue
                # Execute an optimization step (update model parameters)
                self.optim.step()
                self.scheduler.step()
                # Clear gradients for the next accumulation stage
                self.optim.zero_grad()
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
                speed_stats["total_time"] = total_time
            # import pdb;
            # pdb.set_trace()
            pbar.update(1)
            if self.local_rank == 0:
                description = (
                    f"Epoch: {epoch + 1}/{self.max_epoch}, "
                    f"step {batch_idx}/{len(self.dataloader_train)}, "
                    f"{speed_stats}, "
                    f"(loss: {loss.detach().cpu().item():.3f}), "
                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                )
                pbar.set_description(description)
            # if batch_idx == 2:
            #     break
        pbar.close()
    def _validate_epoch(self, epoch):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.
        Args:
            epoch (int): The current epoch number.
        """
        self.model.eval()
        with torch.no_grad():
            for data, target in self.dataloader_val:
                # Implement the model validation steps here
                pass
    def _validate_epoch(self, epoch):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.
        Args:
            epoch (int): The current epoch number.
        """
        self.model.eval()
        with torch.no_grad():
            for data, target in self.dataloader_val:
                # Implement the model validation steps here
                pass
funasr/utils/load_utils.py
@@ -10,100 +10,100 @@
import logging
from torch.nn.utils.rnn import pad_sequence
try:
    from funasr.download.file import download_from_url
    from funasr.download.file import download_from_url
except:
    print("urllib is not installed, if you infer from url, please install it first.")
    print("urllib is not installed, if you infer from url, please install it first.")
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):
            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
                for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
                    data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
                for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
                    data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
            return data_or_path_or_list_ret
        else:
            return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
        data_or_path_or_list = download_from_url(data_or_path_or_list)
    if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
        if data_type is None or data_type == "sound":
            data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
            data_or_path_or_list = data_or_path_or_list[0, :]
        elif data_type == "text" and tokenizer is not None:
            data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
        elif data_type == "image": # undo
            pass
        elif data_type == "video": # undo
            pass
        # if data_in is a file or url, set is_final=True
        if "cache" in kwargs:
            kwargs["cache"]["is_final"] = True
    elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
        data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
    elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
        data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
    else:
        pass
        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
    if audio_fs != fs and data_type != "text":
        resampler = torchaudio.transforms.Resample(audio_fs, fs)
        data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
    return data_or_path_or_list
            return data_or_path_or_list_ret
        else:
            return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
        data_or_path_or_list = download_from_url(data_or_path_or_list)
    if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
        if data_type is None or data_type == "sound":
            data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
            data_or_path_or_list = data_or_path_or_list[0, :]
        elif data_type == "text" and tokenizer is not None:
            data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
        elif data_type == "image": # undo
            pass
        elif data_type == "video": # undo
            pass
        # if data_in is a file or url, set is_final=True
        if "cache" in kwargs:
            kwargs["cache"]["is_final"] = True
    elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
        data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
    elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
        data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
    else:
        pass
        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
    if audio_fs != fs and data_type != "text":
        resampler = torchaudio.transforms.Resample(audio_fs, fs)
        data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
    return data_or_path_or_list
def load_bytes(input):
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in 'iu':
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype('float32')
    if dtype.kind != 'f':
        raise TypeError("'dtype' must be a floating point type")
    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
    return array
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in 'iu':
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype('float32')
    if dtype.kind != 'f':
        raise TypeError("'dtype' must be a floating point type")
    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
    return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
    # import pdb;
    # pdb.set_trace()
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
        if len(data.shape) < 2:
            data = data[None, :] # data: [batch, N]
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, torch.Tensor):
        if len(data.shape) < 2:
            data = data[None, :] # data: [batch, N]
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, (list, tuple)):
        data_list, data_len = [], []
        for data_i in data:
            if isinstance(data_i, np.ndarray):
                data_i = torch.from_numpy(data_i)
            data_list.append(data_i)
            data_len.append(data_i.shape[0])
        data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
    # import pdb;
    # pdb.set_trace()
    # if data_type == "sound":
    data, data_len = frontend(data, data_len, **kwargs)
    if isinstance(data_len, (list, tuple)):
        data_len = torch.tensor([data_len])
    return data.to(torch.float32), data_len.to(torch.int32)
    # import pdb;
    # pdb.set_trace()
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
        if len(data.shape) < 2:
            data = data[None, :] # data: [batch, N]
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, torch.Tensor):
        if len(data.shape) < 2:
            data = data[None, :] # data: [batch, N]
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, (list, tuple)):
        data_list, data_len = [], []
        for data_i in data:
            if isinstance(data_i, np.ndarray):
                data_i = torch.from_numpy(data_i)
            data_list.append(data_i)
            data_len.append(data_i.shape[0])
        data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
    # import pdb;
    # pdb.set_trace()
    # if data_type == "sound":
    data, data_len = frontend(data, data_len, **kwargs)
    if isinstance(data_len, (list, tuple)):
        data_len = torch.tensor([data_len])
    return data.to(torch.float32), data_len.to(torch.int32)
funasr/utils/vad_utils.py
@@ -1,31 +1,31 @@
import torch
from torch.nn.utils.rnn import pad_sequence
def slice_padding_fbank(speech, speech_lengths, vad_segments):
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):
        bed_idx = int(segment[0][0]*16)
        end_idx = min(int(segment[0][1]*16), speech_lengths[0])
        speech_i = speech[0, bed_idx: end_idx]
        speech_lengths_i = end_idx-bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)
    feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
    speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
    return feats_pad, speech_lengths_pad
def slice_padding_fbank(speech, speech_lengths, vad_segments):
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):
        bed_idx = int(segment[0][0]*16)
        end_idx = min(int(segment[0][1]*16), speech_lengths[0])
        speech_i = speech[0, bed_idx: end_idx]
        speech_lengths_i = end_idx-bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)
    feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
    speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
    return feats_pad, speech_lengths_pad
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):
        bed_idx = int(segment[0][0] * 16)
        end_idx = min(int(segment[0][1] * 16), speech_lengths)
        speech_i = speech[bed_idx: end_idx]
        speech_lengths_i = end_idx - bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)
    return speech_list, speech_lengths_list
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):
        bed_idx = int(segment[0][0] * 16)
        end_idx = min(int(segment[0][1] * 16), speech_lengths)
        speech_i = speech[bed_idx: end_idx]
        speech_lengths_i = end_idx - bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)
    return speech_list, speech_lengths_list
runtime/python/utils/test_cer.py
@@ -17,8 +17,8 @@
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@@ -26,23 +26,23 @@
output_dir = args.output_dir
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    os.makedirs(output_dir)
if os.name == 'nt':   # Windows
    newline = '\r\n'
    newline = '\r\n'
else:   # Linux Mac
    newline = '\n'
    newline = '\n'
text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
for i, wav_path_i in enumerate(wav_files):
    wav_name, wav_path = wav_path_i.strip().split()
    result = model(wav_path)
    text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
    token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
    text_f.write(text_i)
    text_f.flush()
    token_f.write(token_i)
    token_f.flush()
    wav_name, wav_path = wav_path_i.strip().split()
    result = model(wav_path)
    text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
    token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
    text_f.write(text_i)
    text_f.flush()
    token_f.write(token_i)
    token_f.flush()
text_f.close()
token_f.close()
runtime/python/utils/test_rtf.py
@@ -16,8 +16,8 @@
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@@ -28,28 +28,28 @@
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
    beg_time = time.time()
    result = model(wav_path)
    end_time = time.time()
    duration = end_time-beg_time
    total += duration
    print(result)
    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
    beg_time = time.time()
    result = model(wav_path)
    end_time = time.time()
    duration = end_time-beg_time
    total += duration
    print(result)
    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    result = model(wav_path)
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    waveform, _ = librosa.load(wav_path, sr=16000)
    duration_time += len(waveform)/16.0
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    waveform, _ = librosa.load(wav_path, sr=16000)
    duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))
runtime/python/utils/test_rtf_gpu.py
@@ -17,8 +17,8 @@
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@@ -29,20 +29,20 @@
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
    beg_time = time.time()
    result = model(wav_path)
    end_time = time.time()
    duration = end_time-beg_time
    total += duration
    print(result)
    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
    beg_time = time.time()
    result = model(wav_path)
    end_time = time.time()
    duration = end_time-beg_time
    total += duration
    print(result)
    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
wav_path = []
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
    wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    wav_path += [wav_path_i]
    wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    wav_path += [wav_path_i]
result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
@@ -50,9 +50,9 @@
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    waveform, _ = librosa.load(wav_path, sr=16000)
    duration_time += len(waveform)/16.0
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    waveform, _ = librosa.load(wav_path, sr=16000)
    duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))