From eaf9dda9e4d970af3d09db695e9e10c83ef94e25 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 四月 2024 15:05:37 +0800
Subject: [PATCH] Dev gzf exp (#1624)

---
 funasr/models/sense_voice/decoder.py                         |   66 +++++++
 funasr/auto/auto_model.py                                    |    2 
 funasr/losses/label_smoothing_loss.py                        |    4 
 funasr/models/sense_voice/encoder.py                         |   67 +++++++
 funasr/tokenizer/whisper_tokenizer.py                        |   22 ++
 funasr/datasets/audio_datasets/index_ds.py                   |   23 +
 funasr/datasets/sense_voice_datasets/__init__.py             |    0 
 examples/industrial_data_pretraining/sense_voice/demo.py     |    4 
 funasr/datasets/sense_voice_datasets/datasets.py             |  118 +++++++++++++
 funasr/models/sense_voice/model.py                           |  131 ++++++++++++++
 examples/industrial_data_pretraining/sense_voice/finetune.sh |   69 +++++++
 funasr/bin/train.py                                          |    2 
 funasr/models/sense_voice/whisper_lib/model.py               |   27 ++
 13 files changed, 513 insertions(+), 22 deletions(-)

diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py
index b2fca47..0d8ef97 100644
--- a/examples/industrial_data_pretraining/sense_voice/demo.py
+++ b/examples/industrial_data_pretraining/sense_voice/demo.py
@@ -5,13 +5,13 @@
 
 from funasr import AutoModel
 
-model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice",
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
                   vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
 				  vad_kwargs={"max_single_segment_time": 30000},
                   )
 
 
-input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/SenseVoice/aed_ser/asr_bgm.wav"
+input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
 
 DecodingOptions = {
 	"task": ("ASR", "AED", "SER"),
diff --git a/examples/industrial_data_pretraining/sense_voice/finetune.sh b/examples/industrial_data_pretraining/sense_voice/finetune.sh
new file mode 100644
index 0000000..cb07901
--- /dev/null
+++ b/examples/industrial_data_pretraining/sense_voice/finetune.sh
@@ -0,0 +1,69 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+# which gpu to train or finetune
+export CUDA_VISIBLE_DEVICES="0"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically
+model_name_or_model_dir="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_name_or_model_dir="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope"
+## option 2, download model by git
+#local_path_root=${workspace}/modelscope_models
+#mkdir -p ${local_path_root}/${model_name_or_model_dir}
+#git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir}
+#model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+
+
+# data dir, which contains: train.json, val.json
+data_dir="../../../data/list"
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+
+# generate train.jsonl and val.jsonl from wav.scp and text.txt
+scp2jsonl \
+++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_out="${train_data}"
+
+scp2jsonl \
+++scp_file_list='["../../../data/list/val_wav.scp", "../../../data/list/val_text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_out="${val_data}"
+
+
+# exp output dir
+output_dir="./outputs"
+log_file="${output_dir}/log.txt"
+
+
+mkdir -p ${output_dir}
+echo "log_file: ${log_file}"
+
+#torchrun \
+#--nnodes 1 \
+#--node_rank 0 \
+#--nproc_per_node ${gpu_num} \
+python \
+../../../funasr/bin/train.py \
+++model="${model_name_or_model_dir}" \
+++train_data_set_list="${train_data}" \
+++valid_data_set_list="${val_data}" \
+++dataset_conf.batch_size=500 \
+++dataset_conf.batch_type="token" \
+++dataset_conf.num_workers=0 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=1 \
+++train_conf.resume=false \
+++train_conf.validate_interval=2000 \
+++train_conf.save_checkpoint_interval=2000 \
+++train_conf.keep_nbest_models=20 \
+++train_conf.avg_nbest_model=10 \
+++optim_conf.lr=0.0002 \
+++debug=true \
+++device="cpu" \
+++output_dir="${output_dir}" #&> ${log_file}
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 630c390..d173a53 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -175,6 +175,8 @@
             kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
             kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
             vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+            if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+                vocab_size = tokenizer.get_vocab_size()
         else:
             vocab_size = -1
         kwargs["tokenizer"] = tokenizer
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 880bb63..353ce68 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -102,7 +102,7 @@
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(model, device_ids=[local_rank],
-                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
+                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
     elif use_fsdp:
         # model = FSDP(model).cuda(local_rank)
 
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 34f7b4f..5396c8a 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -92,7 +92,7 @@
             for line in fin:
                 data = json.loads(line.strip())
                 if "text" in data:  # for sft
-                    self.contents.append(data['text'])
+                    contents.append(data['text'])
                 if "source" in data:  # for speech lab pretrain
                     prompt = data.get("prompt", "<ASR>")
                     source = data["source"]
@@ -101,13 +101,20 @@
                     target_len = data.get("target_len", 0)
                     if "aishell" in source:
                         target = target.replace(" ", "")
-                    contents.append({"source": source,
-                                     "prompt": prompt,
-                                     "target": target,
-                                     "source_len": source_len,
-                                     "target_len": target_len,
-                                     }
-                                    )
+
+                    contents_i = {"source": source,
+                                 "prompt": prompt,
+                                 "target": target,
+                                 "source_len": source_len,
+                                 "target_len": target_len,
+                                 }
+                    text_language = data.get("text_language", None)
+                    if text_language is not None:
+                        contents_i["text_language"] = text_language
+                    audio_language = data.get("audio_language", None)
+                    if audio_language is not None:
+                        contents_i["audio_language"] = audio_language
+                    contents.append(contents_i)
 
         self.contents = contents
         
diff --git a/funasr/datasets/sense_voice_datasets/__init__.py b/funasr/datasets/sense_voice_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/sense_voice_datasets/__init__.py
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
new file mode 100644
index 0000000..956cf79
--- /dev/null
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -0,0 +1,118 @@
+import torch
+import random
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "SenseVoiceDataset")
+class SenseVoiceDataset(torch.utils.data.Dataset):
+    """
+    SenseVoiceDataset
+    """
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                  **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+
+        self.int_pad_value = int_pad_value
+        self.float_pad_value = float_pad_value
+        self.sos = kwargs.get("sos", "<|startoftranscript|>")
+        self.eos = kwargs.get("eos", "<|endoftext|>")
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
+        speech = speech.permute(0, 2, 1)
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+        
+        task = item.get("prompt", "<|ASR|>")
+        text_language = item.get("text_language", "<|zh|>")
+
+        prompt = f"{self.sos}{task}{text_language}"
+        prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+        prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
+
+        target_ids = self.tokenizer.encode(target, allowed_special="all")
+        target_ids_len = len(target_ids) + 1 # [lid, text]
+        
+        eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
+        
+        ids = prompt_ids + target_ids + eos
+        ids_lengths = len(ids)
+        
+        text = torch.tensor(ids, dtype=torch.int64)
+        text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
+
+        target_mask = [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
+        target_mask = torch.tensor(target_mask, dtype=torch.float32)
+
+        return {"speech": speech[0, :, :],
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "target_mask": target_mask,
+                }
+    
+    
+    def collator(self, samples: list=None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+    
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+                
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
+
+
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 8f0809a..385025d 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -50,8 +50,8 @@
         """
         assert x.size(2) == self.size
         batch_size = x.size(0)
-        x = x.view(-1, self.size)
-        target = target.view(-1)
+        x = x.contiguous().view(-1, self.size)
+        target = target.contiguous().view(-1)
         with torch.no_grad():
             true_dist = x.clone()
             true_dist.fill_(self.smoothing / (self.size - 1))
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
new file mode 100644
index 0000000..bae2832
--- /dev/null
+++ b/funasr/models/sense_voice/decoder.py
@@ -0,0 +1,66 @@
+import copy
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
+def sense_voice_decode_forward(
+	self,
+	x: torch.Tensor,
+	xa: torch.Tensor,
+	kv_cache: Optional[dict] = None,
+	**kwargs,
+):
+	"""Forward decoder.
+
+	Args:
+		hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+		hlens: (batch)
+		ys_in_pad:
+			input token ids, int64 (batch, maxlen_out)
+			if input_layer == "embed"
+			input tensor (batch, maxlen_out, #mels) in the other cases
+		ys_in_lens: (batch)
+	Returns:
+		(tuple): tuple containing:
+
+		x: decoded token score before softmax (batch, maxlen_out, token)
+			if use_output_layer is True,
+		olens: (batch, )
+	"""
+	# import pdb;pdb.set_trace()
+	use_padmask = self.use_padmask
+	hlens = kwargs.get("hlens", None)
+	
+	ys_in_lens = kwargs.get("ys_in_lens", None)
+	
+	offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+	tgt, memory = x, xa
+	tgt[tgt==-1] = 0
+	tgt = (
+		self.token_embedding(tgt)
+		+ self.positional_embedding[offset : offset + tgt.size(1)]
+	)
+	# tgt = self.dropout(tgt)
+	
+	x = tgt.to(memory.dtype)
+	
+	if use_padmask and hlens is not None:
+		memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+	else:
+		memory_mask = None
+	
+	for layer, block in enumerate(self.blocks):
+		x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+
+
+	x = self.ln(x)
+	x = (
+		x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+	).float()
+	
+	
+	return x
+	
\ No newline at end of file
diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py
new file mode 100644
index 0000000..3870c52
--- /dev/null
+++ b/funasr/models/sense_voice/encoder.py
@@ -0,0 +1,67 @@
+import copy
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
+
+def sense_voice_encode_forward(
+	self,
+	x: torch.Tensor,
+	ilens: torch.Tensor = None,
+	**kwargs,
+):
+	use_padmask = self.use_padmask
+	x = F.gelu(self.conv1(x))
+	x = F.gelu(self.conv2(x))
+	x = x.permute(0, 2, 1)
+	
+	n_frames = x.size(1)
+	max_pos = self.positional_embedding.size(0)
+	max_pos = n_frames if n_frames < max_pos else max_pos
+	x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
+	
+	
+	if ilens is not None:
+		if self.downsample_rate == 4:
+			olens = (
+				1
+				+ (
+					ilens
+					- self.conv1.kernel_size[0]
+					+ 2 * self.conv1.padding[0]
+				)
+				// self.conv1.stride[0]
+			)
+		else:
+			olens = ilens
+		olens = (
+			1
+			+ (
+				olens
+				- self.conv2.kernel_size[0]
+				+ 2 * self.conv2.padding[0]
+			)
+			// self.conv2.stride[0]
+		)
+		olens = torch.clamp(olens, max=max_pos)
+	else:
+		olens = None
+	
+	if use_padmask and olens is not None:
+		padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+	else:
+		padding_mask = None
+	
+	for layer, block in enumerate(self.blocks):
+		x = block(x, mask=padding_mask, is_pad_mask=True)
+		
+
+	x = self.ln_post(x)
+	
+	if ilens is None:
+		return x
+	else:
+		return x, olens
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 4ee2fa5..b5272a1 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1,35 +1,158 @@
 from dataclasses import dataclass
 from typing import Dict
 from typing import Iterable, Optional
+import types
 import time
 import numpy as np
 import torch
 import torch.nn.functional as F
 from torch import Tensor
 from torch import nn
+from torch.cuda.amp import autocast
+from funasr.metrics.compute_acc import compute_accuracy
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.train_utils.device_funcs import force_gatherable
 from . import whisper_lib as whisper
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 from funasr.register import tables
 
 
+
+
 @tables.register("model_classes", "SenseVoice")
 class SenseVoice(nn.Module):
     def __init__(self, *args, **kwargs):
         super().__init__()
-        hub = kwargs.get("hub", "funasr")
-
+        
         dims = kwargs.get("dims", {})
         dims = whisper.model.ModelDimensions(**dims)
         model = whisper.model.Whisper(dims=dims)
+        
+        # encoder
+        model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
+        model.encoder.use_padmask = kwargs.get("use_padmask", True)
+        from .encoder import sense_voice_encode_forward
+        model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
+        
+        # decoder
+        model.decoder.use_padmask = kwargs.get("use_padmask", True)
+        from .decoder import sense_voice_decode_forward
+        model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder)
         
         self.model = model
         
         self.encoder_output_size = self.model.dims.n_audio_state
         
-    def forward(self, ):
-        pass
+        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+        self.ignore_id = kwargs.get("ignore_id", -1)
+        self.vocab_size = kwargs.get("vocab_size", -1)
+        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
+        self.criterion_att = LabelSmoothingLoss(
+            size=self.vocab_size,
+            padding_idx=self.ignore_id,
+            smoothing=kwargs.get("lsm_weight", 0.0),
+            normalize_length=self.length_normalized_loss,
+        )
+        
+        specaug = kwargs.get("specaug", None)
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**kwargs.get("specaug_conf", {}))
+        self.specaug = specaug
+
+ 
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
     
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+    
+        batch_size = speech.shape[0]
+
+        if self.activation_checkpoint:
+            from torch.utils.checkpoint import checkpoint
+            encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+        )
+        loss = loss_att
+        stats = {}
+        stats["acc"] = acc_att
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+        
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) :
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+
+        # Forward encoder
+        encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
+    
+        return encoder_out, encoder_out_lens
+
+
+    def _calc_att_loss(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+            **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+        stats = {}
+        
+        # 1. Forward decoder
+        decoder_out = self.model.decoder(
+            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+        )
+        
+        # 2. Compute attention loss
+        mask = torch.ones_like(ys_pad) * (-1)
+        ys_pad_mask = (ys_pad * target_mask + mask * (1-target_mask)).to(torch.int64)
+        ys_pad_mask[ys_pad_mask == 0] = -1
+        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+
+        with torch.no_grad():
+            preds = torch.argmax(decoder_out, -1)
+            acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id)
+
+        return loss_att, acc_att, None, None
+
+
     def inference(self,
                   data_in,
                   data_lengths=None,
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 0e8f09b..ca960f1 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -74,7 +74,10 @@
         xa: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         kv_cache: Optional[dict] = None,
+        **kwargs,
     ):
+        is_pad_mask = kwargs.get("is_pad_mask", False)
+
         q = self.query(x)
 
         if kv_cache is None or xa is None or self.key not in kv_cache:
@@ -87,12 +90,13 @@
             k = kv_cache[self.key]
             v = kv_cache[self.value]
 
-        wv, qk = self.qkv_attention(q, k, v, mask)
+        wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
         return self.out(wv), qk
 
     def qkv_attention(
-        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
     ):
+        is_pad_mask = kwargs.get("is_pad_mask", False)
         n_batch, n_ctx, n_state = q.shape
         scale = (n_state // self.n_head) ** -0.25
         q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -101,10 +105,20 @@
 
         qk = q @ k
         if mask is not None:
-            qk = qk + mask[:n_ctx, :n_ctx]
+            if not is_pad_mask:
+                qk = qk + mask[:n_ctx, :n_ctx]
+            else:
+                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+                min_value = float(
+                    np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
+                )
+                qk = qk.masked_fill(mask, min_value)
+                
         qk = qk.float()
 
         w = F.softmax(qk, dim=-1).to(q.dtype)
+        if mask is not None and is_pad_mask:
+            w = w.masked_fill(mask, 0.0)
         return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
 
 
@@ -132,10 +146,13 @@
         xa: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         kv_cache: Optional[dict] = None,
+        **kwargs,
     ):
-        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+        is_pad_mask = kwargs.get("is_pad_mask", False)
+        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
         if self.cross_attn:
-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
         x = x + self.mlp(self.mlp_ln(x))
         return x
 
diff --git a/funasr/tokenizer/whisper_tokenizer.py b/funasr/tokenizer/whisper_tokenizer.py
index 6684f25..0a34d19 100644
--- a/funasr/tokenizer/whisper_tokenizer.py
+++ b/funasr/tokenizer/whisper_tokenizer.py
@@ -22,3 +22,25 @@
 	
 	return tokenizer
 
+
+@tables.register("tokenizer_classes", "SenseVoiceTokenizer")
+def SenseVoiceTokenizer(**kwargs):
+	try:
+		from funasr.models.sense_voice.whisper_lib.tokenizer import get_tokenizer
+	except:
+		print("Notice: If you want to use whisper, please `pip install -U openai-whisper`")
+	
+	language = kwargs.get("language", None)
+	task = kwargs.get("task", None)
+	is_multilingual = kwargs.get("is_multilingual", True)
+	num_languages = kwargs.get("num_languages", 8749)
+	vocab_path = kwargs.get("vocab_path", None)
+	tokenizer = get_tokenizer(
+		multilingual=is_multilingual,
+		num_languages=num_languages,
+		language=language,
+		task=task,
+		vocab_path=vocab_path,
+	)
+	
+	return tokenizer

--
Gitblit v1.9.1