| | |
| | | from funasr.models.scama import utils as myutils |
| | | from funasr.models.transformer.decoder import BaseTransformerDecoder |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt |
| | | from funasr.models.sanm.attention import ( |
| | | MultiHeadedAttentionSANMDecoder, |
| | | MultiHeadedAttentionCrossAtt, |
| | | ) |
| | | from funasr.models.transformer.embedding import PositionalEncoding |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class DecoderLayerSANM(nn.Module): |
| | | """Single decoder layer module. |
| | |
| | | |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) |
| | | |
| | | |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0): |
| | | def forward_chunk( |
| | | self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0 |
| | | ): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | |
| | | |
| | | return x, memory, fsmn_cache, opt_cache |
| | | |
| | | |
| | | @tables.register("decoder_classes", "FsmnDecoderSCAMAOpt") |
| | | class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): |
| | | """ |
| | |
| | | SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01712 |
| | | """ |
| | | |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = None, |
| | | concat_embeds: bool = False, |
| | | attention_dim: int = None, |
| | | tf2torch_tensor_name_prefix_torch: str = "decoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder", |
| | | embed_tensor_name_prefix_tf: str = None, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = None, |
| | | concat_embeds: bool = False, |
| | | attention_dim: int = None, |
| | | tf2torch_tensor_name_prefix_torch: str = "decoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder", |
| | | embed_tensor_name_prefix_tf: str = None, |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size |
| | | attention_heads, |
| | | attention_dim, |
| | | src_attention_dropout_rate, |
| | | encoder_output_size=encoder_output_size, |
| | | ), |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, |
| | | self_attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit=sanm_shfit, |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | |
| | | attention_dim + encoder_output_size, |
| | | None, |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate, |
| | | adim=attention_dim), |
| | | PositionwiseFeedForwardDecoderSANM( |
| | | attention_dim + encoder_output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | adim=attention_dim, |
| | | ), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | |
| | | self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | chunk_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | chunk_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | |
| | | x = torch.cat((x, pre_acoustic_embeds), dim=-1) |
| | | x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None) |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask) |
| | | if self.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders2( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders3( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | |
| | | olens = tgt_mask.sum(1) |
| | | return x, olens |
| | | |
| | | def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ): |
| | | def score( |
| | | self, |
| | | ys, |
| | | state, |
| | | x, |
| | | x_mask=None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | ): |
| | | """Score.""" |
| | | ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None] |
| | | ys_mask = myutils.sequence_mask( |
| | | torch.tensor([len(ys)], dtype=torch.int32), device=x.device |
| | | )[:, :, None] |
| | | logp, state = self.forward_one_step( |
| | | ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, |
| | | cache=state |
| | | ys.unsqueeze(0), |
| | | ys_mask, |
| | | x.unsqueeze(0), |
| | | memory_mask=x_mask, |
| | | pre_acoustic_embeds=pre_acoustic_embeds, |
| | | cache=state, |
| | | ) |
| | | return logp.squeeze(0), state |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | memory_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | cache: List[torch.Tensor] = None, |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | memory_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | """Forward one step. |
| | | |
| | |
| | | y = torch.log_softmax(y, dim=-1) |
| | | |
| | | return y, new_cache |
| | | |
| | | def gen_tf2torch_map_dict(self): |
| | | |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch |
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf |
| | | embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf |
| | | map_dict_local = { |
| | | |
| | | ## decoder |
| | | # ffn |
| | | "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # fsmn |
| | | "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 2, 0), |
| | | }, # (256,1,31),(1,31,256,1) |
| | | # src att |
| | | "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # dnn |
| | | "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # embed_concat_ffn |
| | | "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # out norm |
| | | "{}.after_norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.after_norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | |
| | | # in embed |
| | | "{}.embed.0.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (4235,256),(4235,256) |
| | | |
| | | # out layer |
| | | "{}.output_layer.weight".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), |
| | | "{}/w_embs".format(embed_tensor_name_prefix_tf)], |
| | | "squeeze": [None, None], |
| | | "transpose": [(1, 0), None], |
| | | }, # (4235,256),(256,4235) |
| | | "{}.output_layer.bias".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/bias".format(tensor_name_prefix_tf), |
| | | "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"], |
| | | "squeeze": [None, None], |
| | | "transpose": [None, None], |
| | | }, # (4235,),(4235,) |
| | | |
| | | } |
| | | return map_dict_local |
| | | |
| | | def convert_tf2torch(self, |
| | | var_dict_tf, |
| | | var_dict_torch, |
| | | ): |
| | | |
| | | map_dict = self.gen_tf2torch_map_dict() |
| | | var_dict_torch_update = dict() |
| | | decoder_layeridx_sets = set() |
| | | for name in sorted(var_dict_torch.keys(), reverse=False): |
| | | names = name.split('.') |
| | | if names[0] == self.tf2torch_tensor_name_prefix_torch: |
| | | if names[1] == "decoders": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "decoders2": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | name_q = name_q.replace("decoders2", "decoders") |
| | | layeridx_bias = len(decoder_layeridx_sets) |
| | | |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "decoders3": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "embed" or names[1] == "output_layer": |
| | | name_tf = map_dict[name]["name"] |
| | | if isinstance(name_tf, list): |
| | | idx_list = 0 |
| | | if name_tf[idx_list] in var_dict_tf.keys(): |
| | | pass |
| | | else: |
| | | idx_list = 1 |
| | | data_tf = var_dict_tf[name_tf[idx_list]] |
| | | if map_dict[name]["squeeze"][idx_list] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list]) |
| | | if map_dict[name]["transpose"][idx_list] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), |
| | | name_tf[idx_list], |
| | | var_dict_tf[name_tf[ |
| | | idx_list]].shape)) |
| | | |
| | | else: |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) |
| | | if map_dict[name]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "after_norm": |
| | | name_tf = map_dict[name]["name"] |
| | | data_tf = var_dict_tf[name_tf] |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "embed_concat_ffn": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | |