Xingchen Song(宋星辰)
2024-06-11 0bd1a4d6a9893e45438505514a063d9deee91f21
funasr/models/llm_asr/adaptor.py
@@ -1,5 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@@ -84,13 +86,13 @@
        self.blocks = nn.ModuleList(
            [
                EncoderLayer(
                    output_size,
                    llm_dim,
                    MultiHeadedAttention(
                        kwargs.get("attention_heads", 8),
                        llm_dim,
                        kwargs.get("attention_dropout_rate", 0.0),
                    ),
                    positionwise_layer(
                    PositionwiseFeedForward(
                        llm_dim,
                        llm_dim // 4,
                        kwargs.get("dropout_rate", 0.0),
@@ -119,9 +121,8 @@
        x = self.linear2(x)
        olens = None
        if ilens is not None:
            olens = (ilens - 1) // self.k + 1
            mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
        olens = (ilens - 1) // self.k + 1
        masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
        for layer, block in enumerate(self.blocks):
            x, masks = block(x, masks)
        return x, olens