Dev gzf exp (#1682)
* resume from step
| | |
| | | export CUDA_VISIBLE_DEVICES="0,1" |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | |
| | | torchrun --nnodes 1 --nproc_per_node ${gpu_num} \ |
| | | torchrun --nnodes 1 --nproc_per_node ${gpu_num} --master_port 12345 \ |
| | | ../../../funasr/bin/train.py ${train_args} |
| | | ``` |
| | | --nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node. |
| | | --nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node. --master_port indicates the port is 12345 |
| | | |
| | | ##### Multi-Machine Multi-GPU Training |
| | | |
| | |
| | | export CUDA_VISIBLE_DEVICES="0,1" |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | |
| | | torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \ |
| | | torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr 192.168.1.1 --master_port 12345 \ |
| | | ../../../funasr/bin/train.py ${train_args} |
| | | ``` |
| | | On the worker node (assuming the IP is 192.168.1.2), you need to ensure that the MASTER_ADDR and MASTER_PORT environment variables are set to match those of the master node, and then run the same command: |
| | |
| | | export CUDA_VISIBLE_DEVICES="0,1" |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | |
| | | torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \ |
| | | torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr 192.168.1.1 --master_port 12345 \ |
| | | ../../../funasr/bin/train.py ${train_args} |
| | | ``` |
| | | |
| | |
| | | export CUDA_VISIBLE_DEVICES="0,1" |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | |
| | | torchrun --nnodes 1 --nproc_per_node ${gpu_num} \ |
| | | torchrun --nnodes 1 --nproc_per_node ${gpu_num} --master_port 12345 \ |
| | | ../../../funasr/bin/train.py ${train_args} |
| | | ``` |
| | | --nnodes 表示参与的节点总数,--nproc_per_node 表示每个节点上运行的进程数 |
| | | --nnodes 表示参与的节点总数,--nproc_per_node 表示每个节点上运行的进程数,--master_port 表示端口号 |
| | | |
| | | ##### 多机多gpu训练 |
| | | |
| | |
| | | ../../../funasr/bin/train.py ${train_args} |
| | | ``` |
| | | |
| | | --nnodes 表示参与的节点总数,--node_rank 表示当前节点id,--nproc_per_node 表示每个节点上运行的进程数(通常为gpu个数) |
| | | --nnodes 表示参与的节点总数,--node_rank 表示当前节点id,--nproc_per_node 表示每个节点上运行的进程数(通常为gpu个数),--master_port 表示端口号 |
| | | |
| | | #### 准备数据 |
| | | |
| | |
| | | tag="exp1" |
| | | workspace=`pwd` |
| | | |
| | | master_port=12345 |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | --master_port ${master_port} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | |
| | | tag="exp1" |
| | | workspace=`pwd` |
| | | |
| | | master_port=12345 |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | --master_port ${master_port} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | |
| | | tag="exp1" |
| | | workspace=`pwd` |
| | | |
| | | master_port=12345 |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | --master_port ${master_port} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | |
| | | tag="exp1" |
| | | workspace=`pwd` |
| | | |
| | | master_port=12345 |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | --master_port ${master_port} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | |
| | | tag="exp1" |
| | | workspace=`pwd` |
| | | |
| | | master_port=12345 |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | --master_port ${master_port} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | |
| | | dataloader_tr, dataloader_val = dataloader.build_iter( |
| | | epoch, data_split_i=data_split_i, start_step=trainer.start_step |
| | | ) |
| | | trainer.start_step = 0 |
| | | |
| | | trainer.train_epoch( |
| | | model=model, |
| | |
| | | writer=writer, |
| | | data_split_i=data_split_i, |
| | | data_split_num=dataloader.data_split_num, |
| | | start_step=trainer.start_step, |
| | | ) |
| | | trainer.start_step = 0 |
| | | |
| | | torch.cuda.empty_cache() |
| | | |
| | |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.min_token_length = kwargs.get("min_token_length", 0) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | self.start_step = 0 |
| | | self.start_step = start_step |
| | | if self.start_step > 0: |
| | | logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}") |
| | | # super().__init__(dataset, num_replicas=num_replicas, rank=rank, |
| | |
| | | start_idx = self.rank * batches_per_rank |
| | | end_idx = start_idx + batches_per_rank |
| | | rank_batches = buffer_batches[start_idx + self.start_step : end_idx] |
| | | |
| | | if self.start_step > 0: |
| | | logging.info( |
| | | f"Warning, rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num_before: {end_idx-start_idx}, now: {len(rank_batches)}" |
| | | ) |
| | | # Return an iterator over the batches for the current rank |
| | | return iter(rank_batches) |
| | | |
| | |
| | | with open(path, encoding="utf-8") as fin: |
| | | file_list_all = fin.readlines() |
| | | |
| | | num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 |
| | | num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16 |
| | | file_list = file_list_all[ |
| | | data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice |
| | | ] |
| | |
| | | with open(data_file, "r") as f: |
| | | |
| | | data_file_lists = f.readlines() |
| | | print("") |
| | | lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1 |
| | | task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 |
| | | # import pdb;pdb.set_trace() |
| New file |
| | |
| | | import os |
| | | import json |
| | | import torch |
| | | import logging |
| | | import hydra |
| | | from omegaconf import DictConfig, OmegaConf |
| | | import concurrent.futures |
| | | import librosa |
| | | import torch.distributed as dist |
| | | from tqdm import tqdm |
| | | |
| | | |
| | | def gen_jsonl_from_wav_text_list( |
| | | path, data_type_list=("source",), jsonl_file_out: str = None, **kwargs |
| | | ): |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | |
| | | cpu_cores = os.cpu_count() or 1 |
| | | print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}") |
| | | if rank == 0: |
| | | json_dict = {} |
| | | # for data_type, data_file in zip(data_type_list, path): |
| | | data_type = data_type_list[0] |
| | | data_file = path |
| | | json_dict[data_type] = {} |
| | | with open(data_file, "r") as f: |
| | | |
| | | data_file_lists = f.readlines() |
| | | print("") |
| | | lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1 |
| | | task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 |
| | | # import pdb;pdb.set_trace() |
| | | if task_num > 1: |
| | | with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor: |
| | | |
| | | futures = [ |
| | | executor.submit( |
| | | parse_context_length, |
| | | data_file_lists[i * lines_for_each_th : (i + 1) * lines_for_each_th], |
| | | data_type, |
| | | i, |
| | | ) |
| | | for i in range(task_num) |
| | | ] |
| | | |
| | | for future in concurrent.futures.as_completed(futures): |
| | | |
| | | json_dict[data_type].update(future.result()) |
| | | else: |
| | | res = parse_context_length(data_file_lists, data_type) |
| | | json_dict[data_type].update(res) |
| | | |
| | | with open(jsonl_file_out, "w") as f: |
| | | for key in json_dict[data_type_list[0]].keys(): |
| | | jsonl_line = {"key": key} |
| | | for data_file in data_type_list: |
| | | jsonl_line.update(json_dict[data_file][key]) |
| | | # jsonl_line = json.dumps(jsonl_line, ensure_ascii=False) |
| | | source_len = jsonl_line["source_len"] |
| | | jsonl_line = f"{key} {source_len}" |
| | | f.write(jsonl_line + "\n") |
| | | f.flush() |
| | | print(f"processed {len(json_dict[data_type_list[0]])} samples") |
| | | |
| | | else: |
| | | pass |
| | | |
| | | if world_size > 1: |
| | | dist.barrier() |
| | | |
| | | |
| | | def parse_context_length(data_list: list, data_type: str, id=0): |
| | | pbar = tqdm(total=len(data_list), dynamic_ncols=True) |
| | | res = {} |
| | | for i, line in enumerate(data_list): |
| | | pbar.update(1) |
| | | pbar.set_description(f"cpu: {id}") |
| | | lines = line.strip().split(maxsplit=1) |
| | | key = lines[0] |
| | | line = lines[1] if len(lines) > 1 else "" |
| | | line = line.strip() |
| | | if os.path.exists(line): |
| | | waveform, _ = librosa.load(line, sr=16000) |
| | | sample_num = len(waveform) |
| | | context_len = int(sample_num / 16000 * 1000 / 10) |
| | | else: |
| | | context_len = len(line.split()) if " " in line else len(line) |
| | | res[key] = {data_type: line, f"{data_type}_len": context_len} |
| | | return res |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(cfg: DictConfig): |
| | | |
| | | kwargs = OmegaConf.to_container(cfg, resolve=True) |
| | | print(kwargs) |
| | | |
| | | scp_file_list = kwargs.get("scp_file_list", "/Users/zhifu/funasr1.0/data/list/train_wav.scp") |
| | | # if isinstance(scp_file_list, str): |
| | | # scp_file_list = eval(scp_file_list) |
| | | data_type_list = kwargs.get("data_type_list", ("source",)) |
| | | jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/data/list/wav_len.txt") |
| | | gen_jsonl_from_wav_text_list( |
| | | scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out |
| | | ) |
| | | |
| | | |
| | | """ |
| | | python -m funasr.datasets.audio_datasets.scp2jsonl \ |
| | | ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl |
| | | """ |
| | | |
| | | if __name__ == "__main__": |
| | | main_hydra() |
| New file |
| | |
| | | import os |
| | | import json |
| | | import torch |
| | | import logging |
| | | import hydra |
| | | from omegaconf import DictConfig, OmegaConf |
| | | import concurrent.futures |
| | | import librosa |
| | | import torch.distributed as dist |
| | | import threading |
| | | from tqdm import tqdm |
| | | from concurrent.futures import ThreadPoolExecutor |
| | | |
| | | |
| | | def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu): |
| | | jsonl_file_out_f = open(jsonl_file_out, "w") |
| | | with open(jsonl_file, encoding="utf-8") as fin: |
| | | lines = fin.readlines() |
| | | |
| | | num_total = len(lines) |
| | | if ncpu > 1: |
| | | # 使用ThreadPoolExecutor限制并发线程数 |
| | | with ThreadPoolExecutor(max_workers=ncpu) as executor: |
| | | # 提交任务到线程池 |
| | | futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))} |
| | | |
| | | # 等待所有任务完成,这会阻塞直到所有提交的任务完成 |
| | | for future in concurrent.futures.as_completed(futures): |
| | | # 这里可以添加额外的逻辑来处理完成的任务,但在这个例子中我们只是等待 |
| | | pass |
| | | else: |
| | | for i in range(num_total): |
| | | update_data(lines, i) |
| | | logging.info("All audio durations have been processed.") |
| | | |
| | | for line in lines: |
| | | |
| | | jsonl_file_out_f.write(line + "\n") |
| | | jsonl_file_out_f.flush() |
| | | |
| | | jsonl_file_out_f.close() |
| | | |
| | | |
| | | def update_data(lines, i): |
| | | line = lines[i] |
| | | data = json.loads(line.strip()) |
| | | |
| | | wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data") |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | sample_num = len(waveform) |
| | | source_len = int(sample_num / 16000 * 1000 / 10) |
| | | source_len_old = data["source_len"] |
| | | # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100: |
| | | # logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}") |
| | | data["source_len"] = source_len |
| | | data["source"] = wav_path |
| | | jsonl_line = json.dumps(data, ensure_ascii=False) |
| | | lines[i] = jsonl_line |
| | | |
| | | |
| | | def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1): |
| | | |
| | | os.makedirs(jsonl_file_out_dir, exist_ok=True) |
| | | with open(jsonl_file_list_in, "r") as f: |
| | | data_file_lists = f.readlines() |
| | | |
| | | for i, jsonl in enumerate(data_file_lists): |
| | | filename_with_extension = os.path.basename(jsonl.strip()) |
| | | jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension) |
| | | logging.info(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}") |
| | | |
| | | gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu) |
| | | |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(cfg: DictConfig): |
| | | |
| | | kwargs = OmegaConf.to_container(cfg, resolve=True) |
| | | logging.info(kwargs) |
| | | |
| | | jsonl_file_list_in = kwargs.get( |
| | | "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list" |
| | | ) |
| | | jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp") |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu) |
| | | # gen_scp_from_jsonl(jsonl_file_list_in, jsonl_file_out_dir) |
| | | |
| | | |
| | | """ |
| | | python -m funasr.datasets.audio_datasets.json2scp \ |
| | | ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl |
| | | """ |
| | | |
| | | if __name__ == "__main__": |
| | | main_hydra() |
| | |
| | | x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |
| | | |
| | | Args: |
| | | n_head (int): The number of heads. |
| | | n_feat (int): The number of features. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0): |
| | | """Construct an MultiHeadedAttention object.""" |
| | | super().__init__() |
| | | |
| | | self.dropout = nn.Dropout(p=dropout_rate) |
| | | |
| | | self.fsmn_block = nn.Conv1d( |
| | | n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False |
| | | ) |
| | | # padding |
| | | # padding |
| | | left_padding = (kernel_size - 1) // 2 |
| | | if sanm_shfit > 0: |
| | | left_padding = left_padding + sanm_shfit |
| | | right_padding = kernel_size - 1 - left_padding |
| | | self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) |
| | | self.kernel_size = kernel_size |
| | | |
| | | def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs): |
| | | """ |
| | | :param x: (#batch, time1, size). |
| | | :param mask: Mask tensor (#batch, 1, time) |
| | | :return: |
| | | """ |
| | | # print("in fsmn, inputs", inputs.size()) |
| | | b, t, d = inputs.size() |
| | | # logging.info( |
| | | # "mask: {}".format(mask.size())) |
| | | if mask is not None: |
| | | mask = torch.reshape(mask, (b, -1, 1)) |
| | | # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :])) |
| | | if mask_shfit_chunk is not None: |
| | | # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :])) |
| | | mask = mask * mask_shfit_chunk |
| | | # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :])) |
| | | # print("in fsmn, mask", mask.size()) |
| | | # print("in fsmn, inputs", inputs.size()) |
| | | inputs = inputs * mask |
| | | |
| | | x = inputs.transpose(1, 2) |
| | | b, d, t = x.size() |
| | | if cache is None: |
| | | # print("in fsmn, cache is None, x", x.size()) |
| | | |
| | | x = self.pad_fn(x) |
| | | if not self.training: |
| | | cache = x |
| | | else: |
| | | # print("in fsmn, cache is not None, x", x.size()) |
| | | # x = torch.cat((x, cache), dim=2)[:, :, :-1] |
| | | # if t < self.kernel_size: |
| | | # x = self.pad_fn(x) |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | x = x[:, :, -(self.kernel_size + t - 1) :] |
| | | # print("in fsmn, cache is not None, x_cat", x.size()) |
| | | cache = x |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | # print("in fsmn, fsmn_out", x.size()) |
| | | if x.size(1) != inputs.size(1): |
| | | inputs = inputs[:, -1, :] |
| | | |
| | | x = x + inputs |
| | | x = self.dropout(x) |
| | | if mask is not None: |
| | | x = x * mask |
| | | return x, cache |
| | | |
| | | |
| | | class ResidualAttentionBlockFSMN(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs): |
| | | super().__init__() |
| | | |
| | | self.attn = MultiHeadedAttentionSANMDecoder( |
| | | n_state, |
| | | kwargs.get("self_attention_dropout_rate"), |
| | | kwargs.get("kernel_size", 20), |
| | | kwargs.get("sanm_shfit", 10), |
| | | ) |
| | | self.attn_ln = LayerNorm(n_state) |
| | | |
| | | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None |
| | | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
| | | |
| | | n_mlp = n_state * 4 |
| | | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) |
| | | self.mlp_ln = LayerNorm(n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | if self.cross_attn: |
| | | x = ( |
| | | x |
| | | + self.cross_attn( |
| | | self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask |
| | | )[0] |
| | | ) |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | |
| | | |
| | | @tables.register("decoder_classes", "SenseVoiceDecoderFSMN") |
| | | class SenseVoiceDecoderFSMN(nn.Module): |
| | | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs): |
| | | super().__init__() |
| | | |
| | | self.token_embedding = nn.Embedding(n_vocab, n_state) |
| | | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) |
| | | |
| | | self.blocks = nn.ModuleList( |
| | | [ |
| | | ResidualAttentionBlockFSMN( |
| | | n_state, n_head, cross_attention=True, layer_id=i, **kwargs |
| | | ) |
| | | for i in range(n_layer) |
| | | ] |
| | | ) |
| | | self.ln = LayerNorm(n_state) |
| | | |
| | | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) |
| | | self.register_buffer("mask", mask, persistent=False) |
| | | |
| | | self.use_padmask = kwargs.get("use_padmask", True) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | xa: torch.Tensor, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | # import pdb;pdb.set_trace() |
| | | use_padmask = self.use_padmask |
| | | hlens = kwargs.get("hlens", None) |
| | | |
| | | ys_in_lens = kwargs.get("ys_in_lens", None) |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt == -1] = 0 |
| | | tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)] |
| | | # tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | | |
| | | if use_padmask and hlens is not None: |
| | | memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) |
| | | else: |
| | | memory_mask = None |
| | | |
| | | for layer, block in enumerate(self.blocks): |
| | | x = block( |
| | | x, |
| | | memory, |
| | | mask=self.mask, |
| | | memory_mask=memory_mask, |
| | | is_pad_mask=False, |
| | | is_pad_memory_mask=True, |
| | | ) |
| | | |
| | | x = self.ln(x) |
| | | x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x |
| | |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size, frames, _ = speech.shape |
| | | _, text_tokens = text.shape |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | |
| | | stats["batch_size_x_frames"] = frames * batch_size |
| | | stats["batch_size_real_frames"] = speech_lengths.sum().item() |
| | | stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] |
| | | stats["batch_size_x_tokens"] = text_tokens * batch_size |
| | | stats["batch_size_real_tokens"] = text_lengths.sum().item() |
| | | stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] |
| | | stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoiceFSMN") |
| | | class SenseVoiceFSMN(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | # encoder |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_conf = kwargs.get("decoder_conf", {}) |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=dims.n_vocab, |
| | | encoder_output_size=dims.n_audio_state, |
| | | **decoder_conf, |
| | | ) |
| | | model.decoder = decoder |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | | padding_idx=self.ignore_id, |
| | | smoothing=kwargs.get("lsm_weight", 0.0), |
| | | normalize_length=self.length_normalized_loss, |
| | | ) |
| | | |
| | | specaug = kwargs.get("specaug", None) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**kwargs.get("specaug_conf", {})) |
| | | self.specaug = specaug |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size, frames, _ = speech.shape |
| | | _, text_tokens = text.shape |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | |
| | | encoder_out, encoder_out_lens = checkpoint( |
| | | self.encode, speech, speech_lengths, use_reentrant=False |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | ) |
| | | loss = loss_att |
| | | stats = {} |
| | | stats["acc"] = acc_att |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | stats["batch_size_x_frames"] = frames * batch_size |
| | | stats["batch_size_real_frames"] = speech_lengths.sum().item() |
| | | stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] |
| | | stats["batch_size_x_tokens"] = text_tokens * batch_size |
| | | stats["batch_size_real_tokens"] = text_lengths.sum().item() |
| | | stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] |
| | | stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | # decoder_out, _ = self.model.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy( |
| | | preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id |
| | | ) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class( |
| | | n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) |
| | | ) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if ( |
| | | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| | | ): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | | ) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank( |
| | | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | DecodingOptions = kwargs.get("DecodingOptions", {}) |
| | | task = DecodingOptions.get("task", "ASR") |
| | | if isinstance(task, str): |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | |
| | | batch_num_epoch = len(dataloader_train) |
| | | self.log( |
| | | epoch, |
| | | batch_idx, |
| | | batch_idx + kwargs.get("start_step", 0), |
| | | step_in_epoch=self.step_in_epoch, |
| | | batch_num_epoch=batch_num_epoch, |
| | | lr=lr, |