From 1233c0d3ff9cf7fd6131862e7d0b208d3981f6da Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 20:34:47 +0800
Subject: [PATCH] code update
---
funasr/download/runtime_sdk_download_tool.py | 76
funasr/models/scama/chunk_utilis.py | 644 ++++++------
funasr/models/branchformer/model.py | 14
funasr/train_utils/trainer.py | 444 ++++----
funasr/models/sanm/model.py | 14
funasr/download/download_from_hub.py | 193 ++--
funasr/models/conformer/model.py | 16
funasr/models/scama/utils.py | 107 +-
funasr/utils/load_utils.py | 170 +-
runtime/python/utils/test_cer.py | 28
funasr/datasets/audio_datasets/index_ds.py | 108 +-
funasr/models/e_branchformer/model.py | 14
funasr/models/uniasr/e2e_uni_asr.py | 24
funasr/bin/inference.py | 3
runtime/python/utils/test_rtf.py | 28
funasr/utils/vad_utils.py | 50
examples/industrial_data_pretraining/paraformer/demo.py | 4
funasr/bin/train.py | 294 +++---
funasr/tokenizer/abs_tokenizer.py | 178 +-
funasr/schedulers/__init__.py | 24
funasr/optimizers/__init__.py | 22
runtime/python/utils/test_rtf_gpu.py | 28
funasr/datasets/audio_datasets/samplers.py | 141 +-
funasr/datasets/audio_datasets/datasets.py | 171 +-
24 files changed, 1,391 insertions(+), 1,404 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index 20f0f64..6dbe33d 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -18,5 +18,5 @@
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
- res = model(**fbank_dict)
- print(res)
\ No newline at end of file
+ res = model(**fbank_dict)
+ print(res)
\ No newline at end of file
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index ca8771d..7368d16 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -309,10 +309,7 @@
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
-
- # if kwargs["device"] == "cpu":
- # batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 878eb24..0881cb2 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,178 +1,180 @@
-import argparse
-import logging
import os
import sys
-from io import BytesIO
-from collections.abc import Sequence
import torch
import hydra
+import logging
+import argparse
+from io import BytesIO
+import torch.distributed as dist
+from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.models.lora.utils import mark_only_lora_as_trainable
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+from funasr.register import tables
from funasr.optimizers import optim_classes
+from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.train_utils.initialize import initialize
+from funasr.download.download_from_hub import download_model
+from funasr.models.lora.utils import mark_only_lora_as_trainable
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
# from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.train_utils.trainer import Trainer
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from funasr.download.download_from_hub import download_model
-from funasr.register import tables
+
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
- if kwargs.get("debug", False):
- import pdb; pdb.set_trace()
+ if kwargs.get("debug", False):
+ import pdb; pdb.set_trace()
- assert "model" in kwargs
- if "model_conf" not in kwargs:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
- kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
+
- main(**kwargs)
+ main(**kwargs)
def main(**kwargs):
- # preprocess_config(kwargs)
- # import pdb; pdb.set_trace()
- # set random seed
- tables.print()
- set_all_random_seed(kwargs.get("seed", 0))
- torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
- torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
- torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-
- local_rank = int(os.environ.get('LOCAL_RANK', 0))
- # Check if we are using DDP or FSDP
- use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
- use_fsdp = kwargs.get("use_fsdp", None)
- if use_ddp or use_fsdp:
- dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
- torch.cuda.set_device(local_rank)
-
- # save config.yaml
- if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
- os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
- yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
- OmegaConf.save(config=kwargs, f=yaml_file)
- logging.info("config.yaml is saved to: %s", yaml_file)
+ # preprocess_config(kwargs)
+ # import pdb; pdb.set_trace()
+ # set random seed
+ tables.print()
+ set_all_random_seed(kwargs.get("seed", 0))
+ torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
+ torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ # Check if we are using DDP or FSDP
+ use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
+ use_fsdp = kwargs.get("use_fsdp", None)
+ if use_ddp or use_fsdp:
+ dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
+ torch.cuda.set_device(local_rank)
+
+ # save config.yaml
+ if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+ OmegaConf.save(config=kwargs, f=yaml_file)
+ logging.info("config.yaml is saved to: %s", yaml_file)
- tokenizer = kwargs.get("tokenizer", None)
- if tokenizer is not None:
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
- tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
- kwargs["tokenizer"] = tokenizer
-
- # build frontend if frontend is none None
- frontend = kwargs.get("frontend", None)
- if frontend is not None:
- frontend_class = tables.frontend_classes.get(frontend)
- frontend = frontend_class(**kwargs["frontend_conf"])
- kwargs["frontend"] = frontend
- kwargs["input_size"] = frontend.output_size()
-
- # import pdb;
- # pdb.set_trace()
- # build model
- model_class = tables.model_classes.get(kwargs["model"])
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend if frontend is none None
+ frontend = kwargs.get("frontend", None)
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get(frontend)
+ frontend = frontend_class(**kwargs["frontend_conf"])
+ kwargs["frontend"] = frontend
+ kwargs["input_size"] = frontend.output_size()
+
+ # import pdb;
+ # pdb.set_trace()
+ # build model
+ model_class = tables.model_classes.get(kwargs["model"])
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- if not isinstance(init_param, (list, tuple)):
- init_param = (init_param,)
- logging.info("init_param is not None: %s", init_param)
- for p in init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
- oss_bucket=kwargs.get("oss_bucket", None),
- )
- else:
- initialize(model, kwargs.get("init", "kaiming_normal"))
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if not isinstance(init_param, (list, tuple)):
+ init_param = (init_param,)
+ logging.info("init_param is not None: %s", init_param)
+ for p in init_param:
+ logging.info(f"Loading pretrained params from {p}")
+ load_pretrained_model(
+ model=model,
+ init_param=p,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ )
+ else:
+ initialize(model, kwargs.get("init", "kaiming_normal"))
- # freeze_param
- freeze_param = kwargs.get("freeze_param", None)
- if freeze_param is not None:
- freeze_param = eval(freeze_param)
- if isinstance(freeze_param, Sequence):
- freeze_param = (freeze_param,)
- logging.info("freeze_param is not None: %s", freeze_param)
- for t in freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
-
+ # freeze_param
+ freeze_param = kwargs.get("freeze_param", None)
+ if freeze_param is not None:
+ freeze_param = eval(freeze_param)
+ if isinstance(freeze_param, Sequence):
+ freeze_param = (freeze_param,)
+ logging.info("freeze_param is not None: %s", freeze_param)
+ for t in freeze_param:
+ for k, p in model.named_parameters():
+ if k.startswith(t + ".") or k == t:
+ logging.info(f"Setting {k}.requires_grad = False")
+ p.requires_grad = False
+
- if use_ddp:
- model = model.cuda(local_rank)
- model = DDP(model, device_ids=[local_rank],
- find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
- elif use_fsdp:
- model = FSDP(model).cuda(local_rank)
- else:
- model = model.to(device=kwargs.get("device", "cuda"))
-
-
- # optim
- optim = kwargs.get("optim", "adam")
- assert optim in optim_classes
- optim_class = optim_classes.get(optim)
- optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-
- # scheduler
- scheduler = kwargs.get("scheduler", "warmuplr")
- assert scheduler in scheduler_classes
- scheduler_class = scheduler_classes.get(scheduler)
- scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+ if use_ddp:
+ model = model.cuda(local_rank)
+ model = DDP(model, device_ids=[local_rank],
+ find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
+ elif use_fsdp:
+ model = FSDP(model).cuda(local_rank)
+ else:
+ model = model.to(device=kwargs.get("device", "cuda"))
+
+
+ # optim
+ optim = kwargs.get("optim", "adam")
+ assert optim in optim_classes
+ optim_class = optim_classes.get(optim)
+ optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+
+ # scheduler
+ scheduler = kwargs.get("scheduler", "warmuplr")
+ assert scheduler in scheduler_classes
+ scheduler_class = scheduler_classes.get(scheduler)
+ scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
- # import pdb;
- # pdb.set_trace()
- # dataset
- dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+ # import pdb;
+ # pdb.set_trace()
+ # dataset
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
+ dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
- # dataloader
- batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
- batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
- if batch_sampler is not None:
- batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
- dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
- collate_fn=dataset_tr.collator,
- batch_sampler=batch_sampler,
- num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
- pin_memory=True)
-
+ # dataloader
+ batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+ if batch_sampler is not None:
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+ dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
+ collate_fn=dataset_tr.collator,
+ batch_sampler=batch_sampler,
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
+ pin_memory=True)
+
- trainer = Trainer(
- model=model,
- optim=optim,
- scheduler=scheduler,
- dataloader_train=dataloader_tr,
- dataloader_val=None,
- local_rank=local_rank,
- use_ddp=use_ddp,
- use_fsdp=use_fsdp,
- **kwargs.get("train_conf"),
- )
- trainer.run()
-
- if use_ddp or use_fsdp:
- torch.distributed.destroy_process_group()
+ trainer = Trainer(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ dataloader_train=dataloader_tr,
+ dataloader_val=None,
+ local_rank=local_rank,
+ use_ddp=use_ddp,
+ use_fsdp=use_fsdp,
+ **kwargs.get("train_conf"),
+ )
+ trainer.run()
+
+ if use_ddp or use_fsdp:
+ torch.distributed.destroy_process_group()
-
+
if __name__ == "__main__":
- main_hydra()
\ No newline at end of file
+ main_hydra()
\ No newline at end of file
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 7839ff9..edf127f 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -1,102 +1,93 @@
import torch
-import json
-import torch.distributed as dist
-import numpy as np
-import kaldiio
-import librosa
-import torchaudio
-import time
-import logging
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank
+
@tables.register("dataset_classes", "AudioDataset")
class AudioDataset(torch.utils.data.Dataset):
- """
- AudioDataset
- """
- def __init__(self,
- path,
- index_ds: str = None,
- frontend=None,
- tokenizer=None,
- int_pad_value: int = -1,
- float_pad_value: float = 0.0,
- **kwargs):
- super().__init__()
- index_ds_class = tables.index_ds_classes.get(index_ds)
- self.index_ds = index_ds_class(path)
- preprocessor_speech = kwargs.get("preprocessor_speech", None)
- if preprocessor_speech:
- preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
- preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
- self.preprocessor_speech = preprocessor_speech
- preprocessor_text = kwargs.get("preprocessor_text", None)
- if preprocessor_text:
- preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
- preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
- self.preprocessor_text = preprocessor_text
-
- self.frontend = frontend
- self.fs = 16000 if frontend is None else frontend.fs
- self.data_type = "sound"
- self.tokenizer = tokenizer
+ """
+ AudioDataset
+ """
+ def __init__(self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs):
+ super().__init__()
+ index_ds_class = tables.index_ds_classes.get(index_ds)
+ self.index_ds = index_ds_class(path)
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
+ preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+ self.preprocessor_speech = preprocessor_speech
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+ self.preprocessor_text = preprocessor_text
+
+ self.frontend = frontend
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ self.tokenizer = tokenizer
- self.int_pad_value = int_pad_value
- self.float_pad_value = float_pad_value
-
- def get_source_len(self, index):
- item = self.index_ds[index]
- return self.index_ds.get_source_len(item)
-
- def get_target_len(self, index):
- item = self.index_ds[index]
- return self.index_ds.get_target_len(item)
-
- def __len__(self):
- return len(self.index_ds)
-
- def __getitem__(self, index):
- item = self.index_ds[index]
- # import pdb;
- # pdb.set_trace()
- source = item["source"]
- data_src = load_audio(source, fs=self.fs)
- if self.preprocessor_speech:
- data_src = self.preprocessor_speech(data_src)
- speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+ self.int_pad_value = int_pad_value
+ self.float_pad_value = float_pad_value
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_source_len(item)
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def __getitem__(self, index):
+ item = self.index_ds[index]
+ # import pdb;
+ # pdb.set_trace()
+ source = item["source"]
+ data_src = load_audio(source, fs=self.fs)
+ if self.preprocessor_speech:
+ data_src = self.preprocessor_speech(data_src)
+ speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
- target = item["target"]
- if self.preprocessor_text:
- target = self.preprocessor_text(target)
- ids = self.tokenizer.encode(target)
- ids_lengths = len(ids)
- text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
+ target = item["target"]
+ if self.preprocessor_text:
+ target = self.preprocessor_text(target)
+ ids = self.tokenizer.encode(target)
+ ids_lengths = len(ids)
+ text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
- return {"speech": speech[0, :, :],
- "speech_lengths": speech_lengths,
- "text": text,
- "text_lengths": text_lengths,
- }
-
-
- def collator(self, samples: list=None):
+ return {"speech": speech[0, :, :],
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ }
+
+
+ def collator(self, samples: list=None):
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+ for key, data_list in outputs.items():
+ if data_list[0].dtype == torch.int64:
- outputs = {}
- for sample in samples:
- for key in sample.keys():
- if key not in outputs:
- outputs[key] = []
- outputs[key].append(sample[key])
-
- for key, data_list in outputs.items():
- if data_list[0].dtype == torch.int64:
-
- pad_value = self.int_pad_value
- else:
- pad_value = self.float_pad_value
- outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
- return outputs
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+ return outputs
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 79bb26e..8e5b05c 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -1,64 +1,64 @@
-import torch
import json
-import torch.distributed as dist
-import time
+import torch
import logging
+import torch.distributed as dist
from funasr.register import tables
+
@tables.register("index_ds_classes", "IndexDSJsonl")
class IndexDSJsonl(torch.utils.data.Dataset):
-
- def __init__(self, path):
- super().__init__()
-
- contents = []
- with open(path, encoding='utf-8') as fin:
- for line in fin:
- data = json.loads(line.strip())
- if "text" in data: # for sft
- self.contents.append(data['text'])
- if "source" in data: # for speech lab pretrain
- prompt = data["prompt"]
- source = data["source"]
- target = data["target"]
- source_len = data["source_len"]
- target_len = data["target_len"]
+
+ def __init__(self, path):
+ super().__init__()
+
+ contents = []
+ with open(path, encoding='utf-8') as fin:
+ for line in fin:
+ data = json.loads(line.strip())
+ if "text" in data: # for sft
+ self.contents.append(data['text'])
+ if "source" in data: # for speech lab pretrain
+ prompt = data["prompt"]
+ source = data["source"]
+ target = data["target"]
+ source_len = data["source_len"]
+ target_len = data["target_len"]
- contents.append({"source": source,
- "prompt": prompt,
- "target": target,
- "source_len": source_len,
- "target_len": target_len,
- }
- )
-
- self.contents = []
- total_num = len(contents)
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- logging.warning("distributed is not initialized, only single shard")
- num_per_rank = total_num // world_size
-
- # rank = 0
- # import ipdb; ipdb.set_trace()
- self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
-
- logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
+ contents.append({"source": source,
+ "prompt": prompt,
+ "target": target,
+ "source_len": source_len,
+ "target_len": target_len,
+ }
+ )
+
+ self.contents = []
+ total_num = len(contents)
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+ logging.warning("distributed is not initialized, only single shard")
+ num_per_rank = total_num // world_size
+
+ # rank = 0
+ # import ipdb; ipdb.set_trace()
+ self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
+
+ logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
- def __len__(self):
- return len(self.contents)
-
- def __getitem__(self, index):
- return self.contents[index]
-
- def get_source_len(self, data_dict):
- return data_dict["source_len"]
+ def __len__(self):
+ return len(self.contents)
+
+ def __getitem__(self, index):
+ return self.contents[index]
+
+ def get_source_len(self, data_dict):
+ return data_dict["source_len"]
- def get_target_len(self, data_dict):
-
- return data_dict["target_len"] if "target_len" in data_dict else 0
+ def get_target_len(self, data_dict):
+
+ return data_dict["target_len"] if "target_len" in data_dict else 0
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index 9c87245..bc71b28 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -1,5 +1,4 @@
import torch
-
import numpy as np
from funasr.register import tables
@@ -7,74 +6,74 @@
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = False,
- shuffle: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = batch_size
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 5000)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle
-
- def __len__(self):
- return self.total_samples
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- sample_len_cur = self.dataset.get_source_len(idx_map) + \
- self.dataset.get_target_len(idx_map)
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- if self.batch_type == 'length':
- max_token_padding *= max_token_cur
- if max_token_padding <= self.batch_size:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- yield batch
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
+
+ def __init__(self, dataset,
+ batch_type: str = "example",
+ batch_size: int = 100,
+ buffer_size: int = 30,
+ drop_last: bool = False,
+ shuffle: bool = True,
+ **kwargs):
+
+ self.drop_last = drop_last
+ self.pre_idx = -1
+ self.dataset = dataset
+ self.total_samples = len(dataset)
+ self.batch_type = batch_type
+ self.batch_size = batch_size
+ self.buffer_size = buffer_size
+ self.max_token_length = kwargs.get("max_token_length", 5000)
+ self.shuffle_idx = np.arange(self.total_samples)
+ self.shuffle = shuffle
+
+ def __len__(self):
+ return self.total_samples
+
+ def set_epoch(self, epoch):
+ np.random.seed(epoch)
+
+ def __iter__(self):
+
+ if self.shuffle:
+ np.random.shuffle(self.shuffle_idx)
+
+ batch = []
+ max_token = 0
+ num_sample = 0
+
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
+ # print("iter_num: ", iter_num)
+ for iter in range(self.pre_idx + 1, iter_num):
+ datalen_with_index = []
+ for i in range(self.buffer_size):
+ idx = iter * self.buffer_size + i
+ if idx >= self.total_samples:
+ continue
+
+ idx_map = self.shuffle_idx[idx]
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+ sample_len_cur = self.dataset.get_source_len(idx_map) + \
+ self.dataset.get_target_len(idx_map)
+
+ datalen_with_index.append([idx, sample_len_cur])
+
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+ for item in datalen_with_index_sort:
+ idx, sample_len_cur_raw = item
+ if sample_len_cur_raw > self.max_token_length:
+ continue
+
+ max_token_cur = max(max_token, sample_len_cur_raw)
+ max_token_padding = 1 + num_sample
+ if self.batch_type == 'length':
+ max_token_padding *= max_token_cur
+ if max_token_padding <= self.batch_size:
+ batch.append(idx)
+ max_token = max_token_cur
+ num_sample += 1
+ else:
+ yield batch
+ batch = [idx]
+ max_token = sample_len_cur_raw
+ num_sample = 1
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 57e8c41..cde4b7d 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -1,110 +1,111 @@
-import json
import os
+import json
from omegaconf import OmegaConf
-import torch
+
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
+
def download_model(**kwargs):
- model_hub = kwargs.get("model_hub", "ms")
- if model_hub == "ms":
- kwargs = download_from_ms(**kwargs)
-
- return kwargs
+ model_hub = kwargs.get("model_hub", "ms")
+ if model_hub == "ms":
+ kwargs = download_from_ms(**kwargs)
+
+ return kwargs
def download_from_ms(**kwargs):
- model_or_path = kwargs.get("model")
- if model_or_path in name_maps_ms:
- model_or_path = name_maps_ms[model_or_path]
- model_revision = kwargs.get("model_revision")
- if not os.path.exists(model_or_path):
- model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
- kwargs["model_path"] = model_or_path
-
- config = os.path.join(model_or_path, "config.yaml")
- if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
-
- config = OmegaConf.load(config)
- kwargs = OmegaConf.merge(config, kwargs)
- init_param = os.path.join(model_or_path, "model.pb")
- kwargs["init_param"] = init_param
- if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
- kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
- if os.path.exists(os.path.join(model_or_path, "tokens.json")):
- kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
- if os.path.exists(os.path.join(model_or_path, "seg_dict")):
- kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
- if os.path.exists(os.path.join(model_or_path, "bpe.model")):
- kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
- kwargs["model"] = config["model"]
- if os.path.exists(os.path.join(model_or_path, "am.mvn")):
- kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
- if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
- kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
- elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
- with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
- conf_json = json.load(f)
- cfg = {}
- add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
- cfg.update(kwargs)
- config = OmegaConf.load(cfg["config"])
- kwargs = OmegaConf.merge(config, cfg)
- kwargs["model"] = config["model"]
- return OmegaConf.to_container(kwargs, resolve=True)
+ model_or_path = kwargs.get("model")
+ if model_or_path in name_maps_ms:
+ model_or_path = name_maps_ms[model_or_path]
+ model_revision = kwargs.get("model_revision")
+ if not os.path.exists(model_or_path):
+ model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
+ kwargs["model_path"] = model_or_path
+
+ config = os.path.join(model_or_path, "config.yaml")
+ if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
+
+ config = OmegaConf.load(config)
+ kwargs = OmegaConf.merge(config, kwargs)
+ init_param = os.path.join(model_or_path, "model.pb")
+ kwargs["init_param"] = init_param
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+ kwargs["model"] = config["model"]
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+ elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
+ with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+ conf_json = json.load(f)
+ cfg = {}
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ cfg.update(kwargs)
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
+ return OmegaConf.to_container(kwargs, resolve=True)
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
-
- if isinstance(file_path_metas, dict):
- for k, v in file_path_metas.items():
- if isinstance(v, str):
- p = os.path.join(model_or_path, v)
- if os.path.exists(p):
- cfg[k] = p
- elif isinstance(v, dict):
- if k not in cfg:
- cfg[k] = {}
- return add_file_root_path(model_or_path, v, cfg[k])
-
- return cfg
+
+ if isinstance(file_path_metas, dict):
+ for k, v in file_path_metas.items():
+ if isinstance(v, str):
+ p = os.path.join(model_or_path, v)
+ if os.path.exists(p):
+ cfg[k] = p
+ elif isinstance(v, dict):
+ if k not in cfg:
+ cfg[k] = {}
+ return add_file_root_path(model_or_path, v, cfg[k])
+
+ return cfg
def get_or_download_model_dir(
- model,
- model_revision=None,
- is_training=False,
- check_latest=True,
- ):
- """ Get local model directory or download model if necessary.
+ model,
+ model_revision=None,
+ is_training=False,
+ check_latest=True,
+ ):
+ """ Get local model directory or download model if necessary.
- Args:
- model (str): model id or path to local model directory.
- model_revision (str, optional): model version number.
- :param is_training:
- """
- from modelscope.hub.check_model import check_local_model_is_latest
- from modelscope.hub.snapshot_download import snapshot_download
+ Args:
+ model (str): model id or path to local model directory.
+ model_revision (str, optional): model version number.
+ :param is_training:
+ """
+ from modelscope.hub.check_model import check_local_model_is_latest
+ from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.utils.constant import Invoke, ThirdParty
-
- key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
-
- if os.path.exists(model) and check_latest:
- model_cache_dir = model if os.path.isdir(
- model) else os.path.dirname(model)
- try:
- check_local_model_is_latest(
- model_cache_dir,
- user_agent={
- Invoke.KEY: key,
- ThirdParty.KEY: "funasr"
- })
- except:
- print("could not check the latest version")
- else:
- model_cache_dir = snapshot_download(
- model,
- revision=model_revision,
- user_agent={
- Invoke.KEY: key,
- ThirdParty.KEY: "funasr"
- })
- return model_cache_dir
\ No newline at end of file
+ from modelscope.utils.constant import Invoke, ThirdParty
+
+ key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
+
+ if os.path.exists(model) and check_latest:
+ model_cache_dir = model if os.path.isdir(
+ model) else os.path.dirname(model)
+ try:
+ check_local_model_is_latest(
+ model_cache_dir,
+ user_agent={
+ Invoke.KEY: key,
+ ThirdParty.KEY: "funasr"
+ })
+ except:
+ print("could not check the latest version")
+ else:
+ model_cache_dir = snapshot_download(
+ model,
+ revision=model_revision,
+ user_agent={
+ Invoke.KEY: key,
+ ThirdParty.KEY: "funasr"
+ })
+ return model_cache_dir
\ No newline at end of file
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 92416f4..1981347 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -1,45 +1,47 @@
-from pathlib import Path
import os
import argparse
+from pathlib import Path
+
from funasr.utils.types import str2bool
+
def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--model-name', type=str, required=True)
- parser.add_argument('--export-dir', type=str, required=True)
- parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
- parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
- parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
- parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
- parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
- parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
- parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
- parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
- args = parser.parse_args()
-
- model_dir = args.model_name
- if not Path(args.model_name).exists():
- from modelscope.hub.snapshot_download import snapshot_download
- try:
- model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
- except:
- raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
- (model_dir)
- if args.export:
- model_file = os.path.join(model_dir, 'model.onnx')
- if args.quantize:
- model_file = os.path.join(model_dir, 'model_quant.onnx')
- if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
- from funasr.bin.export_model import ModelExport
- export_model = ModelExport(
- cache_dir=args.export_dir,
- onnx=True,
- device="cpu",
- quant=args.quantize,
- )
- export_model.export(model_dir)
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model-name', type=str, required=True)
+ parser.add_argument('--export-dir', type=str, required=True)
+ parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
+ parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+ parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+ parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+ parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+ parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+ parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+ parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+ args = parser.parse_args()
+
+ model_dir = args.model_name
+ if not Path(args.model_name).exists():
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+ (model_dir)
+ if args.export:
+ model_file = os.path.join(model_dir, 'model.onnx')
+ if args.quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+ if not os.path.exists(model_file):
+ print(".onnx is not exist, begin to export onnx")
+ from funasr.bin.export_model import ModelExport
+ export_model = ModelExport(
+ cache_dir=args.export_dir,
+ onnx=True,
+ device="cpu",
+ quant=args.quantize,
+ )
+ export_model.export(model_dir)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
\ No newline at end of file
diff --git a/funasr/models/branchformer/model.py b/funasr/models/branchformer/model.py
index 53f254d..7fa99b3 100644
--- a/funasr/models/branchformer/model.py
+++ b/funasr/models/branchformer/model.py
@@ -5,12 +5,12 @@
@tables.register("model_classes", "Branchformer")
class Branchformer(Transformer):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """CTC-attention hybrid Encoder-Decoder model"""
- def __init__(
- self,
- *args,
- **kwargs,
- ):
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
- super().__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
diff --git a/funasr/models/conformer/model.py b/funasr/models/conformer/model.py
index 2c26753..171014b 100644
--- a/funasr/models/conformer/model.py
+++ b/funasr/models/conformer/model.py
@@ -7,13 +7,13 @@
@tables.register("model_classes", "Conformer")
class Conformer(Transformer):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """CTC-attention hybrid Encoder-Decoder model"""
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
- super().__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
diff --git a/funasr/models/e_branchformer/model.py b/funasr/models/e_branchformer/model.py
index 4ffeb3e..14c8c4d 100644
--- a/funasr/models/e_branchformer/model.py
+++ b/funasr/models/e_branchformer/model.py
@@ -5,12 +5,12 @@
@tables.register("model_classes", "EBranchformer")
class EBranchformer(Transformer):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """CTC-attention hybrid Encoder-Decoder model"""
- def __init__(
- self,
- *args,
- **kwargs,
- ):
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
- super().__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py
index d51478f..4dc8825 100644
--- a/funasr/models/sanm/model.py
+++ b/funasr/models/sanm/model.py
@@ -7,12 +7,12 @@
@tables.register("model_classes", "SANM")
class SANM(Transformer):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """CTC-attention hybrid Encoder-Decoder model"""
- def __init__(
- self,
- *args,
- **kwargs,
- ):
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
- super().__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
diff --git a/funasr/models/scama/chunk_utilis.py b/funasr/models/scama/chunk_utilis.py
index e90ab62..245d282 100644
--- a/funasr/models/scama/chunk_utilis.py
+++ b/funasr/models/scama/chunk_utilis.py
@@ -1,289 +1,287 @@
-
+import math
import torch
import numpy as np
-import math
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import logging
import torch.nn.functional as F
-from funasr.models.scama.utils import sequence_mask
+from funasr.models.scama.utils import sequence_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
class overlap_chunk():
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- San-m: Memory equipped self-attention for end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ San-m: Memory equipped self-attention for end-to-end speech recognition
+ https://arxiv.org/abs/2006.01713
- """
- def __init__(self,
- chunk_size: tuple = (16,),
- stride: tuple = (10,),
- pad_left: tuple = (0,),
- encoder_att_look_back_factor: tuple = (1,),
+ """
+ def __init__(self,
+ chunk_size: tuple = (16,),
+ stride: tuple = (10,),
+ pad_left: tuple = (0,),
+ encoder_att_look_back_factor: tuple = (1,),
shfit_fsmn: int = 0,
decoder_att_look_back_factor: tuple = (1,),
- ):
+ ):
- pad_left = self.check_chunk_size_args(chunk_size, pad_left)
- encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
- decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
- self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
- = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
- self.shfit_fsmn = shfit_fsmn
- self.x_add_mask = None
- self.x_rm_mask = None
- self.x_len = None
- self.mask_shfit_chunk = None
- self.mask_chunk_predictor = None
- self.mask_att_chunk_encoder = None
- self.mask_shift_att_chunk_decoder = None
- self.chunk_outs = None
- self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
- = None, None, None, None, None
+ pad_left = self.check_chunk_size_args(chunk_size, pad_left)
+ encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
+ decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
+ self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
+ = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
+ self.shfit_fsmn = shfit_fsmn
+ self.x_add_mask = None
+ self.x_rm_mask = None
+ self.x_len = None
+ self.mask_shfit_chunk = None
+ self.mask_chunk_predictor = None
+ self.mask_att_chunk_encoder = None
+ self.mask_shift_att_chunk_decoder = None
+ self.chunk_outs = None
+ self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
+ = None, None, None, None, None
- def check_chunk_size_args(self, chunk_size, x):
- if len(x) < len(chunk_size):
- x = [x[0] for i in chunk_size]
- return x
+ def check_chunk_size_args(self, chunk_size, x):
+ if len(x) < len(chunk_size):
+ x = [x[0] for i in chunk_size]
+ return x
- def get_chunk_size(self,
- ind: int = 0
- ):
- # with torch.no_grad:
- chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
- self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
- self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
- = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
- return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
+ def get_chunk_size(self,
+ ind: int = 0
+ ):
+ # with torch.no_grad:
+ chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
+ self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
+ self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
+ = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
+ return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
- def random_choice(self, training=True, decoding_ind=None):
- chunk_num = len(self.chunk_size)
- ind = 0
- if training and chunk_num > 1:
- ind = torch.randint(0, chunk_num, ()).cpu().item()
- if not training and decoding_ind is not None:
- ind = int(decoding_ind)
+ def random_choice(self, training=True, decoding_ind=None):
+ chunk_num = len(self.chunk_size)
+ ind = 0
+ if training and chunk_num > 1:
+ ind = torch.randint(0, chunk_num, ()).cpu().item()
+ if not training and decoding_ind is not None:
+ ind = int(decoding_ind)
- return ind
+ return ind
- def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
+ def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
- with torch.no_grad():
- x_len = x_len.cpu().numpy()
- x_len_max = x_len.max()
+ with torch.no_grad():
+ x_len = x_len.cpu().numpy()
+ x_len_max = x_len.max()
- chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
- shfit_fsmn = self.shfit_fsmn
- pad_right = chunk_size - stride - pad_left
+ chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
+ shfit_fsmn = self.shfit_fsmn
+ pad_right = chunk_size - stride - pad_left
- chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
- x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
- x_len_chunk = x_len_chunk.astype(x_len.dtype)
- x_len_chunk_max = x_len_chunk.max()
+ chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
+ x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
+ x_len_chunk = x_len_chunk.astype(x_len.dtype)
+ x_len_chunk_max = x_len_chunk.max()
- chunk_num = int(math.ceil(x_len_max/stride))
- dtype = np.int32
- max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
- x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
- x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
- mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
- mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
- mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
- mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
- for chunk_ids in range(chunk_num):
- # x_mask add
- fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
- x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
- x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
- x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
- x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
- x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
- x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
- x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
+ chunk_num = int(math.ceil(x_len_max/stride))
+ dtype = np.int32
+ max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
+ x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
+ x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
+ mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
+ mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
+ mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
+ mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
+ for chunk_ids in range(chunk_num):
+ # x_mask add
+ fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
+ x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
+ x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
+ x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
+ x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
+ x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
+ x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
+ x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
- # x_mask rm
- fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
- padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
- padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
- x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
- x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
- x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
- x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
- x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
- x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
- x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
+ # x_mask rm
+ fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
+ padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
+ padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
+ x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
+ x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
+ x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
+ x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
+ x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
+ x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
+ x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
- # fsmn_padding_mask
- pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
- ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
- mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
- mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
+ # fsmn_padding_mask
+ pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
+ ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
+ mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
+ mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
- # predictor mask
- zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
- ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
- zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
- ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
- mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
- mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
+ # predictor mask
+ zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
+ ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
+ zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
+ ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
+ mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
+ mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
- # encoder att mask
- zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
+ # encoder att mask
+ zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
- zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
- zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
+ zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
+ zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
- encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
- zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
- ones_2_mid = np.ones([stride, stride], dtype=dtype)
- zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
- zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
- ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
- ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
- ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
+ encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
+ zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+ ones_2_mid = np.ones([stride, stride], dtype=dtype)
+ zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
+ zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
+ ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
+ ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
+ ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
- zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
- ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
- ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
+ zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+ ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
+ ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
- zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
- zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
+ zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
+ zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
- ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
- mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
- mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
+ ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
+ mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
+ mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
- # decoder fsmn_shift_att_mask
- zeros_1 = np.zeros([shfit_fsmn, 1])
- ones_1 = np.ones([chunk_size, 1])
- mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
- mask_shift_att_chunk_decoder = np.concatenate(
- [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
+ # decoder fsmn_shift_att_mask
+ zeros_1 = np.zeros([shfit_fsmn, 1])
+ ones_1 = np.ones([chunk_size, 1])
+ mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
+ mask_shift_att_chunk_decoder = np.concatenate(
+ [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
- self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
- self.x_len_chunk = x_len_chunk
- self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
- self.x_len = x_len
- self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
- self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
- self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
- self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
- self.chunk_outs = (self.x_add_mask,
- self.x_len_chunk,
- self.x_rm_mask,
- self.x_len,
- self.mask_shfit_chunk,
- self.mask_chunk_predictor,
- self.mask_att_chunk_encoder,
- self.mask_shift_att_chunk_decoder)
+ self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
+ self.x_len_chunk = x_len_chunk
+ self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
+ self.x_len = x_len
+ self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
+ self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
+ self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
+ self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
+ self.chunk_outs = (self.x_add_mask,
+ self.x_len_chunk,
+ self.x_rm_mask,
+ self.x_len,
+ self.mask_shfit_chunk,
+ self.mask_chunk_predictor,
+ self.mask_att_chunk_encoder,
+ self.mask_shift_att_chunk_decoder)
- return self.chunk_outs
+ return self.chunk_outs
- def split_chunk(self, x, x_len, chunk_outs):
- """
- :param x: (b, t, d)
- :param x_length: (b)
- :param ind: int
- :return:
- """
- x = x[:, :x_len.max(), :]
- b, t, d = x.size()
- x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
- x.device)
- x *= x_len_mask[:, :, None]
+ def split_chunk(self, x, x_len, chunk_outs):
+ """
+ :param x: (b, t, d)
+ :param x_length: (b)
+ :param ind: int
+ :return:
+ """
+ x = x[:, :x_len.max(), :]
+ b, t, d = x.size()
+ x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
+ x.device)
+ x *= x_len_mask[:, :, None]
- x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
- x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
- pad = (0, 0, self.pad_left_cur, 0)
- x = F.pad(x, pad, "constant", 0.0)
- b, t, d = x.size()
- x = torch.transpose(x, 1, 0)
- x = torch.reshape(x, [t, -1])
- x_chunk = torch.mm(x_add_mask, x)
- x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
+ x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
+ x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
+ pad = (0, 0, self.pad_left_cur, 0)
+ x = F.pad(x, pad, "constant", 0.0)
+ b, t, d = x.size()
+ x = torch.transpose(x, 1, 0)
+ x = torch.reshape(x, [t, -1])
+ x_chunk = torch.mm(x_add_mask, x)
+ x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
- return x_chunk, x_len_chunk
+ return x_chunk, x_len_chunk
- def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
- x_chunk = x_chunk[:, :x_len_chunk.max(), :]
- b, t, d = x_chunk.size()
- x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
- x_chunk.device)
- x_chunk *= x_len_chunk_mask[:, :, None]
+ def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
+ x_chunk = x_chunk[:, :x_len_chunk.max(), :]
+ b, t, d = x_chunk.size()
+ x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
+ x_chunk.device)
+ x_chunk *= x_len_chunk_mask[:, :, None]
- x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
- x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
- x_chunk = torch.transpose(x_chunk, 1, 0)
- x_chunk = torch.reshape(x_chunk, [t, -1])
- x = torch.mm(x_rm_mask, x_chunk)
- x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
+ x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
+ x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
+ x_chunk = torch.transpose(x_chunk, 1, 0)
+ x_chunk = torch.reshape(x_chunk, [t, -1])
+ x = torch.mm(x_rm_mask, x_chunk)
+ x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
- return x, x_len
+ return x, x_len
- def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
- def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
- with torch.no_grad():
- x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
- x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
- x = torch.from_numpy(x).type(dtype).to(device)
- return x
+ def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
+ with torch.no_grad():
+ x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
+ x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
+ x = torch.from_numpy(x).type(dtype).to(device)
+ return x
def build_scama_mask_for_cross_attention_decoder(
- predictor_alignments: torch.Tensor,
+ predictor_alignments: torch.Tensor,
encoder_sequence_length: torch.Tensor,
chunk_size: int = 5,
encoder_chunk_size: int = 5,
@@ -291,100 +289,100 @@
attention_chunk_size: int = 1,
attention_chunk_type: str = 'chunk',
step=None,
- predictor_mask_chunk_hopping: torch.Tensor = None,
- decoder_att_look_back_factor: int = 1,
- mask_shift_att_chunk_decoder: torch.Tensor = None,
- target_length: torch.Tensor = None,
- is_training=True,
+ predictor_mask_chunk_hopping: torch.Tensor = None,
+ decoder_att_look_back_factor: int = 1,
+ mask_shift_att_chunk_decoder: torch.Tensor = None,
+ target_length: torch.Tensor = None,
+ is_training=True,
dtype: torch.dtype = torch.float32):
- with torch.no_grad():
- device = predictor_alignments.device
- batch_size, chunk_num = predictor_alignments.size()
- maximum_encoder_length = encoder_sequence_length.max().item()
- int_type = predictor_alignments.dtype
- if not is_training:
- target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
- maximum_target_length = target_length.max()
- predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
- predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
-
-
- index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
- index = torch.cumsum(index, dim=1)
- index = index[:, :, None].repeat(1, 1, chunk_num)
-
- index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
- index_div_bool_zeros = index_div == 0
- index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
-
- index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
-
- index_div_bool_zeros_count *= chunk_size
- index_div_bool_zeros_count += attention_chunk_center_bias
- index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
- index_div_bool_zeros_count_ori = index_div_bool_zeros_count
-
- index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
- max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
-
- mask_flip, mask_flip2 = None, None
- if attention_chunk_size is not None:
- index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
- index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
- index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
- mask_flip = 1 - index_div_bool_zeros_count_beg_mask
- attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
- index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
-
- index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
- index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
- mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
-
- mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
-
- if predictor_mask_chunk_hopping is not None:
- b, k, t = mask.size()
- predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
-
- mask_mask_flip = mask
- if mask_flip is not None:
- mask_mask_flip = mask_flip * mask
-
- def _fn():
- mask_sliced = mask[:b, :k, encoder_chunk_size:t]
- zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
- mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
- _, _, tt = predictor_mask_chunk_hopping.size()
- pad_right_p = max_len_chunk - tt
- predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
- masked = mask_sliced * predictor_mask_chunk_hopping_pad
-
- mask_true = mask_mask_flip + masked
- return mask_true
-
- mask = _fn() if t > chunk_size else mask_mask_flip
-
-
-
- if mask_flip2 is not None:
- mask *= mask_flip2
-
- mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
- mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
-
-
-
- mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
- mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
-
-
-
-
- if attention_chunk_type == 'full':
- mask = torch.ones_like(mask).to(device)
- if mask_shift_att_chunk_decoder is not None:
- mask = mask * mask_shift_att_chunk_decoder
- mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
+ with torch.no_grad():
+ device = predictor_alignments.device
+ batch_size, chunk_num = predictor_alignments.size()
+ maximum_encoder_length = encoder_sequence_length.max().item()
+ int_type = predictor_alignments.dtype
+ if not is_training:
+ target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
+ maximum_target_length = target_length.max()
+ predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
+ predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
+
+
+ index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
+ index = torch.cumsum(index, dim=1)
+ index = index[:, :, None].repeat(1, 1, chunk_num)
+
+ index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
+ index_div_bool_zeros = index_div == 0
+ index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
+
+ index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
+
+ index_div_bool_zeros_count *= chunk_size
+ index_div_bool_zeros_count += attention_chunk_center_bias
+ index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
+ index_div_bool_zeros_count_ori = index_div_bool_zeros_count
+
+ index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
+ max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
+
+ mask_flip, mask_flip2 = None, None
+ if attention_chunk_size is not None:
+ index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
+ index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
+ index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
+ mask_flip = 1 - index_div_bool_zeros_count_beg_mask
+ attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
+ index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
+
+ index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
+ index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
+ mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
+
+ mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
+
+ if predictor_mask_chunk_hopping is not None:
+ b, k, t = mask.size()
+ predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
+
+ mask_mask_flip = mask
+ if mask_flip is not None:
+ mask_mask_flip = mask_flip * mask
+
+ def _fn():
+ mask_sliced = mask[:b, :k, encoder_chunk_size:t]
+ zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
+ mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
+ _, _, tt = predictor_mask_chunk_hopping.size()
+ pad_right_p = max_len_chunk - tt
+ predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
+ masked = mask_sliced * predictor_mask_chunk_hopping_pad
+
+ mask_true = mask_mask_flip + masked
+ return mask_true
+
+ mask = _fn() if t > chunk_size else mask_mask_flip
+
+
+
+ if mask_flip2 is not None:
+ mask *= mask_flip2
+
+ mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
+ mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
+
+
+
+ mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
+ mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
+
+
+
+
+ if attention_chunk_type == 'full':
+ mask = torch.ones_like(mask).to(device)
+ if mask_shift_att_chunk_decoder is not None:
+ mask = mask * mask_shift_att_chunk_decoder
+ mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
- return mask
+ return mask
diff --git a/funasr/models/scama/utils.py b/funasr/models/scama/utils.py
index 4bb9d4f..8832596 100644
--- a/funasr/models/scama/utils.py
+++ b/funasr/models/scama/utils.py
@@ -1,29 +1,30 @@
import os
-import torch
-from torch.nn import functional as F
import yaml
+import torch
import numpy as np
+from torch.nn import functional as F
+
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
- if maxlen is None:
- maxlen = lengths.max()
- row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
- matrix = torch.unsqueeze(lengths, dim=-1)
- mask = row_vector < matrix
- mask = mask.detach()
+ if maxlen is None:
+ maxlen = lengths.max()
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
+ matrix = torch.unsqueeze(lengths, dim=-1)
+ mask = row_vector < matrix
+ mask = mask.detach()
- return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def apply_cmvn(inputs, mvn):
- device = inputs.device
- dtype = inputs.dtype
- frame, dim = inputs.shape
- meams = np.tile(mvn[0:1, :dim], (frame, 1))
- vars = np.tile(mvn[1:2, :dim], (frame, 1))
- inputs -= torch.from_numpy(meams).type(dtype).to(device)
- inputs *= torch.from_numpy(vars).type(dtype).to(device)
+ device = inputs.device
+ dtype = inputs.dtype
+ frame, dim = inputs.shape
+ meams = np.tile(mvn[0:1, :dim], (frame, 1))
+ vars = np.tile(mvn[1:2, :dim], (frame, 1))
+ inputs -= torch.from_numpy(meams).type(dtype).to(device)
+ inputs *= torch.from_numpy(vars).type(dtype).to(device)
- return inputs.type(torch.float32)
+ return inputs.type(torch.float32)
@@ -36,56 +37,56 @@
- outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
- outputs *= stoch_layer_coeff
+ outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
+ outputs *= stoch_layer_coeff
- input_dim = inputs.size(-1)
- output_dim = outputs.size(-1)
+ input_dim = inputs.size(-1)
+ output_dim = outputs.size(-1)
- if input_dim == output_dim:
- outputs += inputs
- return outputs
+ if input_dim == output_dim:
+ outputs += inputs
+ return outputs
def proc_tf_vocab(vocab_path):
- with open(vocab_path, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
- if '<unk>' not in token_list:
- token_list.append('<unk>')
- return token_list
+ with open(vocab_path, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+ if '<unk>' not in token_list:
+ token_list.append('<unk>')
+ return token_list
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
- token_list = proc_tf_vocab(vocab_path)
- with open(config_path, encoding="utf-8") as f:
- config = yaml.safe_load(f)
-
- config['token_list'] = token_list
-
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
- yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
+ token_list = proc_tf_vocab(vocab_path)
+ with open(config_path, encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ config['token_list'] = token_list
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
+ yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
class NoAliasSafeDumper(yaml.SafeDumper):
- # Disable anchor/alias in yaml because looks ugly
- def ignore_aliases(self, data):
- return True
+ # Disable anchor/alias in yaml because looks ugly
+ def ignore_aliases(self, data):
+ return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
- """Safe-dump in yaml with no anchor/alias"""
- return yaml.dump(
- data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
- )
+ """Safe-dump in yaml with no anchor/alias"""
+ return yaml.dump(
+ data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
+ )
if __name__ == '__main__':
- import sys
-
- config_path = sys.argv[1]
- vocab_path = sys.argv[2]
- output_dir = sys.argv[3]
- gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
+ import sys
+
+ config_path = sys.argv[1]
+ vocab_path = sys.argv[2]
+ output_dir = sys.argv[3]
+ gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/e2e_uni_asr.py
index de7ed29..390d274 100644
--- a/funasr/models/uniasr/e2e_uni_asr.py
+++ b/funasr/models/uniasr/e2e_uni_asr.py
@@ -541,20 +541,20 @@
speech_lengths: (Batch, )
"""
# with autocast(False):
- # # 1. Extract feats
- # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ # # 1. Extract feats
+ # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
#
- # # 2. Data augmentation
- # if self.specaug is not None and self.training:
- # feats, feats_lengths = self.specaug(feats, feats_lengths)
+ # # 2. Data augmentation
+ # if self.specaug is not None and self.training:
+ # feats, feats_lengths = self.specaug(feats, feats_lengths)
#
- # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- # if self.normalize is not None:
- # feats, feats_lengths = self.normalize(feats, feats_lengths)
+ # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ # if self.normalize is not None:
+ # feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
# if self.preencoder is not None:
- # feats, feats_lengths = self.preencoder(feats, feats_lengths)
+ # feats, feats_lengths = self.preencoder(feats, feats_lengths)
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out,
encoder_out_lens,
@@ -584,9 +584,9 @@
# # Post-encoder, e.g. NLU
# if self.postencoder is not None:
- # encoder_out, encoder_out_lens = self.postencoder(
- # encoder_out, encoder_out_lens
- # )
+ # encoder_out, encoder_out_lens = self.postencoder(
+ # encoder_out, encoder_out_lens
+ # )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
diff --git a/funasr/optimizers/__init__.py b/funasr/optimizers/__init__.py
index 177f89e..a1a57a5 100644
--- a/funasr/optimizers/__init__.py
+++ b/funasr/optimizers/__init__.py
@@ -3,15 +3,15 @@
from funasr.optimizers.sgd import SGD
optim_classes = dict(
- adam=torch.optim.Adam,
- fairseq_adam=FairseqAdam,
- adamw=torch.optim.AdamW,
- sgd=SGD,
- adadelta=torch.optim.Adadelta,
- adagrad=torch.optim.Adagrad,
- adamax=torch.optim.Adamax,
- asgd=torch.optim.ASGD,
- lbfgs=torch.optim.LBFGS,
- rmsprop=torch.optim.RMSprop,
- rprop=torch.optim.Rprop,
+ adam=torch.optim.Adam,
+ fairseq_adam=FairseqAdam,
+ adamw=torch.optim.AdamW,
+ sgd=SGD,
+ adadelta=torch.optim.Adadelta,
+ adagrad=torch.optim.Adagrad,
+ adamax=torch.optim.Adamax,
+ asgd=torch.optim.ASGD,
+ lbfgs=torch.optim.LBFGS,
+ rmsprop=torch.optim.RMSprop,
+ rprop=torch.optim.Rprop,
)
\ No newline at end of file
diff --git a/funasr/schedulers/__init__.py b/funasr/schedulers/__init__.py
index cba286a..0d1a578 100644
--- a/funasr/schedulers/__init__.py
+++ b/funasr/schedulers/__init__.py
@@ -8,16 +8,16 @@
from funasr.schedulers.warmup_lr import WarmupLR
scheduler_classes = dict(
- ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
- lambdalr=torch.optim.lr_scheduler.LambdaLR,
- steplr=torch.optim.lr_scheduler.StepLR,
- multisteplr=torch.optim.lr_scheduler.MultiStepLR,
- exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
- CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
- noamlr=NoamLR,
- warmuplr=WarmupLR,
- tri_stage=TriStageLR,
- cycliclr=torch.optim.lr_scheduler.CyclicLR,
- onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
- CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+ ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+ lambdalr=torch.optim.lr_scheduler.LambdaLR,
+ steplr=torch.optim.lr_scheduler.StepLR,
+ multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+ exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+ CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+ noamlr=NoamLR,
+ warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
+ cycliclr=torch.optim.lr_scheduler.CyclicLR,
+ onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+ CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index 548bf06..136be13 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -1,100 +1,94 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterable
-from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
import json
-
import numpy as np
+from abc import ABC
+from pathlib import Path
+from abc import abstractmethod
+from typing import Union, Iterable, List, Dict
class AbsTokenizer(ABC):
- @abstractmethod
- def text2tokens(self, line: str) -> List[str]:
- raise NotImplementedError
-
- @abstractmethod
- def tokens2text(self, tokens: Iterable[str]) -> str:
- raise NotImplementedError
+ @abstractmethod
+ def text2tokens(self, line: str) -> List[str]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ raise NotImplementedError
class BaseTokenizer(ABC):
- def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
- unk_symbol: str = "<unk>",
- **kwargs,
- ):
-
- if token_list is not None:
- if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
-
- with token_list.open("r", encoding="utf-8") as f:
- for idx, line in enumerate(f):
- line = line.rstrip()
- self.token_list.append(line)
- elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
-
- with open(token_list, 'r', encoding='utf-8') as f:
- self.token_list = json.load(f)
-
-
- else:
- self.token_list: List[str] = list(token_list)
- self.token_list_repr = ""
- for i, t in enumerate(self.token_list):
- if i == 3:
- break
- self.token_list_repr += f"{t}, "
- self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-
- self.token2id: Dict[str, int] = {}
- for i, t in enumerate(self.token_list):
- if t in self.token2id:
- raise RuntimeError(f'Symbol "{t}" is duplicated')
- self.token2id[t] = i
-
- self.unk_symbol = unk_symbol
- if self.unk_symbol not in self.token2id:
- raise RuntimeError(
- f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
- )
- self.unk_id = self.token2id[self.unk_symbol]
-
- def encode(self, text):
- tokens = self.text2tokens(text)
- text_ints = self.tokens2ids(tokens)
-
- return text_ints
-
- def decode(self, text_ints):
- token = self.ids2tokens(text_ints)
- text = self.tokens2text(token)
- return text
-
- def get_num_vocabulary_size(self) -> int:
- return len(self.token_list)
-
- def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
- if isinstance(integers, np.ndarray) and integers.ndim != 1:
- raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
- return [self.token_list[i] for i in integers]
-
- def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
- return [self.token2id.get(i, self.unk_id) for i in tokens]
-
- @abstractmethod
- def text2tokens(self, line: str) -> List[str]:
- raise NotImplementedError
-
- @abstractmethod
- def tokens2text(self, tokens: Iterable[str]) -> str:
- raise NotImplementedError
\ No newline at end of file
+ def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
+ unk_symbol: str = "<unk>",
+ **kwargs,
+ ):
+
+ if token_list is not None:
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with token_list.open("r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ line = line.rstrip()
+ self.token_list.append(line)
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with open(token_list, 'r', encoding='utf-8') as f:
+ self.token_list = json.load(f)
+
+
+ else:
+ self.token_list: List[str] = list(token_list)
+ self.token_list_repr = ""
+ for i, t in enumerate(self.token_list):
+ if i == 3:
+ break
+ self.token_list_repr += f"{t}, "
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+
+ self.token2id: Dict[str, int] = {}
+ for i, t in enumerate(self.token_list):
+ if t in self.token2id:
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
+ self.token2id[t] = i
+
+ self.unk_symbol = unk_symbol
+ if self.unk_symbol not in self.token2id:
+ raise RuntimeError(
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+ )
+ self.unk_id = self.token2id[self.unk_symbol]
+
+ def encode(self, text):
+ tokens = self.text2tokens(text)
+ text_ints = self.tokens2ids(tokens)
+
+ return text_ints
+
+ def decode(self, text_ints):
+ token = self.ids2tokens(text_ints)
+ text = self.tokens2text(token)
+ return text
+
+ def get_num_vocabulary_size(self) -> int:
+ return len(self.token_list)
+
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
+ raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
+ return [self.token_list[i] for i in integers]
+
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
+
+ @abstractmethod
+ def text2tokens(self, line: str) -> List[str]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 59aeaf0..0f0acc2 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -1,233 +1,235 @@
-import torch
import os
-from funasr.train_utils.device_funcs import to_device
-import logging
import time
+import torch
+import logging
from tqdm import tqdm
-from contextlib import nullcontext
import torch.distributed as dist
+from contextlib import nullcontext
+
+from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
+
class Trainer:
- """
- A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
- and optionally resuming from a saved checkpoint.
+ """
+ A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+ and optionally resuming from a saved checkpoint.
- Attributes:
- max_epoch (int): Maximum number of epochs for training.
- model (torch.nn.Module): The model to be trained.
- optim (torch.optim.Optimizer): The optimizer to use for training.
- scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
- dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
- dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
- output_dir (str): Directory where model checkpoints will be saved.
- resume (str, optional): Path to a checkpoint to resume training from.
- """
-
- def __init__(self, model,
- optim,
- scheduler,
- dataloader_train,
- dataloader_val,
- local_rank,
- use_ddp=False,
- use_fsdp=False,
- **kwargs):
- """
- Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+ Attributes:
+ max_epoch (int): Maximum number of epochs for training.
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+ output_dir (str): Directory where model checkpoints will be saved.
+ resume (str, optional): Path to a checkpoint to resume training from.
+ """
+
+ def __init__(self, model,
+ optim,
+ scheduler,
+ dataloader_train,
+ dataloader_val,
+ local_rank,
+ use_ddp=False,
+ use_fsdp=False,
+ **kwargs):
+ """
+ Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
- Args:
- model (torch.nn.Module): The model to be trained.
- optim (torch.optim.Optimizer): The optimizer to use for training.
- scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
- dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
- dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
- **kwargs: Additional keyword arguments:
- max_epoch (int): The maximum number of epochs for training.
- output_dir (str): The directory where model checkpoints will be saved. Default is './'.
- resume (str, optional): The file path to a checkpoint to resume training from.
- """
-
- self.model = model
- self.optim = optim
- self.scheduler = scheduler
- self.dataloader_train = dataloader_train
- self.dataloader_val = dataloader_val
- self.output_dir = kwargs.get('output_dir', './')
- self.resume = kwargs.get('resume', True)
- self.start_epoch = 0
- self.max_epoch = kwargs.get('max_epoch', 100)
- self.local_rank = local_rank
- self.use_ddp = use_ddp
- self.use_fsdp = use_fsdp
- self.device = next(model.parameters()).device
- self.kwargs = kwargs
-
- if self.resume:
- self._resume_checkpoint(self.resume)
-
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- logging.warning("distributed is not initialized, only single shard")
- self.rank = rank
- self.world_size = world_size
-
- def _save_checkpoint(self, epoch):
- """
- Saves a checkpoint containing the model's state, the optimizer's state,
- and the scheduler's state at the end of the given epoch. This method is
- intended to be called at the end of each epoch to save the training progress.
+ Args:
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+ **kwargs: Additional keyword arguments:
+ max_epoch (int): The maximum number of epochs for training.
+ output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+ resume (str, optional): The file path to a checkpoint to resume training from.
+ """
+
+ self.model = model
+ self.optim = optim
+ self.scheduler = scheduler
+ self.dataloader_train = dataloader_train
+ self.dataloader_val = dataloader_val
+ self.output_dir = kwargs.get('output_dir', './')
+ self.resume = kwargs.get('resume', True)
+ self.start_epoch = 0
+ self.max_epoch = kwargs.get('max_epoch', 100)
+ self.local_rank = local_rank
+ self.use_ddp = use_ddp
+ self.use_fsdp = use_fsdp
+ self.device = next(model.parameters()).device
+ self.kwargs = kwargs
+
+ if self.resume:
+ self._resume_checkpoint(self.resume)
+
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+ logging.warning("distributed is not initialized, only single shard")
+ self.rank = rank
+ self.world_size = world_size
+
+ def _save_checkpoint(self, epoch):
+ """
+ Saves a checkpoint containing the model's state, the optimizer's state,
+ and the scheduler's state at the end of the given epoch. This method is
+ intended to be called at the end of each epoch to save the training progress.
- Args:
- epoch (int): The epoch number at which the checkpoint is being saved.
- """
- state = {
- 'epoch': epoch,
- 'state_dict': self.model.state_dict(),
- 'optimizer': self.optim.state_dict(),
- 'scheduler': self.scheduler.state_dict(),
- }
- # Create output directory if it does not exist
- os.makedirs(self.output_dir, exist_ok=True)
- filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
- torch.save(state, filename)
- print(f'Checkpoint saved to {filename}')
-
- def _resume_checkpoint(self, resume_path):
- """
- Resumes training from a checkpoint at the given file path.
- Loads the model's state, the optimizer's state, and the scheduler's state.
+ Args:
+ epoch (int): The epoch number at which the checkpoint is being saved.
+ """
+ state = {
+ 'epoch': epoch,
+ 'state_dict': self.model.state_dict(),
+ 'optimizer': self.optim.state_dict(),
+ 'scheduler': self.scheduler.state_dict(),
+ }
+ # Create output directory if it does not exist
+ os.makedirs(self.output_dir, exist_ok=True)
+ filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
+ torch.save(state, filename)
+ print(f'Checkpoint saved to {filename}')
+
+ def _resume_checkpoint(self, resume_path):
+ """
+ Resumes training from a checkpoint at the given file path.
+ Loads the model's state, the optimizer's state, and the scheduler's state.
- Args:
- resume_path (str): The file path to the checkpoint to resume from.
- """
- if os.path.isfile(resume_path):
- checkpoint = torch.load(resume_path)
- self.start_epoch = checkpoint['epoch'] + 1
- self.model.load_state_dict(checkpoint['state_dict'])
- self.optim.load_state_dict(checkpoint['optimizer'])
- self.scheduler.load_state_dict(checkpoint['scheduler'])
- print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
- else:
- print(f"No checkpoint found at '{resume_path}', starting from scratch")
-
- def run(self):
- """
- Starts the training process, iterating over epochs, training the model,
- and saving checkpoints at the end of each epoch.
- """
- for epoch in range(self.start_epoch, self.max_epoch + 1):
- self._train_epoch(epoch)
- # self._validate_epoch(epoch)
- if self.rank == 0:
- self._save_checkpoint(epoch)
- self.scheduler.step()
-
- def _train_epoch(self, epoch):
- """
- Defines the training process for a single epoch with gradient accumulation.
- Args:
- epoch (int): The current epoch number.
- """
- self.model.train()
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
- dynamic_ncols=True)
-
- # Set the number of steps for gradient accumulation
- accum_grad = self.kwargs.get("accum_grad", 1)
- # Initialize the gradient accumulation
- self.optim.zero_grad()
- speed_stats = {}
- time5 = time.perf_counter()
- for batch_idx, batch in enumerate(self.dataloader_train):
- time1 = time.perf_counter()
- speed_stats["data_load"] = f"{time1-time5:0.3f}"
- # import pdb;
- # pdb.set_trace()
- batch = to_device(batch, self.device)
-
- my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
- with my_context():
- time2 = time.perf_counter()
- retval = self.model(**batch)
- time3 = time.perf_counter()
- speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
- loss, stats, weight = retval
- stats = {k: v for k, v in stats.items() if v is not None}
- if self.use_ddp or self.use_fsdp:
- # Apply weighted averaging for loss and stats
- loss = (loss * weight.type(loss.dtype)).sum()
- # if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed=True)
- # Now weight is summation over all workers
- loss /= weight
- # Multiply world_size because DistributedDataParallel
- # automatically normalizes the gradient by world_size.
- loss *= self.world_size
- # Scale the loss since we're not updating for every mini-batch
- loss = loss / accum_grad
- loss.backward()
- time4 = time.perf_counter()
- speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
-
- # Perform an optimizer step only after accumulating enough gradients
- if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
- # Perform gradient clipping if it is set
- if self.kwargs.get("grad_clip", None) is not None:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self.model.parameters(),
- max_norm=self.kwargs.get("grad_clip", 10.0),
- norm_type=self.kwargs.get("grad_clip_type", 2.0),
- )
- if not torch.isfinite(grad_norm):
- logging.warning(
- f"The grad norm is {grad_norm}. Skipping updating the model."
- )
- self.optim.zero_grad() # Reset gradients
- continue
-
- # Execute an optimization step (update model parameters)
- self.optim.step()
- self.scheduler.step()
- # Clear gradients for the next accumulation stage
- self.optim.zero_grad()
- total_time = f"{time.perf_counter() - time5:0.3f}"
- time5 = time.perf_counter()
- speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
-
- speed_stats["total_time"] = total_time
-
- # import pdb;
- # pdb.set_trace()
- pbar.update(1)
- if self.local_rank == 0:
- description = (
- f"Epoch: {epoch + 1}/{self.max_epoch}, "
- f"step {batch_idx}/{len(self.dataloader_train)}, "
- f"{speed_stats}, "
- f"(loss: {loss.detach().cpu().item():.3f}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
- )
- pbar.set_description(description)
-
- # if batch_idx == 2:
- # break
- pbar.close()
+ Args:
+ resume_path (str): The file path to the checkpoint to resume from.
+ """
+ if os.path.isfile(resume_path):
+ checkpoint = torch.load(resume_path)
+ self.start_epoch = checkpoint['epoch'] + 1
+ self.model.load_state_dict(checkpoint['state_dict'])
+ self.optim.load_state_dict(checkpoint['optimizer'])
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
+ print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+ else:
+ print(f"No checkpoint found at '{resume_path}', starting from scratch")
+
+ def run(self):
+ """
+ Starts the training process, iterating over epochs, training the model,
+ and saving checkpoints at the end of each epoch.
+ """
+ for epoch in range(self.start_epoch, self.max_epoch + 1):
+ self._train_epoch(epoch)
+ # self._validate_epoch(epoch)
+ if self.rank == 0:
+ self._save_checkpoint(epoch)
+ self.scheduler.step()
+
+ def _train_epoch(self, epoch):
+ """
+ Defines the training process for a single epoch with gradient accumulation.
+ Args:
+ epoch (int): The current epoch number.
+ """
+ self.model.train()
+ pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
+ dynamic_ncols=True)
+
+ # Set the number of steps for gradient accumulation
+ accum_grad = self.kwargs.get("accum_grad", 1)
+ # Initialize the gradient accumulation
+ self.optim.zero_grad()
+ speed_stats = {}
+ time5 = time.perf_counter()
+ for batch_idx, batch in enumerate(self.dataloader_train):
+ time1 = time.perf_counter()
+ speed_stats["data_load"] = f"{time1-time5:0.3f}"
+ # import pdb;
+ # pdb.set_trace()
+ batch = to_device(batch, self.device)
+
+ my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+ with my_context():
+ time2 = time.perf_counter()
+ retval = self.model(**batch)
+ time3 = time.perf_counter()
+ speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+ loss, stats, weight = retval
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if self.use_ddp or self.use_fsdp:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed=True)
+ # Now weight is summation over all workers
+ loss /= weight
+ # Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= self.world_size
+ # Scale the loss since we're not updating for every mini-batch
+ loss = loss / accum_grad
+ loss.backward()
+ time4 = time.perf_counter()
+ speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+
+ # Perform an optimizer step only after accumulating enough gradients
+ if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
+ # Perform gradient clipping if it is set
+ if self.kwargs.get("grad_clip", None) is not None:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ max_norm=self.kwargs.get("grad_clip", 10.0),
+ norm_type=self.kwargs.get("grad_clip_type", 2.0),
+ )
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+ self.optim.zero_grad() # Reset gradients
+ continue
+
+ # Execute an optimization step (update model parameters)
+ self.optim.step()
+ self.scheduler.step()
+ # Clear gradients for the next accumulation stage
+ self.optim.zero_grad()
+ total_time = f"{time.perf_counter() - time5:0.3f}"
+ time5 = time.perf_counter()
+ speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+
+ speed_stats["total_time"] = total_time
+
+ # import pdb;
+ # pdb.set_trace()
+ pbar.update(1)
+ if self.local_rank == 0:
+ description = (
+ f"Epoch: {epoch + 1}/{self.max_epoch}, "
+ f"step {batch_idx}/{len(self.dataloader_train)}, "
+ f"{speed_stats}, "
+ f"(loss: {loss.detach().cpu().item():.3f}), "
+ f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ )
+ pbar.set_description(description)
+
+ # if batch_idx == 2:
+ # break
+ pbar.close()
- def _validate_epoch(self, epoch):
- """
- Defines the validation process for a single epoch.
- Should be implemented with the actual model validation steps.
-
- Args:
- epoch (int): The current epoch number.
- """
- self.model.eval()
- with torch.no_grad():
- for data, target in self.dataloader_val:
- # Implement the model validation steps here
- pass
+ def _validate_epoch(self, epoch):
+ """
+ Defines the validation process for a single epoch.
+ Should be implemented with the actual model validation steps.
+
+ Args:
+ epoch (int): The current epoch number.
+ """
+ self.model.eval()
+ with torch.no_grad():
+ for data, target in self.dataloader_val:
+ # Implement the model validation steps here
+ pass
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 4e131a8..9cd3854 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -10,100 +10,100 @@
import logging
from torch.nn.utils.rnn import pad_sequence
try:
- from funasr.download.file import download_from_url
+ from funasr.download.file import download_from_url
except:
- print("urllib is not installed, if you infer from url, please install it first.")
+ print("urllib is not installed, if you infer from url, please install it first.")
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
- if isinstance(data_or_path_or_list, (list, tuple)):
- if data_type is not None and isinstance(data_type, (list, tuple)):
+ if isinstance(data_or_path_or_list, (list, tuple)):
+ if data_type is not None and isinstance(data_type, (list, tuple)):
- data_types = [data_type] * len(data_or_path_or_list)
- data_or_path_or_list_ret = [[] for d in data_type]
- for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
-
- for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
-
- data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
- data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
+ data_types = [data_type] * len(data_or_path_or_list)
+ data_or_path_or_list_ret = [[] for d in data_type]
+ for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
+
+ for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
+
+ data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
+ data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
- return data_or_path_or_list_ret
- else:
- return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
-
- if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
- data_or_path_or_list = download_from_url(data_or_path_or_list)
-
- if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
- if data_type is None or data_type == "sound":
- data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
- data_or_path_or_list = data_or_path_or_list[0, :]
- elif data_type == "text" and tokenizer is not None:
- data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
- elif data_type == "image": # undo
- pass
- elif data_type == "video": # undo
- pass
-
- # if data_in is a file or url, set is_final=True
- if "cache" in kwargs:
- kwargs["cache"]["is_final"] = True
- elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
- data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
- elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
- data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
- else:
- pass
- # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
-
- if audio_fs != fs and data_type != "text":
- resampler = torchaudio.transforms.Resample(audio_fs, fs)
- data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
- return data_or_path_or_list
+ return data_or_path_or_list_ret
+ else:
+ return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
+
+ if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
+ data_or_path_or_list = download_from_url(data_or_path_or_list)
+
+ if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
+ if data_type is None or data_type == "sound":
+ data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
+ data_or_path_or_list = data_or_path_or_list[0, :]
+ elif data_type == "text" and tokenizer is not None:
+ data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+ elif data_type == "image": # undo
+ pass
+ elif data_type == "video": # undo
+ pass
+
+ # if data_in is a file or url, set is_final=True
+ if "cache" in kwargs:
+ kwargs["cache"]["is_final"] = True
+ elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
+ data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+ elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
+ data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
+ else:
+ pass
+ # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
+
+ if audio_fs != fs and data_type != "text":
+ resampler = torchaudio.transforms.Resample(audio_fs, fs)
+ data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
+ return data_or_path_or_list
def load_bytes(input):
- middle_data = np.frombuffer(input, dtype=np.int16)
- middle_data = np.asarray(middle_data)
- if middle_data.dtype.kind not in 'iu':
- raise TypeError("'middle_data' must be an array of integers")
- dtype = np.dtype('float32')
- if dtype.kind != 'f':
- raise TypeError("'dtype' must be a floating point type")
-
- i = np.iinfo(middle_data.dtype)
- abs_max = 2 ** (i.bits - 1)
- offset = i.min + abs_max
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
- return array
+ middle_data = np.frombuffer(input, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in 'iu':
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype('float32')
+ if dtype.kind != 'f':
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2 ** (i.bits - 1)
+ offset = i.min + abs_max
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+ return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
- # import pdb;
- # pdb.set_trace()
- if isinstance(data, np.ndarray):
- data = torch.from_numpy(data)
- if len(data.shape) < 2:
- data = data[None, :] # data: [batch, N]
- data_len = [data.shape[1]] if data_len is None else data_len
- elif isinstance(data, torch.Tensor):
- if len(data.shape) < 2:
- data = data[None, :] # data: [batch, N]
- data_len = [data.shape[1]] if data_len is None else data_len
- elif isinstance(data, (list, tuple)):
- data_list, data_len = [], []
- for data_i in data:
- if isinstance(data_i, np.ndarray):
- data_i = torch.from_numpy(data_i)
- data_list.append(data_i)
- data_len.append(data_i.shape[0])
- data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
- # import pdb;
- # pdb.set_trace()
- # if data_type == "sound":
- data, data_len = frontend(data, data_len, **kwargs)
-
- if isinstance(data_len, (list, tuple)):
- data_len = torch.tensor([data_len])
- return data.to(torch.float32), data_len.to(torch.int32)
+ # import pdb;
+ # pdb.set_trace()
+ if isinstance(data, np.ndarray):
+ data = torch.from_numpy(data)
+ if len(data.shape) < 2:
+ data = data[None, :] # data: [batch, N]
+ data_len = [data.shape[1]] if data_len is None else data_len
+ elif isinstance(data, torch.Tensor):
+ if len(data.shape) < 2:
+ data = data[None, :] # data: [batch, N]
+ data_len = [data.shape[1]] if data_len is None else data_len
+ elif isinstance(data, (list, tuple)):
+ data_list, data_len = [], []
+ for data_i in data:
+ if isinstance(data_i, np.ndarray):
+ data_i = torch.from_numpy(data_i)
+ data_list.append(data_i)
+ data_len.append(data_i.shape[0])
+ data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
+ # import pdb;
+ # pdb.set_trace()
+ # if data_type == "sound":
+ data, data_len = frontend(data, data_len, **kwargs)
+
+ if isinstance(data_len, (list, tuple)):
+ data_len = torch.tensor([data_len])
+ return data.to(torch.float32), data_len.to(torch.int32)
diff --git a/funasr/utils/vad_utils.py b/funasr/utils/vad_utils.py
index f84e2b9..af7c8f2 100644
--- a/funasr/utils/vad_utils.py
+++ b/funasr/utils/vad_utils.py
@@ -1,31 +1,31 @@
import torch
from torch.nn.utils.rnn import pad_sequence
-def slice_padding_fbank(speech, speech_lengths, vad_segments):
- speech_list = []
- speech_lengths_list = []
- for i, segment in enumerate(vad_segments):
-
- bed_idx = int(segment[0][0]*16)
- end_idx = min(int(segment[0][1]*16), speech_lengths[0])
- speech_i = speech[0, bed_idx: end_idx]
- speech_lengths_i = end_idx-bed_idx
- speech_list.append(speech_i)
- speech_lengths_list.append(speech_lengths_i)
- feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
- speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
- return feats_pad, speech_lengths_pad
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0]*16)
+ end_idx = min(int(segment[0][1]*16), speech_lengths[0])
+ speech_i = speech[0, bed_idx: end_idx]
+ speech_lengths_i = end_idx-bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
- speech_list = []
- speech_lengths_list = []
- for i, segment in enumerate(vad_segments):
- bed_idx = int(segment[0][0] * 16)
- end_idx = min(int(segment[0][1] * 16), speech_lengths)
- speech_i = speech[bed_idx: end_idx]
- speech_lengths_i = end_idx - bed_idx
- speech_list.append(speech_i)
- speech_lengths_list.append(speech_lengths_i)
-
- return speech_list, speech_lengths_list
\ No newline at end of file
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx: end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+
+ return speech_list, speech_lengths_list
\ No newline at end of file
diff --git a/runtime/python/utils/test_cer.py b/runtime/python/utils/test_cer.py
index e27e393..d795d33 100644
--- a/runtime/python/utils/test_cer.py
+++ b/runtime/python/utils/test_cer.py
@@ -17,8 +17,8 @@
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
- from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
-
+ from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@@ -26,23 +26,23 @@
output_dir = args.output_dir
if not os.path.exists(output_dir):
- os.makedirs(output_dir)
+ os.makedirs(output_dir)
if os.name == 'nt': # Windows
- newline = '\r\n'
+ newline = '\r\n'
else: # Linux Mac
- newline = '\n'
+ newline = '\n'
text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
for i, wav_path_i in enumerate(wav_files):
- wav_name, wav_path = wav_path_i.strip().split()
- result = model(wav_path)
- text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
- token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
- text_f.write(text_i)
- text_f.flush()
- token_f.write(token_i)
- token_f.flush()
+ wav_name, wav_path = wav_path_i.strip().split()
+ result = model(wav_path)
+ text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
+ token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
+ text_f.write(text_i)
+ text_f.flush()
+ token_f.write(token_i)
+ token_f.flush()
text_f.close()
token_f.close()
-
+
diff --git a/runtime/python/utils/test_rtf.py b/runtime/python/utils/test_rtf.py
index 391a0ac..3fe96a3 100644
--- a/runtime/python/utils/test_rtf.py
+++ b/runtime/python/utils/test_rtf.py
@@ -16,8 +16,8 @@
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
- from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
-
+ from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@@ -28,28 +28,28 @@
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
- beg_time = time.time()
- result = model(wav_path)
- end_time = time.time()
- duration = end_time-beg_time
- total += duration
- print(result)
- print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
+ beg_time = time.time()
+ result = model(wav_path)
+ end_time = time.time()
+ duration = end_time-beg_time
+ total += duration
+ print(result)
+ print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
- wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
- result = model(wav_path)
+ wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+ result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
- wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
- waveform, _ = librosa.load(wav_path, sr=16000)
- duration_time += len(waveform)/16.0
+ wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+ waveform, _ = librosa.load(wav_path, sr=16000)
+ duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))
\ No newline at end of file
diff --git a/runtime/python/utils/test_rtf_gpu.py b/runtime/python/utils/test_rtf_gpu.py
index 84cd2c7..02d5ac6 100644
--- a/runtime/python/utils/test_rtf_gpu.py
+++ b/runtime/python/utils/test_rtf_gpu.py
@@ -17,8 +17,8 @@
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
- from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
-
+ from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+
model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@@ -29,20 +29,20 @@
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
- beg_time = time.time()
- result = model(wav_path)
- end_time = time.time()
- duration = end_time-beg_time
- total += duration
- print(result)
- print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
+ beg_time = time.time()
+ result = model(wav_path)
+ end_time = time.time()
+ duration = end_time-beg_time
+ total += duration
+ print(result)
+ print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
wav_path = []
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
- wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
- wav_path += [wav_path_i]
+ wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+ wav_path += [wav_path_i]
result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
@@ -50,9 +50,9 @@
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
- wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
- waveform, _ = librosa.load(wav_path, sr=16000)
- duration_time += len(waveform)/16.0
+ wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+ waveform, _ = librosa.load(wav_path, sr=16000)
+ duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))
\ No newline at end of file
--
Gitblit v1.9.1