From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/models/scama/sanm_decoder.py | 552 +++++++++++++++++++++++++++++++-----------------------
1 files changed, 314 insertions(+), 238 deletions(-)
diff --git a/funasr/models/paraformer/contextual_decoder.py b/funasr/models/scama/sanm_decoder.py
similarity index 70%
copy from funasr/models/paraformer/contextual_decoder.py
copy to funasr/models/scama/sanm_decoder.py
index 626cdef..53423d0 100644
--- a/funasr/models/paraformer/contextual_decoder.py
+++ b/funasr/models/scama/sanm_decoder.py
@@ -6,17 +6,38 @@
import numpy as np
from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from funasr.models.transformer.decoder import BaseTransformerDecoder
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
-class ContextualDecoderLayer(nn.Module):
+ """
+
def __init__(
self,
size,
@@ -28,7 +49,7 @@
concat_after=False,
):
"""Construct an DecoderLayer object."""
- super(ContextualDecoderLayer, self).__init__()
+ super(DecoderLayerSANM, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
@@ -45,85 +66,161 @@
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
- def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
# tgt = self.dropout(tgt)
- if isinstance(tgt, Tuple):
- tgt, _ = tgt
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
- if self.normalize_before:
- tgt = self.norm2(tgt)
- if self.training:
- cache = None
- x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
- x = residual + self.dropout(x)
- x_self_attn = x
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, _ = self.self_attn(tgt, tgt_mask)
+ x = residual + self.dropout(x)
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x = self.src_attn(x, memory, memory_mask)
- x_src_attn = x
-
- x = residual + self.dropout(x)
- return x, tgt_mask, x_self_attn, x_src_attn
-
-
-class ContextualBiasDecoder(nn.Module):
- def __init__(
- self,
- size,
- src_attn,
- dropout_rate,
- normalize_before=True,
- ):
- """Construct an DecoderLayer object."""
- super(ContextualBiasDecoder, self).__init__()
- self.size = size
- self.src_attn = src_attn
- if src_attn is not None:
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
-
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
- x = tgt
if self.src_attn is not None:
+ residual = x
if self.normalize_before:
x = self.norm3(x)
- x = self.dropout(self.src_attn(x, memory, memory_mask))
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
return x, tgt_mask, memory, memory_mask, cache
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
-class ContextualParaformerDecoder(ParaformerSANMDecoder):
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+ x = residual + x
+
+ return x, memory, fsmn_cache, opt_cache
+
+@register_class("decoder_classes", "FsmnDecoderSCAMAOpt")
+class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
+
"""
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = 0,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = None,
+ concat_embeds: bool = False,
+ attention_dim: int = None,
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ embed_tensor_name_prefix_tf: str = None,
):
super().__init__(
vocab_size=vocab_size,
@@ -135,14 +232,12 @@
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
+ if attention_dim is None:
+ attention_dim = encoder_output_size
- attention_dim = encoder_output_size
- if input_layer == 'none':
- self.embed = None
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
- # pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
@@ -168,14 +263,14 @@
if sanm_shfit is None:
sanm_shfit = (kernel_size - 1) // 2
self.decoders = repeat(
- att_layer_num - 1,
+ att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
+ attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
@@ -183,29 +278,6 @@
concat_after,
),
)
- self.dropout = nn.Dropout(dropout_rate)
- self.bias_decoder = ContextualBiasDecoder(
- size=attention_dim,
- src_attn=MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- dropout_rate=dropout_rate,
- normalize_before=True,
- )
- self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
- self.last_decoder = ContextualDecoderLayer(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
if num_blocks - att_layer_num <= 0:
self.decoders2 = None
else:
@@ -214,7 +286,7 @@
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -236,16 +308,35 @@
concat_after,
),
)
+ if concat_embeds:
+ self.embed_concat_ffn = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim + encoder_output_size,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
+ adim=attention_dim),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ else:
+ self.embed_concat_ffn = None
+ self.concat_embeds = concat_embeds
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- contextual_info: torch.Tensor,
- clas_scale: float = 1.0,
- return_hidden: bool = False,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ chunk_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -269,46 +360,122 @@
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ if chunk_mask is not None:
+ memory_mask = memory_mask * chunk_mask
+ if tgt_mask.size(1) != memory_mask.size(1):
+ memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
- x = tgt
+ x = self.embed(tgt)
+
+ if pre_acoustic_embeds is not None and self.concat_embeds:
+ x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+ x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask
)
- _, _, x_self_attn, x_src_attn = self.last_decoder(
- x, tgt_mask, memory, memory_mask
- )
-
- # contextual paraformer related
- contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
- contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
- cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
-
- if self.bias_output is not None:
- x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
- x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
- x = x_self_attn + self.dropout(x)
-
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
x, tgt_mask, memory, memory_mask
)
-
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
- olens = tgt_mask.sum(1)
- if self.output_layer is not None and return_hidden is False:
+ if self.output_layer is not None:
x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
return x, olens
- def gen_tf2torch_map_dict(self):
+ def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+ """Score."""
+ ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
+ cache=state
+ )
+ return logp.squeeze(0), state
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+
+ x = tgt[:, -1:]
+ tgt_mask = None
+ x = self.embed(x)
+
+ if pre_acoustic_embeds is not None and self.concat_embeds:
+ x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+ x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
+ if cache is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ cache = [None] * cache_layer_num
+ new_cache = []
+ # for c, decoder in zip(cache, self.decoders):
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(c_ret)
+
+ if self.num_blocks - self.att_layer_num >= 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ c = cache[j]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(c_ret)
+
+ for decoder in self.decoders3:
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=None
+ )
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = self.output_layer(y)
+ y = torch.log_softmax(y, dim=-1)
+
+ return y, new_cache
+
+ def gen_tf2torch_map_dict(self):
+
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
map_dict_local = {
-
+
## decoder
# ffn
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
@@ -346,7 +513,7 @@
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
-
+
# fsmn
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
@@ -443,7 +610,7 @@
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
-
+
# embed_concat_ffn
"{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -480,7 +647,7 @@
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
-
+
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -492,17 +659,18 @@
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
-
+
# in embed
"{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+ {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (4235,256),(4235,256)
-
+
# out layer
"{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+ {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
+ "{}/w_embs".format(embed_tensor_name_prefix_tf)],
"squeeze": [None, None],
"transpose": [(1, 0), None],
}, # (4235,256),(256,4235)
@@ -512,56 +680,7 @@
"squeeze": [None, None],
"transpose": [None, None],
}, # (4235,),(4235,)
-
- ## clas decoder
- # src att
- "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.bias_output.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (1024,256),(1,256,1024)
-
+
}
return map_dict_local
@@ -569,6 +688,7 @@
var_dict_tf,
var_dict_torch,
):
+
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
decoder_layeridx_sets = set()
@@ -598,37 +718,13 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
- elif names[1] == "last_decoder":
- layeridx = 15
- name_q = name.replace("last_decoder", "decoders.layeridx")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
-
+
elif names[1] == "decoders2":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("decoders2", "decoders")
layeridx_bias = len(decoder_layeridx_sets)
-
+
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
@@ -649,11 +745,11 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
-
+
elif names[1] == "decoders3":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
@@ -675,29 +771,8 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
- elif names[1] == "bias_decoder":
- name_q = name
-
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
-
- elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+
+ elif names[1] == "embed" or names[1] == "output_layer":
name_tf = map_dict[name]["name"]
if isinstance(name_tf, list):
idx_list = 0
@@ -720,7 +795,7 @@
name_tf[idx_list],
var_dict_tf[name_tf[
idx_list]].shape))
-
+
else:
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
@@ -736,7 +811,7 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
-
+
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
@@ -745,11 +820,11 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
-
+
elif names[1] == "embed_concat_ffn":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
@@ -771,5 +846,6 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
-
+
return var_dict_torch_update
+
--
Gitblit v1.9.1