shixian.shi
2024-01-17 b1857837dd5873a8308eb770c50f4fc4c8eab752
funasr/models/contextual_paraformer/decoder.py
@@ -1,22 +1,24 @@
from typing import List
from typing import Tuple
import logging
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import torch.nn as nn
import logging
import numpy as np
from funasr.models.scama import utils as myutils
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
from typing import Tuple
from funasr.register import tables
from funasr.models.scama import utils as myutils
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
class ContextualDecoderLayer(nn.Module):
class ContextualDecoderLayer(torch.nn.Module):
    def __init__(
        self,
        size,
@@ -38,12 +40,12 @@
            self.norm2 = LayerNorm(size)
        if src_attn is not None:
            self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
            self.concat_linear1 = torch.nn.Linear(size + size, size)
            self.concat_linear2 = torch.nn.Linear(size + size, size)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
        # tgt = self.dropout(tgt)
@@ -73,7 +75,7 @@
        return x, tgt_mask, x_self_attn, x_src_attn
class ContextualBiasDecoder(nn.Module):
class ContextualBiasDecoder(torch.nn.Module):
    def __init__(
        self,
        size,
@@ -87,7 +89,7 @@
        self.src_attn = src_attn
        if src_attn is not None:
            self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -183,7 +185,7 @@
                concat_after,
            ),
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.bias_decoder = ContextualBiasDecoder(
            size=attention_dim,
            src_attn=MultiHeadedAttentionCrossAtt(