From fa6f60fa762f271d096b8749f3cc9bfc61a6ed48 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 23 二月 2024 14:01:44 +0800
Subject: [PATCH] update
---
/dev/null | 96 ------------
funasr/models/llm_asr/model.py | 25 +-
funasr/datasets/llm_datasets/datasets.py | 13
funasr/bin/train.py | 4
funasr/auto/auto_model.py | 6
funasr/datasets/llm_datasets/preprocessor.py | 40 +---
funasr/models/paraformer/cif_predictor.py | 200 +++++++++---------------
setup.py | 2
funasr/metrics/compute_acc.py | 4
examples/aishell/llm_asr_nar/conf/template.yaml | 34 ++--
10 files changed, 134 insertions(+), 290 deletions(-)
diff --git a/examples/aishell/llm_asr_nar/conf/template.yaml b/examples/aishell/llm_asr_nar/conf/template.yaml
index 0b26969..d529635 100644
--- a/examples/aishell/llm_asr_nar/conf/template.yaml
+++ b/examples/aishell/llm_asr_nar/conf/template.yaml
@@ -24,11 +24,11 @@
init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
freeze: true
-adaptor: linear
+adaptor: Linear
adaptor_conf:
downsample_rate: 1
llm_dim: 4096
- encoder_dim: 2048
+ encoder_dim: 512
# frontend related
frontend: WavFrontend
@@ -38,54 +38,56 @@
n_mels: 80
frame_length: 25
frame_shift: 10
- dither: 0.0
- lfr_m: 1
- lfr_n: 1
+ lfr_m: 7
+ lfr_n: 6
+ cmvn_file: "/root/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
-specaug: SpecAug
+specaug: SpecAugLFR
specaug_conf:
- apply_time_warp: true
+ apply_time_warp: false
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
- num_freq_mask: 2
+ lfr_rate: 6
+ num_freq_mask: 1
apply_time_mask: true
time_mask_width_range:
- 0
- - 40
- num_time_mask: 2
+ - 12
+ num_time_mask: 1
train_conf:
accum_grad: 1
grad_clip: 5
max_epoch: 150
keep_nbest_models: 10
- log_interval: 50
+ log_interval: 10
-optim: adam
+optim: adamw
optim_conf:
- lr: 0.001
+ lr: 0.0001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
- warmup_steps: 35000
+ warmup_steps: 1500
dataset: AudioLLMDataset
dataset_conf:
index_ds: IndexDSJsonl
batch_sampler: RankFullLocalShuffleBatchSampler
batch_type: example # example or length
- batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 500
shuffle: True
num_workers: 4
+ preprocessor_text: TextPreprocessRemovePunctuation
tokenizer: HuggingfaceTokenizer
tokenizer_conf:
unk_symbol: <unk>
- init_param_path: null
+ init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index e5faa2a..3b70ad6 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -157,8 +157,10 @@
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
- kwargs["token_list"] = tokenizer.token_list
- vocab_size = len(tokenizer.token_list)
+
+ kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
+ vocab_size = len(kwargs["token_list"])
else:
vocab_size = -1
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 26b0f4a..44d84e7 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -85,7 +85,9 @@
# build model
model_class = tables.model_classes.get(kwargs["model"])
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
+ vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index 20eb8aa..ab0e48a 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -24,12 +24,12 @@
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
- preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+ preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
- preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
self.preprocessor_text = preprocessor_text
self.frontend = frontend
@@ -43,6 +43,7 @@
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
self.prompt_af = ""
+ self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
def get_source_len(self, index):
item = self.index_ds[index]
@@ -64,7 +65,7 @@
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
- speech = speech.sequeeze(0)
+ speech = speech.squeeze(0)
target = item["target"]
if self.preprocessor_text:
@@ -91,10 +92,10 @@
label_mask = labels_ids.ge(0) # [False,False,True,True]
labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
- audio_mask = [0] * prompt_pre_length + [1] * audio_length
- torch.tensor(audio_mask, dtype=torch.float32)
+ audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+ audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
- ids = self.tokenizer.encode(target)
+ ids = self.tokenizer.encode(target) # token ids is different from labels_ids
text = torch.tensor(ids, dtype=torch.int64)
text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
index ab75140..9f20672 100644
--- a/funasr/datasets/llm_datasets/preprocessor.py
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -11,41 +11,27 @@
from torch import nn
import random
import re
+import string
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
-@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
-class SpeechPreprocessSpeedPerturb(nn.Module):
- def __init__(self, speed_perturb: list=None, **kwargs):
- super().__init__()
- self.speed_perturb = speed_perturb
-
- def forward(self, waveform, fs, **kwargs):
- if self.speed_perturb is None:
- return waveform
- speed = random.choice(self.speed_perturb)
- if speed != 1.0:
- if not isinstance(waveform, torch.Tensor):
- waveform = torch.tensor(waveform)
- waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
- waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
- waveform = waveform.view(-1)
-
- return waveform
-
-@tables.register("preprocessor_classes", "TextPreprocessSegDict")
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
class TextPreprocessSegDict(nn.Module):
- def __init__(self, seg_dict: str = None,
- text_cleaner: Collection[str] = None,
- split_with_space: bool = False,
+ def __init__(self,
**kwargs):
super().__init__()
- self.text_cleaner = TextCleaner(text_cleaner)
def forward(self, text, **kwargs):
- text = self.text_cleaner(text)
-
- return text
+ # 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
+ en_punct = string.punctuation
+ # 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
+ cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
+ # 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
+ all_punct = en_punct + cn_punct
+ # 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
+ punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
+ # 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
+ return punct_pattern.sub('', text)
diff --git a/funasr/datasets/llm_datasets/scp2jsonl.py b/funasr/datasets/llm_datasets/scp2jsonl.py
deleted file mode 100644
index e09a84a..0000000
--- a/funasr/datasets/llm_datasets/scp2jsonl.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import os
-import json
-import torch
-import logging
-import hydra
-from omegaconf import DictConfig, OmegaConf
-import concurrent.futures
-import librosa
-import torch.distributed as dist
-
-
-
-def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
-
- cpu_cores = os.cpu_count() or 1
- print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
- if rank == 0:
- json_dict = {}
- for data_type, data_file in zip(data_type_list, path):
- json_dict[data_type] = {}
- with open(data_file, "r") as f:
-
- data_file_lists = f.readlines()
- lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
- task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
- with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
-
- futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
-
- for future in concurrent.futures.as_completed(futures):
-
- json_dict[data_type].update(future.result())
- # print(json_dict)
-
- with open(jsonl_file_out, "w") as f:
- for key in json_dict[data_type_list[0]].keys():
- jsonl_line = {"key": key}
- for data_file in data_type_list:
- jsonl_line.update(json_dict[data_file][key])
- jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
- f.write(jsonl_line+"\n")
- f.flush()
-
- else:
- pass
-
- if world_size > 1:
- dist.barrier()
-
-
-def parse_context_length(data_list: list, data_type: str):
-
- res = {}
- for i, line in enumerate(data_list):
- key, line = line.strip().split(maxsplit=1)
- line = line.strip()
- if os.path.exists(line):
- waveform, _ = librosa.load(line, sr=16000)
- sample_num = len(waveform)
- context_len = int(sample_num//16000*1000/10)
- else:
- context_len = len(line.split()) if " " in line else len(line)
- res[key] = {data_type: line, f"{data_type}_len": context_len}
- return res
-
-
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(cfg: DictConfig):
-
- kwargs = OmegaConf.to_container(cfg, resolve=True)
-
- scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
- if isinstance(scp_file_list, str):
- scp_file_list = eval(scp_file_list)
- data_type_list = kwargs.get("data_type_list", ("source", "target"))
- jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
- gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
-
-
-"""
-python -m funasr.datasets.audio_datasets.scp2jsonl \
-++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
-++data_type_list='["source", "target"]' \
-++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
-"""
-
-if __name__ == "__main__":
- main_hydra()
-
-
\ No newline at end of file
diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py
index 73545c0..ec8067f 100644
--- a/funasr/metrics/compute_acc.py
+++ b/funasr/metrics/compute_acc.py
@@ -35,8 +35,6 @@
"""
mask = pad_targets != ignore_label
- numerator = torch.sum(
- pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
- )
+ numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
\ No newline at end of file
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index a903262..06323c6 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -73,7 +73,7 @@
hub = encoder_conf.get("hub", None)
if hub == "funasr":
from funasr import AutoModel
- init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
model = AutoModel(model=init_param_path, model_revision="v2.0.4")
# frontend = model.kwargs.get("frontend")
model.model.decoder = None
@@ -179,6 +179,7 @@
if input_ids is not None:
input_ids[input_ids == -1] = 0
+ input_ids[input_ids == -100] = 0
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
@@ -190,7 +191,7 @@
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_out.shape
encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
- inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
+ inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
@@ -198,11 +199,10 @@
stats = {}
- if self.metric:
- with torch.no_grad():
- preds = torch.argmax(model_outputs.logits, -1)
- acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
- stats["acc"] = acc_att
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+ stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
@@ -221,11 +221,12 @@
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
- enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
- pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
- mask=enc_mask,
- target_label_length=audio_token_lengths,
- )
+ with autocast(False):
+ enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+ mask=enc_mask,
+ target_label_length=audio_token_lengths,
+ )
return pre_acoustic_embeds, pre_token_length
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 60ddc24..4d9f5d8 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -10,7 +10,7 @@
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-
+from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
@@ -28,42 +28,44 @@
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- memory = self.cif_conv1d(queries)
- output = memory + context
- output = self.dropout(output)
- output = output.transpose(1, 2)
- output = torch.relu(output)
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ with autocast(False):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.cif_conv1d(queries)
+ output = memory + context
+ output = self.dropout(output)
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
-
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -169,41 +171,43 @@
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
- output = output.transpose(1, 2)
-
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length.squeeze(-1)
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- if self.tail_mask:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ with autocast(False):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length.squeeze(-1)
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
else:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
-
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ if self.tail_mask:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+ else:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
@@ -370,62 +374,6 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
-
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## predictor
- "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (256,256,3),(3,256,256)
- "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.cif_output.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1,256),(1,256,1)
- "{}.cif_output.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1,),(1,)
- }
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
class mae_loss(torch.nn.Module):
diff --git a/setup.py b/setup.py
index f703bb4..4e76c80 100644
--- a/setup.py
+++ b/setup.py
@@ -40,11 +40,11 @@
"umap_learn",
"jaconv",
"hydra-core>=1.3.2",
+ "tensorboardX",
],
# train: The modules invoked when training only.
"train": [
"editdistance",
- "tensorboardX",
],
# all: The modules should be optionally installled due to some reason.
# Please consider moving them to "install" occasionally
--
Gitblit v1.9.1