Zhihao Du
2023-03-16 38de2af5bf9976d2f14f087d9a0d31991daf6783
funasr/modules/eend_ola/encoder.py
@@ -1,5 +1,5 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
@@ -81,10 +81,16 @@
        return self.dropout(x)
class TransformerEncoder(nn.Module):
    def __init__(self, idim, n_layers, n_units,
                 e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
        super(TransformerEncoder, self).__init__()
class EENDOLATransformerEncoder(nn.Module):
    def __init__(self,
                 idim: int,
                 n_layers: int,
                 n_units: int,
                 e_units: int = 2048,
                 h: int = 4,
                 dropout_rate: float = 0.1,
                 use_pos_emb: bool = False):
        super(EENDOLATransformerEncoder, self).__init__()
        self.lnorm_in = nn.LayerNorm(n_units)
        self.n_layers = n_layers
        self.dropout = nn.Dropout(dropout_rate)