shixian.shi
2024-01-15 97d648c255316ec1fff5d82e46749076faabdd2d
funasr/models/contextual_paraformer/decoder.py
@@ -1,22 +1,24 @@
from typing import List
from typing import Tuple
import logging
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import torch.nn as nn
import logging
import numpy as np
from funasr.models.scama import utils as myutils
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
from typing import Tuple
from funasr.register import tables
from funasr.models.scama import utils as myutils
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
class ContextualDecoderLayer(nn.Module):
class ContextualDecoderLayer(torch.nn.Module):
    def __init__(
        self,
        size,
@@ -38,12 +40,12 @@
            self.norm2 = LayerNorm(size)
        if src_attn is not None:
            self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
            self.concat_linear1 = torch.nn.Linear(size + size, size)
            self.concat_linear2 = torch.nn.Linear(size + size, size)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
        # tgt = self.dropout(tgt)
@@ -73,7 +75,7 @@
        return x, tgt_mask, x_self_attn, x_src_attn
class ContextualBiasDecoder(nn.Module):
class ContextualBiasDecoder(torch.nn.Module):
    def __init__(
        self,
        size,
@@ -87,7 +89,7 @@
        self.src_attn = src_attn
        if src_attn is not None:
            self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -183,7 +185,7 @@
                concat_after,
            ),
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.bias_decoder = ContextualBiasDecoder(
            size=attention_dim,
            src_attn=MultiHeadedAttentionCrossAtt(