From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/models/sa_asr/transformer_decoder.py | 442 +-----------------------------------------------------
1 files changed, 11 insertions(+), 431 deletions(-)
diff --git a/funasr/models/transformer/transformer_decoder.py b/funasr/models/sa_asr/transformer_decoder.py
similarity index 65%
rename from funasr/models/transformer/transformer_decoder.py
rename to funasr/models/sa_asr/transformer_decoder.py
index b2bea68..3319212 100644
--- a/funasr/models/transformer/transformer_decoder.py
+++ b/funasr/models/sa_asr/transformer_decoder.py
@@ -10,9 +10,9 @@
import torch
from torch import nn
-from funasr.models.decoder.abs_decoder import AbsDecoder
+
from funasr.models.transformer.attention import MultiHeadedAttention
-from funasr.models.transformer.attention import CosineDistanceAttention
+from funasr.models.sa_asr.attention import CosineDistanceAttention
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.embedding import PositionalEncoding
@@ -24,9 +24,10 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
+from funasr.utils.register import register_class, registry_tables
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -150,7 +151,7 @@
return x, tgt_mask, memory, memory_mask
-class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
@@ -352,7 +353,7 @@
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
-
+@register_class("decoder_classes", "TransformerDecoder")
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -401,6 +402,7 @@
)
+@register_class("decoder_classes", "ParaformerDecoderSAN")
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -514,7 +516,7 @@
else:
return x, olens
-
+@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -575,7 +577,7 @@
),
)
-
+@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -637,6 +639,7 @@
)
+@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -697,7 +700,7 @@
),
)
-
+@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -757,426 +760,3 @@
concat_after,
),
)
-
-class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
-
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- spker_embedding_dim: int = 256,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- input_layer: str = "embed",
- use_asr_output_layer: bool = True,
- use_spk_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- ):
- super().__init__()
- attention_dim = encoder_output_size
-
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_asr_output_layer:
- self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.asr_output_layer = None
-
- if use_spk_output_layer:
- self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
- else:
- self.spk_output_layer = None
-
- self.cos_distance_att = CosineDistanceAttention()
-
- self.decoder1 = None
- self.decoder2 = None
- self.decoder3 = None
- self.decoder4 = None
-
- def forward(
- self,
- asr_hs_pad: torch.Tensor,
- spk_hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- profile: torch.Tensor,
- profile_lens: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
-
- tgt = ys_in_pad
- # tgt_mask: (B, 1, L)
- tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
- # m: (1, L, L)
- m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
- # tgt_mask: (B, L, L)
- tgt_mask = tgt_mask & m
-
- asr_memory = asr_hs_pad
- spk_memory = spk_hs_pad
- memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
- # Spk decoder
- x = self.embed(tgt)
-
- x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
- x, tgt_mask, asr_memory, spk_memory, memory_mask
- )
- x, tgt_mask, spk_memory, memory_mask = self.decoder2(
- x, tgt_mask, spk_memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.spk_output_layer is not None:
- x = self.spk_output_layer(x)
- dn, weights = self.cos_distance_att(x, profile, profile_lens)
- # Asr decoder
- x, tgt_mask, asr_memory, memory_mask = self.decoder3(
- z, tgt_mask, asr_memory, memory_mask, dn
- )
- x, tgt_mask, asr_memory, memory_mask = self.decoder4(
- x, tgt_mask, asr_memory, memory_mask
- )
-
- if self.normalize_before:
- x = self.after_norm(x)
- if self.asr_output_layer is not None:
- x = self.asr_output_layer(x)
-
- olens = tgt_mask.sum(1)
- return x, weights, olens
-
-
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- asr_memory: torch.Tensor,
- spk_memory: torch.Tensor,
- profile: torch.Tensor,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-
- x = self.embed(tgt)
-
- if cache is None:
- cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
- new_cache = []
- x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
- x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
- )
- new_cache.append(x)
- for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
- x, tgt_mask, spk_memory, _ = decoder(
- x, tgt_mask, spk_memory, None, cache=c
- )
- new_cache.append(x)
- if self.normalize_before:
- x = self.after_norm(x)
- else:
- x = x
- if self.spk_output_layer is not None:
- x = self.spk_output_layer(x)
- dn, weights = self.cos_distance_att(x, profile, None)
-
- x, tgt_mask, asr_memory, _ = self.decoder3(
- z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
- )
- new_cache.append(x)
-
- for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
- x, tgt_mask, asr_memory, _ = decoder(
- x, tgt_mask, asr_memory, None, cache=c
- )
- new_cache.append(x)
-
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.asr_output_layer is not None:
- y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
-
- return y, weights, new_cache
-
- def score(self, ys, state, asr_enc, spk_enc, profile):
- """Score."""
- ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
- logp, weights, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
- )
- return logp.squeeze(0), weights.squeeze(), state
-
-class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- spker_embedding_dim: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- asr_num_blocks: int = 6,
- spk_num_blocks: int = 3,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_asr_output_layer: bool = True,
- use_spk_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- ):
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- spker_embedding_dim=spker_embedding_dim,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_asr_output_layer=use_asr_output_layer,
- use_spk_output_layer=use_spk_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
-
- attention_dim = encoder_output_size
-
- self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
- self.decoder2 = repeat(
- spk_num_blocks - 1,
- lambda lnum: DecoderLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
-
- self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
- attention_dim,
- spker_embedding_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
- self.decoder4 = repeat(
- asr_num_blocks - 1,
- lambda lnum: DecoderLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
-class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
-
- def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
- self.size = size
- self.self_attn = self_attn
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
-
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
- # compute only the last frame query keeping dim: max_time_out -> 1
- assert cache.shape == (
- tgt.shape[0],
- tgt.shape[1] - 1,
- self.size,
- ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
- x = residual + self.concat_linear1(tgt_concat)
- else:
- x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
- if not self.normalize_before:
- x = self.norm1(x)
- z = x
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
-
- if self.concat_after:
- x_concat = torch.cat(
- (x, skip), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x = residual + self.dropout(skip)
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
-
-class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
-
- def __init__(
- self,
- size,
- d_size,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
- self.size = size
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- self.spk_linear = nn.Linear(d_size, size, bias=False)
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
-
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
-
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- x = tgt_q
- if self.normalize_before:
- x = self.norm2(x)
- if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
- if not self.normalize_before:
- x = self.norm2(x)
- residual = x
-
- if dn!=None:
- x = x + self.spk_linear(dn)
- if self.normalize_before:
- x = self.norm3(x)
-
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm3(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, memory, memory_mask
\ No newline at end of file
--
Gitblit v1.9.1