| | |
| | | init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" |
| | | freeze: true |
| | | |
| | | adaptor: linear |
| | | adaptor: Linear |
| | | adaptor_conf: |
| | | downsample_rate: 1 |
| | | llm_dim: 4096 |
| | | encoder_dim: 2048 |
| | | encoder_dim: 512 |
| | | |
| | | # frontend related |
| | | frontend: WavFrontend |
| | |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | dither: 0.0 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | lfr_m: 7 |
| | | lfr_n: 6 |
| | | cmvn_file: "/root/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn" |
| | | |
| | | specaug: SpecAug |
| | | specaug: SpecAugLFR |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | apply_time_warp: false |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 30 |
| | | num_freq_mask: 2 |
| | | lfr_rate: 6 |
| | | num_freq_mask: 1 |
| | | apply_time_mask: true |
| | | time_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_time_mask: 2 |
| | | - 12 |
| | | num_time_mask: 1 |
| | | |
| | | train_conf: |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 150 |
| | | keep_nbest_models: 10 |
| | | log_interval: 50 |
| | | log_interval: 10 |
| | | |
| | | optim: adam |
| | | optim: adamw |
| | | optim_conf: |
| | | lr: 0.001 |
| | | lr: 0.0001 |
| | | weight_decay: 0.000001 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 35000 |
| | | warmup_steps: 1500 |
| | | |
| | | dataset: AudioLLMDataset |
| | | dataset_conf: |
| | | index_ds: IndexDSJsonl |
| | | batch_sampler: RankFullLocalShuffleBatchSampler |
| | | batch_type: example # example or length |
| | | batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, |
| | | buffer_size: 500 |
| | | shuffle: True |
| | | num_workers: 4 |
| | | preprocessor_text: TextPreprocessRemovePunctuation |
| | | |
| | | tokenizer: HuggingfaceTokenizer |
| | | tokenizer_conf: |
| | | unk_symbol: <unk> |
| | | init_param_path: null |
| | | init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" |
| | | |
| | |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | kwargs["token_list"] = tokenizer.token_list |
| | | vocab_size = len(tokenizer.token_list) |
| | | |
| | | kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None |
| | | kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] |
| | | vocab_size = len(kwargs["token_list"]) |
| | | else: |
| | | vocab_size = -1 |
| | | |
| | |
| | | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | | vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None |
| | | vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) |
| | | |
| | | |
| | | |
| | |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {})) |
| | | self.preprocessor_speech = preprocessor_speech |
| | | preprocessor_text = kwargs.get("preprocessor_text", None) |
| | | if preprocessor_text: |
| | | preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {})) |
| | | self.preprocessor_text = preprocessor_text |
| | | |
| | | self.frontend = frontend |
| | |
| | | self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format( |
| | | self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: " |
| | | self.prompt_af = "" |
| | | self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100) |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src, fs=self.fs) |
| | | speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] |
| | | speech = speech.sequeeze(0) |
| | | speech = speech.squeeze(0) |
| | | |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | |
| | | label_mask = labels_ids.ge(0) # [False,False,True,True] |
| | | labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos] |
| | | |
| | | audio_mask = [0] * prompt_pre_length + [1] * audio_length |
| | | torch.tensor(audio_mask, dtype=torch.float32) |
| | | audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0] |
| | | audio_mask = torch.tensor(audio_mask, dtype=torch.float32) |
| | | |
| | | ids = self.tokenizer.encode(target) |
| | | ids = self.tokenizer.encode(target) # token ids is different from labels_ids |
| | | text = torch.tensor(ids, dtype=torch.int64) |
| | | text_lengths = torch.tensor([len(ids)], dtype=torch.int32) |
| | | |
| | |
| | | from torch import nn |
| | | import random |
| | | import re |
| | | import string |
| | | from funasr.tokenizer.cleaner import TextCleaner |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb") |
| | | class SpeechPreprocessSpeedPerturb(nn.Module): |
| | | def __init__(self, speed_perturb: list=None, **kwargs): |
| | | super().__init__() |
| | | self.speed_perturb = speed_perturb |
| | | |
| | | def forward(self, waveform, fs, **kwargs): |
| | | if self.speed_perturb is None: |
| | | return waveform |
| | | speed = random.choice(self.speed_perturb) |
| | | if speed != 1.0: |
| | | if not isinstance(waveform, torch.Tensor): |
| | | waveform = torch.tensor(waveform) |
| | | waveform, _ = torchaudio.sox_effects.apply_effects_tensor( |
| | | waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]]) |
| | | waveform = waveform.view(-1) |
| | | |
| | | return waveform |
| | | |
| | | |
| | | @tables.register("preprocessor_classes", "TextPreprocessSegDict") |
| | | @tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation") |
| | | class TextPreprocessSegDict(nn.Module): |
| | | def __init__(self, seg_dict: str = None, |
| | | text_cleaner: Collection[str] = None, |
| | | split_with_space: bool = False, |
| | | def __init__(self, |
| | | **kwargs): |
| | | super().__init__() |
| | | |
| | | self.text_cleaner = TextCleaner(text_cleaner) |
| | | |
| | | def forward(self, text, **kwargs): |
| | | text = self.text_cleaner(text) |
| | | |
| | | return text |
| | | # 定义英文标点符号 |
| | | en_punct = string.punctuation |
| | | # 定义中文标点符号(部分常用的) |
| | | cn_punct = '。?!,、;:“”‘’()《》【】…—~·' |
| | | # 合并英文和中文标点符号 |
| | | all_punct = en_punct + cn_punct |
| | | # 创建正则表达式模式,匹配任何在all_punct中的字符 |
| | | punct_pattern = re.compile('[{}]'.format(re.escape(all_punct))) |
| | | # 使用正则表达式的sub方法替换掉这些字符 |
| | | return punct_pattern.sub('', text) |
| | |
| | | |
| | | """ |
| | | mask = pad_targets != ignore_label |
| | | numerator = torch.sum( |
| | | pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) |
| | | ) |
| | | numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)) |
| | | denominator = torch.sum(mask) |
| | | return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type |
| | |
| | | hub = encoder_conf.get("hub", None) |
| | | if hub == "funasr": |
| | | from funasr import AutoModel |
| | | init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") |
| | | init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") |
| | | model = AutoModel(model=init_param_path, model_revision="v2.0.4") |
| | | # frontend = model.kwargs.get("frontend") |
| | | model.model.decoder = None |
| | |
| | | |
| | | if input_ids is not None: |
| | | input_ids[input_ids == -1] = 0 |
| | | input_ids[input_ids == -100] = 0 |
| | | if hasattr(self.llm.model, "embed_tokens"): |
| | | inputs_embeds = self.llm.model.embed_tokens(input_ids) |
| | | elif hasattr(self.llm.model.model, "embed_tokens"): |
| | |
| | | batch_size, token_num, dims = inputs_embeds.shape |
| | | _, l, _ = encoder_out.shape |
| | | encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0) |
| | | inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None]) |
| | | inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None]) |
| | | inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0) |
| | | |
| | | model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids) |
| | |
| | | |
| | | |
| | | stats = {} |
| | | if self.metric: |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(model_outputs.logits, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) |
| | | stats["acc"] = acc_att |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(model_outputs.logits, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) |
| | | stats["acc"] = acc_att |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | |
| | | |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | enc, enc_lens = self.audio_encoder.encode(**batch) |
| | | enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :] |
| | | pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc, |
| | | mask=enc_mask, |
| | | target_label_length=audio_token_lengths, |
| | | ) |
| | | with autocast(False): |
| | | enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :] |
| | | pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc, |
| | | mask=enc_mask, |
| | | target_label_length=audio_token_lengths, |
| | | ) |
| | | |
| | | return pre_acoustic_embeds, pre_token_length |
| | | |
| | |
| | | from funasr.register import tables
|
| | | from funasr.train_utils.device_funcs import to_device
|
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
| | |
|
| | | from torch.cuda.amp import autocast
|
| | |
|
| | | @tables.register("predictor_classes", "CifPredictor")
|
| | | class CifPredictor(torch.nn.Module):
|
| | |
| | |
|
| | | def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
| | | target_label_length=None):
|
| | | h = hidden
|
| | | context = h.transpose(1, 2)
|
| | | queries = self.pad(context)
|
| | | memory = self.cif_conv1d(queries)
|
| | | output = memory + context
|
| | | output = self.dropout(output)
|
| | | output = output.transpose(1, 2)
|
| | | output = torch.relu(output)
|
| | | output = self.cif_output(output)
|
| | | alphas = torch.sigmoid(output)
|
| | | alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
| | | if mask is not None:
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | alphas = alphas * mask
|
| | | if mask_chunk_predictor is not None:
|
| | | alphas = alphas * mask_chunk_predictor
|
| | | alphas = alphas.squeeze(-1)
|
| | | mask = mask.squeeze(-1)
|
| | | if target_label_length is not None:
|
| | | target_length = target_label_length
|
| | | elif target_label is not None:
|
| | | target_length = (target_label != ignore_id).float().sum(-1)
|
| | | else:
|
| | | target_length = None
|
| | | token_num = alphas.sum(-1)
|
| | | if target_length is not None:
|
| | | alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
| | | elif self.tail_threshold > 0.0:
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
| | | |
| | | with autocast(False):
|
| | | h = hidden
|
| | | context = h.transpose(1, 2)
|
| | | queries = self.pad(context)
|
| | | memory = self.cif_conv1d(queries)
|
| | | output = memory + context
|
| | | output = self.dropout(output)
|
| | | output = output.transpose(1, 2)
|
| | | output = torch.relu(output)
|
| | | output = self.cif_output(output)
|
| | | alphas = torch.sigmoid(output)
|
| | | alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
| | | if mask is not None:
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | alphas = alphas * mask
|
| | | if mask_chunk_predictor is not None:
|
| | | alphas = alphas * mask_chunk_predictor
|
| | | alphas = alphas.squeeze(-1)
|
| | | mask = mask.squeeze(-1)
|
| | | if target_label_length is not None:
|
| | | target_length = target_label_length
|
| | | elif target_label is not None:
|
| | | target_length = (target_label != ignore_id).float().sum(-1)
|
| | | else:
|
| | | target_length = None
|
| | | token_num = alphas.sum(-1)
|
| | | if target_length is not None:
|
| | | alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
| | | elif self.tail_threshold > 0.0:
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
| | | |
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | |
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | |
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | | |
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | | |
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| | | def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
| | |
| | |
|
| | | def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
| | | target_label_length=None):
|
| | | h = hidden
|
| | | context = h.transpose(1, 2)
|
| | | queries = self.pad(context)
|
| | | output = torch.relu(self.cif_conv1d(queries))
|
| | | output = output.transpose(1, 2)
|
| | |
|
| | | output = self.cif_output(output)
|
| | | alphas = torch.sigmoid(output)
|
| | | alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
| | | if mask is not None:
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | alphas = alphas * mask
|
| | | if mask_chunk_predictor is not None:
|
| | | alphas = alphas * mask_chunk_predictor
|
| | | alphas = alphas.squeeze(-1)
|
| | | mask = mask.squeeze(-1)
|
| | | if target_label_length is not None:
|
| | | target_length = target_label_length.squeeze(-1)
|
| | | elif target_label is not None:
|
| | | target_length = (target_label != ignore_id).float().sum(-1)
|
| | | else:
|
| | | target_length = None
|
| | | token_num = alphas.sum(-1)
|
| | | if target_length is not None:
|
| | | alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
| | | elif self.tail_threshold > 0.0:
|
| | | if self.tail_mask:
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
| | | |
| | | with autocast(False):
|
| | | h = hidden
|
| | | context = h.transpose(1, 2)
|
| | | queries = self.pad(context)
|
| | | output = torch.relu(self.cif_conv1d(queries))
|
| | | output = output.transpose(1, 2)
|
| | | |
| | | output = self.cif_output(output)
|
| | | alphas = torch.sigmoid(output)
|
| | | alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
| | | if mask is not None:
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | alphas = alphas * mask
|
| | | if mask_chunk_predictor is not None:
|
| | | alphas = alphas * mask_chunk_predictor
|
| | | alphas = alphas.squeeze(-1)
|
| | | mask = mask.squeeze(-1)
|
| | | if target_label_length is not None:
|
| | | target_length = target_label_length.squeeze(-1)
|
| | | elif target_label is not None:
|
| | | target_length = (target_label != ignore_id).float().sum(-1)
|
| | | else:
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
|
| | |
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | | target_length = None
|
| | | token_num = alphas.sum(-1)
|
| | | if target_length is not None:
|
| | | alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
| | | elif self.tail_threshold > 0.0:
|
| | | if self.tail_mask:
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
|
| | | else:
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
|
| | | |
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | |
|
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| | |
| | | predictor_alignments = index_div_bool_zeros_count_tile_out
|
| | | predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
| | | return predictor_alignments.detach(), predictor_alignments_length.detach()
|
| | |
|
| | | def gen_tf2torch_map_dict(self):
|
| | | |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
| | | map_dict_local = {
|
| | | ## predictor
|
| | | "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
|
| | | {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
| | | "squeeze": None,
|
| | | "transpose": (2, 1, 0),
|
| | | }, # (256,256,3),(3,256,256)
|
| | | "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
|
| | | {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
| | | "squeeze": None,
|
| | | "transpose": None,
|
| | | }, # (256,),(256,)
|
| | | "{}.cif_output.weight".format(tensor_name_prefix_torch):
|
| | | {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
| | | "squeeze": 0,
|
| | | "transpose": (1, 0),
|
| | | }, # (1,256),(1,256,1)
|
| | | "{}.cif_output.bias".format(tensor_name_prefix_torch):
|
| | | {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
|
| | | "squeeze": None,
|
| | | "transpose": None,
|
| | | }, # (1,),(1,)
|
| | | }
|
| | | return map_dict_local
|
| | |
|
| | | def convert_tf2torch(self,
|
| | | var_dict_tf,
|
| | | var_dict_torch,
|
| | | ):
|
| | | map_dict = self.gen_tf2torch_map_dict()
|
| | | var_dict_torch_update = dict()
|
| | | for name in sorted(var_dict_torch.keys(), reverse=False):
|
| | | names = name.split('.')
|
| | | if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
| | | name_tf = map_dict[name]["name"]
|
| | | data_tf = var_dict_tf[name_tf]
|
| | | if map_dict[name]["squeeze"] is not None:
|
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
| | | if map_dict[name]["transpose"] is not None:
|
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
| | | var_dict_torch[
|
| | | name].size(),
|
| | | data_tf.size())
|
| | | var_dict_torch_update[name] = data_tf
|
| | | logging.info(
|
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
| | | var_dict_tf[name_tf].shape))
|
| | | |
| | | return var_dict_torch_update
|
| | |
|
| | |
|
| | | class mae_loss(torch.nn.Module):
|
| | |
| | | "umap_learn", |
| | | "jaconv", |
| | | "hydra-core>=1.3.2", |
| | | "tensorboardX", |
| | | ], |
| | | # train: The modules invoked when training only. |
| | | "train": [ |
| | | "editdistance", |
| | | "tensorboardX", |
| | | ], |
| | | # all: The modules should be optionally installled due to some reason. |
| | | # Please consider moving them to "install" occasionally |