From a09aba419f305abadc185ec41c336211549e894b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 30 四月 2024 12:52:58 +0800
Subject: [PATCH] Dev gzf exp (#1682)

---
 funasr/models/sense_voice/decoder.py              |  202 ++++++++++++++++
 funasr/datasets/audio_datasets/espnet_samplers.py |    7 
 funasr/train_utils/trainer.py                     |    2 
 examples/README.md                                |    8 
 examples/aishell/paraformer/run.sh                |    3 
 examples/README_zh.md                             |    6 
 funasr/datasets/audio_datasets/index_ds.py        |    2 
 examples/aishell/branchformer/run.sh              |    3 
 examples/aishell/e_branchformer/run.sh            |    3 
 funasr/datasets/audio_datasets/scp2jsonl.py       |    1 
 funasr/models/sense_voice/model.py                |  240 ++++++++++++++++++++
 funasr/datasets/audio_datasets/update_jsonl.py    |   98 ++++++++
 examples/aishell/conformer/run.sh                 |    3 
 funasr/bin/train.py                               |    3 
 funasr/datasets/audio_datasets/scp2len.py         |  121 ++++++++++
 examples/aishell/transformer/run.sh               |    3 
 16 files changed, 692 insertions(+), 13 deletions(-)

diff --git a/examples/README.md b/examples/README.md
index f87d5fa..0191a2d 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -248,10 +248,10 @@
 export CUDA_VISIBLE_DEVICES="0,1"
 gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 
-torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+torchrun --nnodes 1 --nproc_per_node ${gpu_num} --master_port 12345 \
 ../../../funasr/bin/train.py ${train_args}
 ```
---nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node.
+--nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node. --master_port indicates the port is 12345
 
 ##### Multi-Machine Multi-GPU Training
 
@@ -260,7 +260,7 @@
 export CUDA_VISIBLE_DEVICES="0,1"
 gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 
-torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
+torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr 192.168.1.1 --master_port 12345 \
 ../../../funasr/bin/train.py ${train_args}
 ```
 On the worker node (assuming the IP is 192.168.1.2), you need to ensure that the MASTER_ADDR and MASTER_PORT environment variables are set to match those of the master node, and then run the same command:
@@ -269,7 +269,7 @@
 export CUDA_VISIBLE_DEVICES="0,1"
 gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 
-torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
+torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr 192.168.1.1 --master_port 12345 \
 ../../../funasr/bin/train.py ${train_args}
 ```
 
diff --git a/examples/README_zh.md b/examples/README_zh.md
index b016f9e..b0a6665 100644
--- a/examples/README_zh.md
+++ b/examples/README_zh.md
@@ -256,10 +256,10 @@
 export CUDA_VISIBLE_DEVICES="0,1"
 gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 
-torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+torchrun --nnodes 1 --nproc_per_node ${gpu_num} --master_port 12345 \
 ../../../funasr/bin/train.py ${train_args}
 ```
---nnodes 琛ㄧず鍙備笌鐨勮妭鐐规�绘暟锛�--nproc_per_node 琛ㄧず姣忎釜鑺傜偣涓婅繍琛岀殑杩涚▼鏁�
+--nnodes 琛ㄧず鍙備笌鐨勮妭鐐规�绘暟锛�--nproc_per_node 琛ㄧず姣忎釜鑺傜偣涓婅繍琛岀殑杩涚▼鏁帮紝--master_port 琛ㄧず绔彛鍙�
 
 ##### 澶氭満澶歡pu璁粌
 
@@ -280,7 +280,7 @@
 ../../../funasr/bin/train.py ${train_args}
 ```
 
---nnodes 琛ㄧず鍙備笌鐨勮妭鐐规�绘暟锛�--node_rank 琛ㄧず褰撳墠鑺傜偣id锛�--nproc_per_node 琛ㄧず姣忎釜鑺傜偣涓婅繍琛岀殑杩涚▼鏁帮紙閫氬父涓篻pu涓暟锛�
+--nnodes 琛ㄧず鍙備笌鐨勮妭鐐规�绘暟锛�--node_rank 琛ㄧず褰撳墠鑺傜偣id锛�--nproc_per_node 琛ㄧず姣忎釜鑺傜偣涓婅繍琛岀殑杩涚▼鏁帮紙閫氬父涓篻pu涓暟锛夛紝--master_port 琛ㄧず绔彛鍙�
 
 #### 鍑嗗鏁版嵁
 
diff --git a/examples/aishell/branchformer/run.sh b/examples/aishell/branchformer/run.sh
index 918aa9b..5b64954 100755
--- a/examples/aishell/branchformer/run.sh
+++ b/examples/aishell/branchformer/run.sh
@@ -27,6 +27,8 @@
 tag="exp1"
 workspace=`pwd`
 
+master_port=12345
+
 . utils/parse_options.sh || exit 1;
 
 # Set bash to 'debug' mode, it will exit on :
@@ -115,6 +117,7 @@
   torchrun \
   --nnodes 1 \
   --nproc_per_node ${gpu_num} \
+  --master_port ${master_port} \
   ../../../funasr/bin/train.py \
   --config-path "${workspace}/conf" \
   --config-name "${config}" \
diff --git a/examples/aishell/conformer/run.sh b/examples/aishell/conformer/run.sh
index ba8b43c..0c8ab50 100755
--- a/examples/aishell/conformer/run.sh
+++ b/examples/aishell/conformer/run.sh
@@ -27,6 +27,8 @@
 tag="exp1"
 workspace=`pwd`
 
+master_port=12345
+
 . utils/parse_options.sh || exit 1;
 
 # Set bash to 'debug' mode, it will exit on :
@@ -114,6 +116,7 @@
   torchrun \
   --nnodes 1 \
   --nproc_per_node ${gpu_num} \
+  --master_port ${master_port} \
   ../../../funasr/bin/train.py \
   --config-path "${workspace}/conf" \
   --config-name "${config}" \
diff --git a/examples/aishell/e_branchformer/run.sh b/examples/aishell/e_branchformer/run.sh
index be18599..452ec80 100755
--- a/examples/aishell/e_branchformer/run.sh
+++ b/examples/aishell/e_branchformer/run.sh
@@ -27,6 +27,8 @@
 tag="exp1"
 workspace=`pwd`
 
+master_port=12345
+
 . utils/parse_options.sh || exit 1;
 
 # Set bash to 'debug' mode, it will exit on :
@@ -115,6 +117,7 @@
   torchrun \
   --nnodes 1 \
   --nproc_per_node ${gpu_num} \
+  --master_port ${master_port} \
   ../../../funasr/bin/train.py \
   --config-path "${workspace}/conf" \
   --config-name "${config}" \
diff --git a/examples/aishell/paraformer/run.sh b/examples/aishell/paraformer/run.sh
index a957b93..ffef61e 100755
--- a/examples/aishell/paraformer/run.sh
+++ b/examples/aishell/paraformer/run.sh
@@ -27,6 +27,8 @@
 tag="exp1"
 workspace=`pwd`
 
+master_port=12345
+
 . utils/parse_options.sh || exit 1;
 
 # Set bash to 'debug' mode, it will exit on :
@@ -113,6 +115,7 @@
   torchrun \
   --nnodes 1 \
   --nproc_per_node ${gpu_num} \
+  --master_port ${master_port} \
   ../../../funasr/bin/train.py \
   --config-path "${workspace}/conf" \
   --config-name "${config}" \
diff --git a/examples/aishell/transformer/run.sh b/examples/aishell/transformer/run.sh
index 98c2829..3fb8465 100755
--- a/examples/aishell/transformer/run.sh
+++ b/examples/aishell/transformer/run.sh
@@ -27,6 +27,8 @@
 tag="exp1"
 workspace=`pwd`
 
+master_port=12345
+
 . utils/parse_options.sh || exit 1;
 
 # Set bash to 'debug' mode, it will exit on :
@@ -115,6 +117,7 @@
   torchrun \
   --nnodes 1 \
   --nproc_per_node ${gpu_num} \
+  --master_port ${master_port} \
   ../../../funasr/bin/train.py \
   --config-path "${workspace}/conf" \
   --config-name "${config}" \
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 97516eb..d20915c 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -205,7 +205,6 @@
             dataloader_tr, dataloader_val = dataloader.build_iter(
                 epoch, data_split_i=data_split_i, start_step=trainer.start_step
             )
-            trainer.start_step = 0
 
             trainer.train_epoch(
                 model=model,
@@ -218,7 +217,9 @@
                 writer=writer,
                 data_split_i=data_split_i,
                 data_split_num=dataloader.data_split_num,
+                start_step=trainer.start_step,
             )
+            trainer.start_step = 0
 
             torch.cuda.empty_cache()
 
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index cb30a28..e155cd7 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -71,7 +71,7 @@
         self.max_token_length = kwargs.get("max_token_length", 2048)
         self.min_token_length = kwargs.get("min_token_length", 0)
         self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        self.start_step = 0
+        self.start_step = start_step
         if self.start_step > 0:
             logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
         # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
@@ -146,7 +146,10 @@
         start_idx = self.rank * batches_per_rank
         end_idx = start_idx + batches_per_rank
         rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
-
+        if self.start_step > 0:
+            logging.info(
+                f"Warning, rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num_before: {end_idx-start_idx}, now: {len(rank_batches)}"
+            )
         # Return an iterator over the batches for the current rank
         return iter(rank_batches)
 
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 70581e8..385218a 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -35,7 +35,7 @@
             with open(path, encoding="utf-8") as fin:
                 file_list_all = fin.readlines()
 
-                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1
+                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                 file_list = file_list_all[
                     data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
                 ]
diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f167173..f4c9d74 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -29,7 +29,6 @@
             with open(data_file, "r") as f:
 
                 data_file_lists = f.readlines()
-                print("")
                 lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                 task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                 # import pdb;pdb.set_trace()
diff --git a/funasr/datasets/audio_datasets/scp2len.py b/funasr/datasets/audio_datasets/scp2len.py
new file mode 100644
index 0000000..5d742b1
--- /dev/null
+++ b/funasr/datasets/audio_datasets/scp2len.py
@@ -0,0 +1,121 @@
+import os
+import json
+import torch
+import logging
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from tqdm import tqdm
+
+
+def gen_jsonl_from_wav_text_list(
+    path, data_type_list=("source",), jsonl_file_out: str = None, **kwargs
+):
+    try:
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    except:
+        rank = 0
+        world_size = 1
+
+    cpu_cores = os.cpu_count() or 1
+    print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
+    if rank == 0:
+        json_dict = {}
+        # for data_type, data_file in zip(data_type_list, path):
+        data_type = data_type_list[0]
+        data_file = path
+        json_dict[data_type] = {}
+        with open(data_file, "r") as f:
+
+            data_file_lists = f.readlines()
+            print("")
+            lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
+            task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
+            # import pdb;pdb.set_trace()
+            if task_num > 1:
+                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
+
+                    futures = [
+                        executor.submit(
+                            parse_context_length,
+                            data_file_lists[i * lines_for_each_th : (i + 1) * lines_for_each_th],
+                            data_type,
+                            i,
+                        )
+                        for i in range(task_num)
+                    ]
+
+                    for future in concurrent.futures.as_completed(futures):
+
+                        json_dict[data_type].update(future.result())
+            else:
+                res = parse_context_length(data_file_lists, data_type)
+                json_dict[data_type].update(res)
+
+        with open(jsonl_file_out, "w") as f:
+            for key in json_dict[data_type_list[0]].keys():
+                jsonl_line = {"key": key}
+                for data_file in data_type_list:
+                    jsonl_line.update(json_dict[data_file][key])
+                # jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
+                source_len = jsonl_line["source_len"]
+                jsonl_line = f"{key} {source_len}"
+                f.write(jsonl_line + "\n")
+                f.flush()
+        print(f"processed {len(json_dict[data_type_list[0]])} samples")
+
+    else:
+        pass
+
+    if world_size > 1:
+        dist.barrier()
+
+
+def parse_context_length(data_list: list, data_type: str, id=0):
+    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
+    res = {}
+    for i, line in enumerate(data_list):
+        pbar.update(1)
+        pbar.set_description(f"cpu: {id}")
+        lines = line.strip().split(maxsplit=1)
+        key = lines[0]
+        line = lines[1] if len(lines) > 1 else ""
+        line = line.strip()
+        if os.path.exists(line):
+            waveform, _ = librosa.load(line, sr=16000)
+            sample_num = len(waveform)
+            context_len = int(sample_num / 16000 * 1000 / 10)
+        else:
+            context_len = len(line.split()) if " " in line else len(line)
+        res[key] = {data_type: line, f"{data_type}_len": context_len}
+    return res
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+
+    kwargs = OmegaConf.to_container(cfg, resolve=True)
+    print(kwargs)
+
+    scp_file_list = kwargs.get("scp_file_list", "/Users/zhifu/funasr1.0/data/list/train_wav.scp")
+    # if isinstance(scp_file_list, str):
+    #     scp_file_list = eval(scp_file_list)
+    data_type_list = kwargs.get("data_type_list", ("source",))
+    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/data/list/wav_len.txt")
+    gen_jsonl_from_wav_text_list(
+        scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
+    )
+
+
+"""
+python -m funasr.datasets.audio_datasets.scp2jsonl \
+++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
+"""
+
+if __name__ == "__main__":
+    main_hydra()
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
new file mode 100644
index 0000000..05870fe
--- /dev/null
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -0,0 +1,98 @@
+import os
+import json
+import torch
+import logging
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+import threading
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
+
+
+def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
+    jsonl_file_out_f = open(jsonl_file_out, "w")
+    with open(jsonl_file, encoding="utf-8") as fin:
+        lines = fin.readlines()
+
+        num_total = len(lines)
+        if ncpu > 1:
+            # 浣跨敤ThreadPoolExecutor闄愬埗骞跺彂绾跨▼鏁�
+            with ThreadPoolExecutor(max_workers=ncpu) as executor:
+                # 鎻愪氦浠诲姟鍒扮嚎绋嬫睜
+                futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}
+
+                # 绛夊緟鎵�鏈変换鍔″畬鎴愶紝杩欎細闃诲鐩村埌鎵�鏈夋彁浜ょ殑浠诲姟瀹屾垚
+                for future in concurrent.futures.as_completed(futures):
+                    # 杩欓噷鍙互娣诲姞棰濆鐨勯�昏緫鏉ュ鐞嗗畬鎴愮殑浠诲姟锛屼絾鍦ㄨ繖涓緥瀛愪腑鎴戜滑鍙槸绛夊緟
+                    pass
+        else:
+            for i in range(num_total):
+                update_data(lines, i)
+        logging.info("All audio durations have been processed.")
+
+        for line in lines:
+
+            jsonl_file_out_f.write(line + "\n")
+            jsonl_file_out_f.flush()
+
+    jsonl_file_out_f.close()
+
+
+def update_data(lines, i):
+    line = lines[i]
+    data = json.loads(line.strip())
+
+    wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
+    waveform, _ = librosa.load(wav_path, sr=16000)
+    sample_num = len(waveform)
+    source_len = int(sample_num / 16000 * 1000 / 10)
+    source_len_old = data["source_len"]
+    # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
+    #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
+    data["source_len"] = source_len
+    data["source"] = wav_path
+    jsonl_line = json.dumps(data, ensure_ascii=False)
+    lines[i] = jsonl_line
+
+
+def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
+
+    os.makedirs(jsonl_file_out_dir, exist_ok=True)
+    with open(jsonl_file_list_in, "r") as f:
+        data_file_lists = f.readlines()
+
+        for i, jsonl in enumerate(data_file_lists):
+            filename_with_extension = os.path.basename(jsonl.strip())
+            jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
+            logging.info(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")
+
+            gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+
+    kwargs = OmegaConf.to_container(cfg, resolve=True)
+    logging.info(kwargs)
+
+    jsonl_file_list_in = kwargs.get(
+        "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
+    )
+    jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
+    ncpu = kwargs.get("ncpu", 1)
+    update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
+    # gen_scp_from_jsonl(jsonl_file_list_in, jsonl_file_out_dir)
+
+
+"""
+python -m funasr.datasets.audio_datasets.json2scp \
+++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
+"""
+
+if __name__ == "__main__":
+    main_hydra()
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 133508f..f5b8825 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -335,3 +335,205 @@
         x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
 
         return x
+
+
+class MultiHeadedAttentionSANMDecoder(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(
+            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
+        )
+        # padding
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+        self.kernel_size = kernel_size
+
+    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs):
+        """
+        :param x: (#batch, time1, size).
+        :param mask: Mask tensor (#batch, 1, time)
+        :return:
+        """
+        # print("in fsmn, inputs", inputs.size())
+        b, t, d = inputs.size()
+        # logging.info(
+        #     "mask: {}".format(mask.size()))
+        if mask is not None:
+            mask = torch.reshape(mask, (b, -1, 1))
+            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            if mask_shfit_chunk is not None:
+                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
+                mask = mask * mask_shfit_chunk
+            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            # print("in fsmn, mask", mask.size())
+            # print("in fsmn, inputs", inputs.size())
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        b, d, t = x.size()
+        if cache is None:
+            # print("in fsmn, cache is None, x", x.size())
+
+            x = self.pad_fn(x)
+            if not self.training:
+                cache = x
+        else:
+            # print("in fsmn, cache is not None, x", x.size())
+            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
+            # if t < self.kernel_size:
+            #     x = self.pad_fn(x)
+            x = torch.cat((cache[:, :, 1:], x), dim=2)
+            x = x[:, :, -(self.kernel_size + t - 1) :]
+            # print("in fsmn, cache is not None, x_cat", x.size())
+            cache = x
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        # print("in fsmn, fsmn_out", x.size())
+        if x.size(1) != inputs.size(1):
+            inputs = inputs[:, -1, :]
+
+        x = x + inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x, cache
+
+
+class ResidualAttentionBlockFSMN(nn.Module):
+    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs):
+        super().__init__()
+
+        self.attn = MultiHeadedAttentionSANMDecoder(
+            n_state,
+            kwargs.get("self_attention_dropout_rate"),
+            kwargs.get("kernel_size", 20),
+            kwargs.get("sanm_shfit", 10),
+        )
+        self.attn_ln = LayerNorm(n_state)
+
+        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
+        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+        n_mlp = n_state * 4
+        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
+        self.mlp_ln = LayerNorm(n_state)
+
+    def forward(
+        self,
+        x: Tensor,
+        xa: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
+        kv_cache: Optional[dict] = None,
+        **kwargs,
+    ):
+        is_pad_mask = kwargs.get("is_pad_mask", False)
+        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+        x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+        if self.cross_attn:
+            x = (
+                x
+                + self.cross_attn(
+                    self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
+                )[0]
+            )
+        x = x + self.mlp(self.mlp_ln(x))
+        return x
+
+
+@tables.register("decoder_classes", "SenseVoiceDecoderFSMN")
+class SenseVoiceDecoderFSMN(nn.Module):
+    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
+        super().__init__()
+
+        self.token_embedding = nn.Embedding(n_vocab, n_state)
+        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+        self.blocks = nn.ModuleList(
+            [
+                ResidualAttentionBlockFSMN(
+                    n_state, n_head, cross_attention=True, layer_id=i, **kwargs
+                )
+                for i in range(n_layer)
+            ]
+        )
+        self.ln = LayerNorm(n_state)
+
+        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+        self.register_buffer("mask", mask, persistent=False)
+
+        self.use_padmask = kwargs.get("use_padmask", True)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        xa: torch.Tensor,
+        kv_cache: Optional[dict] = None,
+        **kwargs,
+    ):
+        """Forward decoder.
+
+        Args:
+                hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+                hlens: (batch)
+                ys_in_pad:
+                        input token ids, int64 (batch, maxlen_out)
+                        if input_layer == "embed"
+                        input tensor (batch, maxlen_out, #mels) in the other cases
+                ys_in_lens: (batch)
+        Returns:
+                (tuple): tuple containing:
+
+                x: decoded token score before softmax (batch, maxlen_out, token)
+                        if use_output_layer is True,
+                olens: (batch, )
+        """
+        # import pdb;pdb.set_trace()
+        use_padmask = self.use_padmask
+        hlens = kwargs.get("hlens", None)
+
+        ys_in_lens = kwargs.get("ys_in_lens", None)
+
+        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+        tgt, memory = x, xa
+        tgt[tgt == -1] = 0
+        tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
+        # tgt = self.dropout(tgt)
+
+        x = tgt.to(memory.dtype)
+
+        if use_padmask and hlens is not None:
+            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+        else:
+            memory_mask = None
+
+        for layer, block in enumerate(self.blocks):
+            x = block(
+                x,
+                memory,
+                mask=self.mask,
+                memory_mask=memory_mask,
+                is_pad_mask=False,
+                is_pad_memory_mask=True,
+            )
+
+        x = self.ln(x)
+        x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
+
+        return x
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index ae20902..c12107e 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -310,6 +310,7 @@
             speech_lengths = speech_lengths[:, 0]
 
         batch_size, frames, _ = speech.shape
+        _, text_tokens = text.shape
 
         if self.activation_checkpoint:
             from torch.utils.checkpoint import checkpoint
@@ -331,6 +332,10 @@
         stats["batch_size_x_frames"] = frames * batch_size
         stats["batch_size_real_frames"] = speech_lengths.sum().item()
         stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
+        stats["batch_size_x_tokens"] = text_tokens * batch_size
+        stats["batch_size_real_tokens"] = text_lengths.sum().item()
+        stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
+        stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
@@ -471,3 +476,238 @@
         results.append(result_i)
 
         return results, meta_data
+
+
+@tables.register("model_classes", "SenseVoiceFSMN")
+class SenseVoiceFSMN(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+
+        dims = kwargs.get("dims", {})
+        dims = whisper.model.ModelDimensions(**dims)
+        model = whisper.model.Whisper(dims=dims)
+
+        # encoder
+        model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
+        model.encoder.use_padmask = kwargs.get("use_padmask", True)
+        from .encoder import sense_voice_encode_forward
+
+        model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
+
+        # decoder
+        del model.decoder
+        decoder = kwargs.get("decoder", "SenseVoiceDecoder")
+        decoder_conf = kwargs.get("decoder_conf", {})
+        decoder_class = tables.decoder_classes.get(decoder)
+        decoder = decoder_class(
+            vocab_size=dims.n_vocab,
+            encoder_output_size=dims.n_audio_state,
+            **decoder_conf,
+        )
+        model.decoder = decoder
+
+        self.model = model
+
+        self.encoder_output_size = self.model.dims.n_audio_state
+
+        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+        self.ignore_id = kwargs.get("ignore_id", -1)
+        self.vocab_size = kwargs.get("vocab_size", -1)
+        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
+        self.criterion_att = LabelSmoothingLoss(
+            size=self.vocab_size,
+            padding_idx=self.ignore_id,
+            smoothing=kwargs.get("lsm_weight", 0.0),
+            normalize_length=self.length_normalized_loss,
+        )
+
+        specaug = kwargs.get("specaug", None)
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**kwargs.get("specaug_conf", {}))
+        self.specaug = specaug
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
+        batch_size, frames, _ = speech.shape
+        _, text_tokens = text.shape
+
+        if self.activation_checkpoint:
+            from torch.utils.checkpoint import checkpoint
+
+            encoder_out, encoder_out_lens = checkpoint(
+                self.encode, speech, speech_lengths, use_reentrant=False
+            )
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+        )
+        loss = loss_att
+        stats = {}
+        stats["acc"] = acc_att
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+        stats["batch_size_x_frames"] = frames * batch_size
+        stats["batch_size_real_frames"] = speech_lengths.sum().item()
+        stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
+        stats["batch_size_x_tokens"] = text_tokens * batch_size
+        stats["batch_size_real_tokens"] = text_lengths.sum().item()
+        stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
+        stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+        # Forward encoder
+        encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
+
+        return encoder_out, encoder_out_lens
+
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+        stats = {}
+
+        # 1. Forward decoder
+        decoder_out = self.model.decoder(
+            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+        )
+        # decoder_out, _ = self.model.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+        # 2. Compute attention loss
+        mask = torch.ones_like(ys_pad) * (-1)
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
+        ys_pad_mask[ys_pad_mask == 0] = -1
+        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+
+        with torch.no_grad():
+            preds = torch.argmax(decoder_out, -1)
+            acc_att = compute_accuracy(
+                preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
+            )
+
+        return loss_att, acc_att, None, None
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+        if frontend is None and not hasattr(self, "frontend"):
+            frontend_class = tables.frontend_classes.get("WhisperFrontend")
+            frontend = frontend_class(
+                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
+            )
+            self.frontend = frontend
+        else:
+            frontend = frontend if frontend is not None else self.frontend
+
+        meta_data = {}
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])[0, :, :]
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        DecodingOptions = kwargs.get("DecodingOptions", {})
+        task = DecodingOptions.get("task", "ASR")
+        if isinstance(task, str):
+            task = [task]
+        task = "".join([f"<|{x}|>" for x in task])
+        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+        DecodingOptions["initial_prompt"] = initial_prompt
+
+        language = DecodingOptions.get("language", None)
+        language = None if language == "auto" else language
+        DecodingOptions["language"] = language
+
+        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+
+        if "without_timestamps" not in DecodingOptions:
+            DecodingOptions["without_timestamps"] = True
+
+        options = whisper.DecodingOptions(**DecodingOptions)
+
+        result = whisper.decode(self.model, speech, options)
+        text = f"{result.text}"
+        results = []
+        result_i = {"key": key[0], "text": text}
+
+        results.append(result_i)
+
+        return results, meta_data
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index e86420c..a28ca51 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -456,7 +456,7 @@
                     batch_num_epoch = len(dataloader_train)
                 self.log(
                     epoch,
-                    batch_idx,
+                    batch_idx + kwargs.get("start_step", 0),
                     step_in_epoch=self.step_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,

--
Gitblit v1.9.1