zhifu gao
2024-04-30 a09aba419f305abadc185ec41c336211549e894b
Dev gzf exp (#1682)

* resume from step
14个文件已修改
2个文件已添加
705 ■■■■■ 已修改文件
examples/README.md 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/README_zh.md 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/branchformer/run.sh 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/conformer/run.sh 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_branchformer/run.sh 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/run.sh 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/transformer/run.sh 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/espnet_samplers.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/scp2jsonl.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/scp2len.py 121 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/update_jsonl.py 98 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py 202 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 240 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/README.md
@@ -248,10 +248,10 @@
export CUDA_VISIBLE_DEVICES="0,1"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
torchrun --nnodes 1 --nproc_per_node ${gpu_num} --master_port 12345 \
../../../funasr/bin/train.py ${train_args}
```
--nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node.
--nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node. --master_port indicates the port is 12345
##### Multi-Machine Multi-GPU Training
@@ -260,7 +260,7 @@
export CUDA_VISIBLE_DEVICES="0,1"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr 192.168.1.1 --master_port 12345 \
../../../funasr/bin/train.py ${train_args}
```
On the worker node (assuming the IP is 192.168.1.2), you need to ensure that the MASTER_ADDR and MASTER_PORT environment variables are set to match those of the master node, and then run the same command:
@@ -269,7 +269,7 @@
export CUDA_VISIBLE_DEVICES="0,1"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr 192.168.1.1 --master_port 12345 \
../../../funasr/bin/train.py ${train_args}
```
examples/README_zh.md
@@ -256,10 +256,10 @@
export CUDA_VISIBLE_DEVICES="0,1"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
torchrun --nnodes 1 --nproc_per_node ${gpu_num} --master_port 12345 \
../../../funasr/bin/train.py ${train_args}
```
--nnodes 表示参与的节点总数,--nproc_per_node 表示每个节点上运行的进程数
--nnodes 表示参与的节点总数,--nproc_per_node 表示每个节点上运行的进程数,--master_port 表示端口号
##### 多机多gpu训练
@@ -280,7 +280,7 @@
../../../funasr/bin/train.py ${train_args}
```
--nnodes 表示参与的节点总数,--node_rank 表示当前节点id,--nproc_per_node 表示每个节点上运行的进程数(通常为gpu个数)
--nnodes 表示参与的节点总数,--node_rank 表示当前节点id,--nproc_per_node 表示每个节点上运行的进程数(通常为gpu个数),--master_port 表示端口号
#### 准备数据
examples/aishell/branchformer/run.sh
@@ -27,6 +27,8 @@
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
@@ -115,6 +117,7 @@
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  --master_port ${master_port} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
examples/aishell/conformer/run.sh
@@ -27,6 +27,8 @@
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
@@ -114,6 +116,7 @@
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  --master_port ${master_port} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
examples/aishell/e_branchformer/run.sh
@@ -27,6 +27,8 @@
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
@@ -115,6 +117,7 @@
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  --master_port ${master_port} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
examples/aishell/paraformer/run.sh
@@ -27,6 +27,8 @@
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
@@ -113,6 +115,7 @@
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  --master_port ${master_port} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
examples/aishell/transformer/run.sh
@@ -27,6 +27,8 @@
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
@@ -115,6 +117,7 @@
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  --master_port ${master_port} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
funasr/bin/train.py
@@ -205,7 +205,6 @@
            dataloader_tr, dataloader_val = dataloader.build_iter(
                epoch, data_split_i=data_split_i, start_step=trainer.start_step
            )
            trainer.start_step = 0
            trainer.train_epoch(
                model=model,
@@ -218,7 +217,9 @@
                writer=writer,
                data_split_i=data_split_i,
                data_split_num=dataloader.data_split_num,
                start_step=trainer.start_step,
            )
            trainer.start_step = 0
            torch.cuda.empty_cache()
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -71,7 +71,7 @@
        self.max_token_length = kwargs.get("max_token_length", 2048)
        self.min_token_length = kwargs.get("min_token_length", 0)
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        self.start_step = 0
        self.start_step = start_step
        if self.start_step > 0:
            logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
        # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
@@ -146,7 +146,10 @@
        start_idx = self.rank * batches_per_rank
        end_idx = start_idx + batches_per_rank
        rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
        if self.start_step > 0:
            logging.info(
                f"Warning, rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num_before: {end_idx-start_idx}, now: {len(rank_batches)}"
            )
        # Return an iterator over the batches for the current rank
        return iter(rank_batches)
funasr/datasets/audio_datasets/index_ds.py
@@ -35,7 +35,7 @@
            with open(path, encoding="utf-8") as fin:
                file_list_all = fin.readlines()
                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1
                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                file_list = file_list_all[
                    data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
                ]
funasr/datasets/audio_datasets/scp2jsonl.py
@@ -29,7 +29,6 @@
            with open(data_file, "r") as f:
                data_file_lists = f.readlines()
                print("")
                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                # import pdb;pdb.set_trace()
funasr/datasets/audio_datasets/scp2len.py
New file
@@ -0,0 +1,121 @@
import os
import json
import torch
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import concurrent.futures
import librosa
import torch.distributed as dist
from tqdm import tqdm
def gen_jsonl_from_wav_text_list(
    path, data_type_list=("source",), jsonl_file_out: str = None, **kwargs
):
    try:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    except:
        rank = 0
        world_size = 1
    cpu_cores = os.cpu_count() or 1
    print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
    if rank == 0:
        json_dict = {}
        # for data_type, data_file in zip(data_type_list, path):
        data_type = data_type_list[0]
        data_file = path
        json_dict[data_type] = {}
        with open(data_file, "r") as f:
            data_file_lists = f.readlines()
            print("")
            lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
            task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
            # import pdb;pdb.set_trace()
            if task_num > 1:
                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
                    futures = [
                        executor.submit(
                            parse_context_length,
                            data_file_lists[i * lines_for_each_th : (i + 1) * lines_for_each_th],
                            data_type,
                            i,
                        )
                        for i in range(task_num)
                    ]
                    for future in concurrent.futures.as_completed(futures):
                        json_dict[data_type].update(future.result())
            else:
                res = parse_context_length(data_file_lists, data_type)
                json_dict[data_type].update(res)
        with open(jsonl_file_out, "w") as f:
            for key in json_dict[data_type_list[0]].keys():
                jsonl_line = {"key": key}
                for data_file in data_type_list:
                    jsonl_line.update(json_dict[data_file][key])
                # jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
                source_len = jsonl_line["source_len"]
                jsonl_line = f"{key} {source_len}"
                f.write(jsonl_line + "\n")
                f.flush()
        print(f"processed {len(json_dict[data_type_list[0]])} samples")
    else:
        pass
    if world_size > 1:
        dist.barrier()
def parse_context_length(data_list: list, data_type: str, id=0):
    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
    res = {}
    for i, line in enumerate(data_list):
        pbar.update(1)
        pbar.set_description(f"cpu: {id}")
        lines = line.strip().split(maxsplit=1)
        key = lines[0]
        line = lines[1] if len(lines) > 1 else ""
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
            sample_num = len(waveform)
            context_len = int(sample_num / 16000 * 1000 / 10)
        else:
            context_len = len(line.split()) if " " in line else len(line)
        res[key] = {data_type: line, f"{data_type}_len": context_len}
    return res
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    print(kwargs)
    scp_file_list = kwargs.get("scp_file_list", "/Users/zhifu/funasr1.0/data/list/train_wav.scp")
    # if isinstance(scp_file_list, str):
    #     scp_file_list = eval(scp_file_list)
    data_type_list = kwargs.get("data_type_list", ("source",))
    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/data/list/wav_len.txt")
    gen_jsonl_from_wav_text_list(
        scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
    )
"""
python -m funasr.datasets.audio_datasets.scp2jsonl \
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
++data_type_list='["source", "target"]' \
++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
"""
if __name__ == "__main__":
    main_hydra()
funasr/datasets/audio_datasets/update_jsonl.py
New file
@@ -0,0 +1,98 @@
import os
import json
import torch
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import concurrent.futures
import librosa
import torch.distributed as dist
import threading
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
    jsonl_file_out_f = open(jsonl_file_out, "w")
    with open(jsonl_file, encoding="utf-8") as fin:
        lines = fin.readlines()
        num_total = len(lines)
        if ncpu > 1:
            # 使用ThreadPoolExecutor限制并发线程数
            with ThreadPoolExecutor(max_workers=ncpu) as executor:
                # 提交任务到线程池
                futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}
                # 等待所有任务完成,这会阻塞直到所有提交的任务完成
                for future in concurrent.futures.as_completed(futures):
                    # 这里可以添加额外的逻辑来处理完成的任务,但在这个例子中我们只是等待
                    pass
        else:
            for i in range(num_total):
                update_data(lines, i)
        logging.info("All audio durations have been processed.")
        for line in lines:
            jsonl_file_out_f.write(line + "\n")
            jsonl_file_out_f.flush()
    jsonl_file_out_f.close()
def update_data(lines, i):
    line = lines[i]
    data = json.loads(line.strip())
    wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
    waveform, _ = librosa.load(wav_path, sr=16000)
    sample_num = len(waveform)
    source_len = int(sample_num / 16000 * 1000 / 10)
    source_len_old = data["source_len"]
    # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
    #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
    data["source_len"] = source_len
    data["source"] = wav_path
    jsonl_line = json.dumps(data, ensure_ascii=False)
    lines[i] = jsonl_line
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
    os.makedirs(jsonl_file_out_dir, exist_ok=True)
    with open(jsonl_file_list_in, "r") as f:
        data_file_lists = f.readlines()
        for i, jsonl in enumerate(data_file_lists):
            filename_with_extension = os.path.basename(jsonl.strip())
            jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
            logging.info(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")
            gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    logging.info(kwargs)
    jsonl_file_list_in = kwargs.get(
        "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
    )
    jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
    ncpu = kwargs.get("ncpu", 1)
    update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
    # gen_scp_from_jsonl(jsonl_file_list_in, jsonl_file_out_dir)
"""
python -m funasr.datasets.audio_datasets.json2scp \
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
++data_type_list='["source", "target"]' \
++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
"""
if __name__ == "__main__":
    main_hydra()
funasr/models/sense_voice/decoder.py
@@ -335,3 +335,205 @@
        x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return x
class MultiHeadedAttentionSANMDecoder(nn.Module):
    """Multi-Head Attention layer.
    Args:
        n_head (int): The number of heads.
        n_feat (int): The number of features.
        dropout_rate (float): Dropout rate.
    """
    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
        """Construct an MultiHeadedAttention object."""
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fsmn_block = nn.Conv1d(
            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
        )
        # padding
        # padding
        left_padding = (kernel_size - 1) // 2
        if sanm_shfit > 0:
            left_padding = left_padding + sanm_shfit
        right_padding = kernel_size - 1 - left_padding
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
        self.kernel_size = kernel_size
    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs):
        """
        :param x: (#batch, time1, size).
        :param mask: Mask tensor (#batch, 1, time)
        :return:
        """
        # print("in fsmn, inputs", inputs.size())
        b, t, d = inputs.size()
        # logging.info(
        #     "mask: {}".format(mask.size()))
        if mask is not None:
            mask = torch.reshape(mask, (b, -1, 1))
            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
            if mask_shfit_chunk is not None:
                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
                mask = mask * mask_shfit_chunk
            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
            # print("in fsmn, mask", mask.size())
            # print("in fsmn, inputs", inputs.size())
            inputs = inputs * mask
        x = inputs.transpose(1, 2)
        b, d, t = x.size()
        if cache is None:
            # print("in fsmn, cache is None, x", x.size())
            x = self.pad_fn(x)
            if not self.training:
                cache = x
        else:
            # print("in fsmn, cache is not None, x", x.size())
            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
            # if t < self.kernel_size:
            #     x = self.pad_fn(x)
            x = torch.cat((cache[:, :, 1:], x), dim=2)
            x = x[:, :, -(self.kernel_size + t - 1) :]
            # print("in fsmn, cache is not None, x_cat", x.size())
            cache = x
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        # print("in fsmn, fsmn_out", x.size())
        if x.size(1) != inputs.size(1):
            inputs = inputs[:, -1, :]
        x = x + inputs
        x = self.dropout(x)
        if mask is not None:
            x = x * mask
        return x, cache
class ResidualAttentionBlockFSMN(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs):
        super().__init__()
        self.attn = MultiHeadedAttentionSANMDecoder(
            n_state,
            kwargs.get("self_attention_dropout_rate"),
            kwargs.get("kernel_size", 20),
            kwargs.get("sanm_shfit", 10),
        )
        self.attn_ln = LayerNorm(n_state)
        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        if self.cross_attn:
            x = (
                x
                + self.cross_attn(
                    self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
                )[0]
            )
        x = x + self.mlp(self.mlp_ln(x))
        return x
@tables.register("decoder_classes", "SenseVoiceDecoderFSMN")
class SenseVoiceDecoderFSMN(nn.Module):
    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList(
            [
                ResidualAttentionBlockFSMN(
                    n_state, n_head, cross_attention=True, layer_id=i, **kwargs
                )
                for i in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)
        self.use_padmask = kwargs.get("use_padmask", True)
    def forward(
        self,
        x: torch.Tensor,
        xa: torch.Tensor,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        """Forward decoder.
        Args:
                hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
                hlens: (batch)
                ys_in_pad:
                        input token ids, int64 (batch, maxlen_out)
                        if input_layer == "embed"
                        input tensor (batch, maxlen_out, #mels) in the other cases
                ys_in_lens: (batch)
        Returns:
                (tuple): tuple containing:
                x: decoded token score before softmax (batch, maxlen_out, token)
                        if use_output_layer is True,
                olens: (batch, )
        """
        # import pdb;pdb.set_trace()
        use_padmask = self.use_padmask
        hlens = kwargs.get("hlens", None)
        ys_in_lens = kwargs.get("ys_in_lens", None)
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        tgt, memory = x, xa
        tgt[tgt == -1] = 0
        tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
        # tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
        if use_padmask and hlens is not None:
            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
        else:
            memory_mask = None
        for layer, block in enumerate(self.blocks):
            x = block(
                x,
                memory,
                mask=self.mask,
                memory_mask=memory_mask,
                is_pad_mask=False,
                is_pad_memory_mask=True,
            )
        x = self.ln(x)
        x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return x
funasr/models/sense_voice/model.py
@@ -310,6 +310,7 @@
            speech_lengths = speech_lengths[:, 0]
        batch_size, frames, _ = speech.shape
        _, text_tokens = text.shape
        if self.activation_checkpoint:
            from torch.utils.checkpoint import checkpoint
@@ -331,6 +332,10 @@
        stats["batch_size_x_frames"] = frames * batch_size
        stats["batch_size_real_frames"] = speech_lengths.sum().item()
        stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
        stats["batch_size_x_tokens"] = text_tokens * batch_size
        stats["batch_size_real_tokens"] = text_lengths.sum().item()
        stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
        stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
@@ -471,3 +476,238 @@
        results.append(result_i)
        return results, meta_data
@tables.register("model_classes", "SenseVoiceFSMN")
class SenseVoiceFSMN(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        dims = kwargs.get("dims", {})
        dims = whisper.model.ModelDimensions(**dims)
        model = whisper.model.Whisper(dims=dims)
        # encoder
        model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
        model.encoder.use_padmask = kwargs.get("use_padmask", True)
        from .encoder import sense_voice_encode_forward
        model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
        # decoder
        del model.decoder
        decoder = kwargs.get("decoder", "SenseVoiceDecoder")
        decoder_conf = kwargs.get("decoder_conf", {})
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(
            vocab_size=dims.n_vocab,
            encoder_output_size=dims.n_audio_state,
            **decoder_conf,
        )
        model.decoder = decoder
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
        self.ignore_id = kwargs.get("ignore_id", -1)
        self.vocab_size = kwargs.get("vocab_size", -1)
        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
        self.criterion_att = LabelSmoothingLoss(
            size=self.vocab_size,
            padding_idx=self.ignore_id,
            smoothing=kwargs.get("lsm_weight", 0.0),
            normalize_length=self.length_normalized_loss,
        )
        specaug = kwargs.get("specaug", None)
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**kwargs.get("specaug_conf", {}))
        self.specaug = specaug
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ):
        target_mask = kwargs.get("target_mask", None)
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size, frames, _ = speech.shape
        _, text_tokens = text.shape
        if self.activation_checkpoint:
            from torch.utils.checkpoint import checkpoint
            encoder_out, encoder_out_lens = checkpoint(
                self.encode, speech, speech_lengths, use_reentrant=False
            )
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
        )
        loss = loss_att
        stats = {}
        stats["acc"] = acc_att
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        stats["batch_size_x_frames"] = frames * batch_size
        stats["batch_size_real_frames"] = speech_lengths.sum().item()
        stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
        stats["batch_size_x_tokens"] = text_tokens * batch_size
        stats["batch_size_real_tokens"] = text_lengths.sum().item()
        stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
        stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
        # Forward encoder
        encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
        return encoder_out, encoder_out_lens
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        **kwargs,
    ):
        target_mask = kwargs.get("target_mask", None)
        stats = {}
        # 1. Forward decoder
        decoder_out = self.model.decoder(
            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
        )
        # decoder_out, _ = self.model.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        # 2. Compute attention loss
        mask = torch.ones_like(ys_pad) * (-1)
        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
        ys_pad_mask[ys_pad_mask == 0] = -1
        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
        with torch.no_grad():
            preds = torch.argmax(decoder_out, -1)
            acc_att = compute_accuracy(
                preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
            )
        return loss_att, acc_att, None, None
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(
                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
            )
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(
                data_in,
                fs=frontend.fs if hasattr(frontend, "fs") else 16000,
                audio_fs=kwargs.get("fs", 16000),
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
            )
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(
                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        DecodingOptions = kwargs.get("DecodingOptions", {})
        task = DecodingOptions.get("task", "ASR")
        if isinstance(task, str):
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        DecodingOptions["initial_prompt"] = initial_prompt
        language = DecodingOptions.get("language", None)
        language = None if language == "auto" else language
        DecodingOptions["language"] = language
        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
        if "without_timestamps" not in DecodingOptions:
            DecodingOptions["without_timestamps"] = True
        options = whisper.DecodingOptions(**DecodingOptions)
        result = whisper.decode(self.model, speech, options)
        text = f"{result.text}"
        results = []
        result_i = {"key": key[0], "text": text}
        results.append(result_i)
        return results, meta_data
funasr/train_utils/trainer.py
@@ -456,7 +456,7 @@
                    batch_num_epoch = len(dataloader_train)
                self.log(
                    epoch,
                    batch_idx,
                    batch_idx + kwargs.get("start_step", 0),
                    step_in_epoch=self.step_in_epoch,
                    batch_num_epoch=batch_num_epoch,
                    lr=lr,