Dev gzf exp (#1624)
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
| | |
| | | |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice", |
| | | model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope", |
| | | vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | vad_kwargs={"max_single_segment_time": 30000}, |
| | | ) |
| | | |
| | | |
| | | input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/SenseVoice/aed_ser/asr_bgm.wav" |
| | | input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" |
| | | |
| | | DecodingOptions = { |
| | | "task": ("ASR", "AED", "SER"), |
| New file |
| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | # which gpu to train or finetune |
| | | export CUDA_VISIBLE_DEVICES="0" |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | |
| | | # model_name from model_hub, or model_dir in local path |
| | | |
| | | ## option 1, download model automatically |
| | | model_name_or_model_dir="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | model_name_or_model_dir="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope" |
| | | ## option 2, download model by git |
| | | #local_path_root=${workspace}/modelscope_models |
| | | #mkdir -p ${local_path_root}/${model_name_or_model_dir} |
| | | #git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir} |
| | | #model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir} |
| | | |
| | | |
| | | # data dir, which contains: train.json, val.json |
| | | data_dir="../../../data/list" |
| | | |
| | | train_data="${data_dir}/train.jsonl" |
| | | val_data="${data_dir}/val.jsonl" |
| | | |
| | | # generate train.jsonl and val.jsonl from wav.scp and text.txt |
| | | scp2jsonl \ |
| | | ++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_out="${train_data}" |
| | | |
| | | scp2jsonl \ |
| | | ++scp_file_list='["../../../data/list/val_wav.scp", "../../../data/list/val_text.txt"]' \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_out="${val_data}" |
| | | |
| | | |
| | | # exp output dir |
| | | output_dir="./outputs" |
| | | log_file="${output_dir}/log.txt" |
| | | |
| | | |
| | | mkdir -p ${output_dir} |
| | | echo "log_file: ${log_file}" |
| | | |
| | | #torchrun \ |
| | | #--nnodes 1 \ |
| | | #--node_rank 0 \ |
| | | #--nproc_per_node ${gpu_num} \ |
| | | python \ |
| | | ../../../funasr/bin/train.py \ |
| | | ++model="${model_name_or_model_dir}" \ |
| | | ++train_data_set_list="${train_data}" \ |
| | | ++valid_data_set_list="${val_data}" \ |
| | | ++dataset_conf.batch_size=500 \ |
| | | ++dataset_conf.batch_type="token" \ |
| | | ++dataset_conf.num_workers=0 \ |
| | | ++train_conf.max_epoch=50 \ |
| | | ++train_conf.log_interval=1 \ |
| | | ++train_conf.resume=false \ |
| | | ++train_conf.validate_interval=2000 \ |
| | | ++train_conf.save_checkpoint_interval=2000 \ |
| | | ++train_conf.keep_nbest_models=20 \ |
| | | ++train_conf.avg_nbest_model=10 \ |
| | | ++optim_conf.lr=0.0002 \ |
| | | ++debug=true \ |
| | | ++device="cpu" \ |
| | | ++output_dir="${output_dir}" #&> ${log_file} |
| | |
| | | kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None |
| | | kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] |
| | | vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 |
| | | if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): |
| | | vocab_size = tokenizer.get_vocab_size() |
| | | else: |
| | | vocab_size = -1 |
| | | kwargs["tokenizer"] = tokenizer |
| | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True)) |
| | | elif use_fsdp: |
| | | # model = FSDP(model).cuda(local_rank) |
| | | |
| | |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | self.contents.append(data['text']) |
| | | contents.append(data['text']) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data.get("prompt", "<ASR>") |
| | | source = data["source"] |
| | |
| | | target_len = data.get("target_len", 0) |
| | | if "aishell" in source: |
| | | target = target.replace(" ", "") |
| | | contents.append({"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | ) |
| | | |
| | | contents_i = {"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | text_language = data.get("text_language", None) |
| | | if text_language is not None: |
| | | contents_i["text_language"] = text_language |
| | | audio_language = data.get("audio_language", None) |
| | | if audio_language is not None: |
| | | contents_i["audio_language"] = audio_language |
| | | contents.append(contents_i) |
| | | |
| | | self.contents = contents |
| | | |
| New file |
| | |
| | | import torch |
| | | import random |
| | | |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video |
| | | |
| | | |
| | | @tables.register("dataset_classes", "SenseVoiceDataset") |
| | | class SenseVoiceDataset(torch.utils.data.Dataset): |
| | | """ |
| | | SenseVoiceDataset |
| | | """ |
| | | def __init__(self, |
| | | path, |
| | | index_ds: str = None, |
| | | frontend=None, |
| | | tokenizer=None, |
| | | int_pad_value: int = -1, |
| | | float_pad_value: float = 0.0, |
| | | **kwargs): |
| | | super().__init__() |
| | | index_ds_class = tables.index_ds_classes.get(index_ds) |
| | | self.index_ds = index_ds_class(path, **kwargs) |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) |
| | | self.preprocessor_speech = preprocessor_speech |
| | | preprocessor_text = kwargs.get("preprocessor_text", None) |
| | | if preprocessor_text: |
| | | preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) |
| | | self.preprocessor_text = preprocessor_text |
| | | |
| | | self.frontend = frontend |
| | | self.fs = 16000 if frontend is None else frontend.fs |
| | | self.data_type = "sound" |
| | | self.tokenizer = tokenizer |
| | | |
| | | self.int_pad_value = int_pad_value |
| | | self.float_pad_value = float_pad_value |
| | | self.sos = kwargs.get("sos", "<|startoftranscript|>") |
| | | self.eos = kwargs.get("eos", "<|endoftext|>") |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_source_len(item) |
| | | |
| | | def get_target_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_target_len(item) |
| | | |
| | | def __len__(self): |
| | | return len(self.index_ds) |
| | | |
| | | def __getitem__(self, index): |
| | | item = self.index_ds[index] |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | source = item["source"] |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src, fs=self.fs) |
| | | speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] |
| | | speech = speech.permute(0, 2, 1) |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | | target = self.preprocessor_text(target) |
| | | |
| | | task = item.get("prompt", "<|ASR|>") |
| | | text_language = item.get("text_language", "<|zh|>") |
| | | |
| | | prompt = f"{self.sos}{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | prompt_ids_len = len(prompt_ids) - 1 # [sos, task] |
| | | |
| | | target_ids = self.tokenizer.encode(target, allowed_special="all") |
| | | target_ids_len = len(target_ids) + 1 # [lid, text] |
| | | |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | |
| | | ids = prompt_ids + target_ids + eos |
| | | ids_lengths = len(ids) |
| | | |
| | | text = torch.tensor(ids, dtype=torch.int64) |
| | | text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) |
| | | |
| | | target_mask = [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1] # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] |
| | | target_mask = torch.tensor(target_mask, dtype=torch.float32) |
| | | |
| | | return {"speech": speech[0, :, :], |
| | | "speech_lengths": speech_lengths, |
| | | "text": text, |
| | | "text_lengths": text_lengths, |
| | | "target_mask": target_mask, |
| | | } |
| | | |
| | | |
| | | def collator(self, samples: list=None): |
| | | outputs = {} |
| | | for sample in samples: |
| | | for key in sample.keys(): |
| | | if key not in outputs: |
| | | outputs[key] = [] |
| | | outputs[key].append(sample[key]) |
| | | |
| | | for key, data_list in outputs.items(): |
| | | if isinstance(data_list[0], torch.Tensor): |
| | | if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: |
| | | |
| | | pad_value = self.int_pad_value |
| | | else: |
| | | pad_value = self.float_pad_value |
| | | |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) |
| | | return outputs |
| | | |
| | | |
| | |
| | | """ |
| | | assert x.size(2) == self.size |
| | | batch_size = x.size(0) |
| | | x = x.view(-1, self.size) |
| | | target = target.view(-1) |
| | | x = x.contiguous().view(-1, self.size) |
| | | target = target.contiguous().view(-1) |
| | | with torch.no_grad(): |
| | | true_dist = x.clone() |
| | | true_dist.fill_(self.smoothing / (self.size - 1)) |
| New file |
| | |
| | | import copy |
| | | from typing import Optional, Tuple, Union |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | |
| | | def sense_voice_decode_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | xa: torch.Tensor, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | # import pdb;pdb.set_trace() |
| | | use_padmask = self.use_padmask |
| | | hlens = kwargs.get("hlens", None) |
| | | |
| | | ys_in_lens = kwargs.get("ys_in_lens", None) |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt==-1] = 0 |
| | | tgt = ( |
| | | self.token_embedding(tgt) |
| | | + self.positional_embedding[offset : offset + tgt.size(1)] |
| | | ) |
| | | # tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | | |
| | | if use_padmask and hlens is not None: |
| | | memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) |
| | | else: |
| | | memory_mask = None |
| | | |
| | | for layer, block in enumerate(self.blocks): |
| | | x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) |
| | | |
| | | |
| | | x = self.ln(x) |
| | | x = ( |
| | | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | |
| | | |
| | | return x |
| | | |
| New file |
| | |
| | | import copy |
| | | from typing import Optional, Tuple, Union |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | |
| | | |
| | | def sense_voice_encode_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | ilens: torch.Tensor = None, |
| | | **kwargs, |
| | | ): |
| | | use_padmask = self.use_padmask |
| | | x = F.gelu(self.conv1(x)) |
| | | x = F.gelu(self.conv2(x)) |
| | | x = x.permute(0, 2, 1) |
| | | |
| | | n_frames = x.size(1) |
| | | max_pos = self.positional_embedding.size(0) |
| | | max_pos = n_frames if n_frames < max_pos else max_pos |
| | | x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype) |
| | | |
| | | |
| | | if ilens is not None: |
| | | if self.downsample_rate == 4: |
| | | olens = ( |
| | | 1 |
| | | + ( |
| | | ilens |
| | | - self.conv1.kernel_size[0] |
| | | + 2 * self.conv1.padding[0] |
| | | ) |
| | | // self.conv1.stride[0] |
| | | ) |
| | | else: |
| | | olens = ilens |
| | | olens = ( |
| | | 1 |
| | | + ( |
| | | olens |
| | | - self.conv2.kernel_size[0] |
| | | + 2 * self.conv2.padding[0] |
| | | ) |
| | | // self.conv2.stride[0] |
| | | ) |
| | | olens = torch.clamp(olens, max=max_pos) |
| | | else: |
| | | olens = None |
| | | |
| | | if use_padmask and olens is not None: |
| | | padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device) |
| | | else: |
| | | padding_mask = None |
| | | |
| | | for layer, block in enumerate(self.blocks): |
| | | x = block(x, mask=padding_mask, is_pad_mask=True) |
| | | |
| | | |
| | | x = self.ln_post(x) |
| | | |
| | | if ilens is None: |
| | | return x |
| | | else: |
| | | return x, olens |
| | |
| | | from dataclasses import dataclass |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |
| | | import types |
| | | import time |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor |
| | | from torch import nn |
| | | from torch.cuda.amp import autocast |
| | | from funasr.metrics.compute_acc import compute_accuracy |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoice") |
| | | class SenseVoice(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | hub = kwargs.get("hub", "funasr") |
| | | |
| | | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | # encoder |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | |
| | | # decoder |
| | | model.decoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .decoder import sense_voice_decode_forward |
| | | model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder) |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | def forward(self, ): |
| | | pass |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | | padding_idx=self.ignore_id, |
| | | smoothing=kwargs.get("lsm_weight", 0.0), |
| | | normalize_length=self.length_normalized_loss, |
| | | ) |
| | | |
| | | specaug = kwargs.get("specaug", None) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**kwargs.get("specaug_conf", {})) |
| | | self.specaug = specaug |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | ) |
| | | loss = loss_att |
| | | stats = {} |
| | | stats["acc"] = acc_att |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | ) : |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1-target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | |
| | | q = self.query(x) |
| | | |
| | | if kv_cache is None or xa is None or self.key not in kv_cache: |
| | |
| | | k = kv_cache[self.key] |
| | | v = kv_cache[self.value] |
| | | |
| | | wv, qk = self.qkv_attention(q, k, v, mask) |
| | | wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask) |
| | | return self.out(wv), qk |
| | | |
| | | def qkv_attention( |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | n_batch, n_ctx, n_state = q.shape |
| | | scale = (n_state // self.n_head) ** -0.25 |
| | | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale |
| | |
| | | |
| | | qk = q @ k |
| | | if mask is not None: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | if not is_pad_mask: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float( |
| | | np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min |
| | | ) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |
| | | |
| | | w = F.softmax(qk, dim=-1).to(q.dtype) |
| | | if mask is not None and is_pad_mask: |
| | | w = w.masked_fill(mask, 0.0) |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | |
| | | |
| | |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | if self.cross_attn: |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0] |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | |
| | |
| | | |
| | | return tokenizer |
| | | |
| | | |
| | | @tables.register("tokenizer_classes", "SenseVoiceTokenizer") |
| | | def SenseVoiceTokenizer(**kwargs): |
| | | try: |
| | | from funasr.models.sense_voice.whisper_lib.tokenizer import get_tokenizer |
| | | except: |
| | | print("Notice: If you want to use whisper, please `pip install -U openai-whisper`") |
| | | |
| | | language = kwargs.get("language", None) |
| | | task = kwargs.get("task", None) |
| | | is_multilingual = kwargs.get("is_multilingual", True) |
| | | num_languages = kwargs.get("num_languages", 8749) |
| | | vocab_path = kwargs.get("vocab_path", None) |
| | | tokenizer = get_tokenizer( |
| | | multilingual=is_multilingual, |
| | | num_languages=num_languages, |
| | | language=language, |
| | | task=task, |
| | | vocab_path=vocab_path, |
| | | ) |
| | | |
| | | return tokenizer |