From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/modules/eend_ola/encoder.py |   36 ++++++++++++++----------------------
 1 files changed, 14 insertions(+), 22 deletions(-)

diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py
index 17d11ac..3065884 100644
--- a/funasr/modules/eend_ola/encoder.py
+++ b/funasr/modules/eend_ola/encoder.py
@@ -1,5 +1,5 @@
 import math
-import numpy as np
+
 import torch
 import torch.nn.functional as F
 from torch import nn
@@ -81,10 +81,17 @@
         return self.dropout(x)
 
 
-class TransformerEncoder(nn.Module):
-    def __init__(self, idim, n_layers, n_units,
-                 e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
-        super(TransformerEncoder, self).__init__()
+class EENDOLATransformerEncoder(nn.Module):
+    def __init__(self,
+                 idim: int,
+                 n_layers: int,
+                 n_units: int,
+                 e_units: int = 2048,
+                 h: int = 4,
+                 dropout_rate: float = 0.1,
+                 use_pos_emb: bool = False):
+        super(EENDOLATransformerEncoder, self).__init__()
+        self.linear_in = nn.Linear(idim, n_units)
         self.lnorm_in = nn.LayerNorm(n_units)
         self.n_layers = n_layers
         self.dropout = nn.Dropout(dropout_rate)
@@ -98,25 +105,10 @@
             setattr(self, '{}{:d}'.format("ff_", i),
                     PositionwiseFeedForward(n_units, e_units, dropout_rate))
         self.lnorm_out = nn.LayerNorm(n_units)
-        if use_pos_emb:
-            self.pos_enc = torch.nn.Sequential(
-                torch.nn.Linear(idim, n_units),
-                torch.nn.LayerNorm(n_units),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                PositionalEncoding(n_units, dropout_rate),
-            )
-        else:
-            self.linear_in = nn.Linear(idim, n_units)
-            self.pos_enc = None
 
     def __call__(self, x, x_mask=None):
         BT_size = x.shape[0] * x.shape[1]
-        if self.pos_enc is not None:
-            e = self.pos_enc(x)
-            e = e.view(BT_size, -1)
-        else:
-            e = self.linear_in(x.reshape(BT_size, -1))
+        e = self.linear_in(x.reshape(BT_size, -1))
         for i in range(self.n_layers):
             e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
             s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
@@ -124,4 +116,4 @@
             e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
             s = getattr(self, '{}{:d}'.format("ff_", i))(e)
             e = e + self.dropout(s)
-        return self.lnorm_out(e)
+        return self.lnorm_out(e)
\ No newline at end of file

--
Gitblit v1.9.1