funasr/models/transformer/positionwise_feed_forward.py
@@ -34,3 +34,16 @@ return self.w_2(self.dropout(self.activation(self.w_1(x)))) class PositionwiseFeedForwardDecoderSANMExport(torch.nn.Module): def __init__(self, model): super().__init__() self.w_1 = model.w_1 self.w_2 = model.w_2 self.activation = model.activation self.norm = model.norm def forward(self, x): x = self.activation(self.w_1(x)) x = self.w_2(self.norm(x)) return x