| | |
| | | x = self.output_layer(x) |
| | | return x, olens |
| | | |
| | | def gen_tf2torch_map_dict(self): |
| | | |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch |
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf |
| | | map_dict_local = { |
| | | @tables.register("decoder_classes", "ContextualParaformerDecoderExport") |
| | | class ContextualParaformerDecoderExport(torch.nn.Module): |
| | | def __init__(self, model, |
| | | max_seq_len=512, |
| | | model_name='decoder', |
| | | onnx: bool = True, |
| | | **kwargs,): |
| | | super().__init__() |
| | | from funasr.utils.torch_function import sequence_mask |
| | | self.model = model |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport |
| | | from funasr.models.paraformer.decoder import DecoderLayerSANMExport |
| | | from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANMExport |
| | | |
| | | ## decoder |
| | | # ffn |
| | | "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | for i, d in enumerate(self.model.decoders): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward) |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn) |
| | | if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): |
| | | d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn) |
| | | self.model.decoders[i] = DecoderLayerSANMExport(d) |
| | | |
| | | # fsmn |
| | | "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 2, 0), |
| | | }, # (256,1,31),(1,31,256,1) |
| | | # src att |
| | | "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # dnn |
| | | "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | if self.model.decoders2 is not None: |
| | | for i, d in enumerate(self.model.decoders2): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward) |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn) |
| | | self.model.decoders2[i] = DecoderLayerSANMExport(d) |
| | | |
| | | # embed_concat_ffn |
| | | "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | for i, d in enumerate(self.model.decoders3): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward) |
| | | self.model.decoders3[i] = DecoderLayerSANMExport(d) |
| | | |
| | | self.output_layer = model.output_layer |
| | | self.after_norm = model.after_norm |
| | | self.model_name = model_name |
| | | |
| | | # out norm |
| | | "{}.after_norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.after_norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # bias decoder |
| | | if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt): |
| | | self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAttExport(self.model.bias_decoder.src_attn) |
| | | self.bias_decoder = self.model.bias_decoder |
| | | |
| | | # last decoder |
| | | if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt): |
| | | self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAttExport(self.model.last_decoder.src_attn) |
| | | if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoderExport(self.model.last_decoder.self_attn) |
| | | if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANMExport(self.model.last_decoder.feed_forward) |
| | | self.last_decoder = self.model.last_decoder |
| | | self.bias_output = self.model.bias_output |
| | | self.dropout = self.model.dropout |
| | | |
| | | |
| | | # in embed |
| | | "{}.embed.0.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/w_embs".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (4235,256),(4235,256) |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | # out layer |
| | | "{}.output_layer.weight".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)], |
| | | "squeeze": [None, None], |
| | | "transpose": [(1, 0), None], |
| | | }, # (4235,256),(256,4235) |
| | | "{}.output_layer.bias".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/bias".format(tensor_name_prefix_tf), |
| | | "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"], |
| | | "squeeze": [None, None], |
| | | "transpose": [None, None], |
| | | }, # (4235,),(4235,) |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | bias_embed: torch.Tensor, |
| | | ): |
| | | |
| | | ## clas decoder |
| | | # src att |
| | | "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # dnn |
| | | "{}.bias_output.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": (2, 1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | tgt = ys_in_pad |
| | | tgt_mask = self.make_pad_mask(ys_in_lens) |
| | | tgt_mask, _ = self.prepare_mask(tgt_mask) |
| | | # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | } |
| | | return map_dict_local |
| | | memory = hs_pad |
| | | memory_mask = self.make_pad_mask(hlens) |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | def convert_tf2torch(self, |
| | | var_dict_tf, |
| | | var_dict_torch, |
| | | ): |
| | | map_dict = self.gen_tf2torch_map_dict() |
| | | var_dict_torch_update = dict() |
| | | decoder_layeridx_sets = set() |
| | | for name in sorted(var_dict_torch.keys(), reverse=False): |
| | | names = name.split('.') |
| | | if names[0] == self.tf2torch_tensor_name_prefix_torch: |
| | | if names[1] == "decoders": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | elif names[1] == "last_decoder": |
| | | layeridx = 15 |
| | | name_q = name.replace("last_decoder", "decoders.layeridx") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | _, _, x_self_attn, x_src_attn = self.last_decoder( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | elif names[1] == "decoders2": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | name_q = name_q.replace("decoders2", "decoders") |
| | | layeridx_bias = len(decoder_layeridx_sets) |
| | | # contextual paraformer related |
| | | contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0]) |
| | | # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] |
| | | contextual_mask = self.make_pad_mask(contextual_length) |
| | | contextual_mask, _ = self.prepare_mask(contextual_mask) |
| | | # import pdb; pdb.set_trace() |
| | | contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1) |
| | | cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask) |
| | | |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | if self.bias_output is not None: |
| | | x = torch.cat([x_src_attn, cx], dim=2) |
| | | x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D |
| | | x = x_self_attn + self.dropout(x) |
| | | |
| | | elif names[1] == "decoders3": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | if self.model.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders2( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders3( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x = self.after_norm(x) |
| | | x = self.output_layer(x) |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | elif names[1] == "bias_decoder": |
| | | name_q = name |
| | | return x, ys_in_lens |
| | | |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | |
| | | elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output": |
| | | name_tf = map_dict[name]["name"] |
| | | if isinstance(name_tf, list): |
| | | idx_list = 0 |
| | | if name_tf[idx_list] in var_dict_tf.keys(): |
| | | pass |
| | | else: |
| | | idx_list = 1 |
| | | data_tf = var_dict_tf[name_tf[idx_list]] |
| | | if map_dict[name]["squeeze"][idx_list] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list]) |
| | | if map_dict[name]["transpose"][idx_list] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), |
| | | name_tf[idx_list], |
| | | var_dict_tf[name_tf[ |
| | | idx_list]].shape)) |
| | | |
| | | else: |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) |
| | | if map_dict[name]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "after_norm": |
| | | name_tf = map_dict[name]["name"] |
| | | data_tf = var_dict_tf[name_tf] |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "embed_concat_ffn": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |