From 2ae59b6ce06305724e2eaf30b9f9e93447a7832e Mon Sep 17 00:00:00 2001 From: 维石 <shixian.shi@alibaba-inc.com> Date: 星期一, 22 七月 2024 16:58:27 +0800 Subject: [PATCH] ONNX and torchscript export for sensevoice --- funasr/models/transformer/positionwise_feed_forward.py | 21 ++++++++++----------- 1 files changed, 10 insertions(+), 11 deletions(-) diff --git a/funasr/models/transformer/positionwise_feed_forward.py b/funasr/models/transformer/positionwise_feed_forward.py index 081ff5b..7cfa5f9 100644 --- a/funasr/models/transformer/positionwise_feed_forward.py +++ b/funasr/models/transformer/positionwise_feed_forward.py @@ -35,15 +35,14 @@ class PositionwiseFeedForwardDecoderSANMExport(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.w_1 = model.w_1 - self.w_2 = model.w_2 - self.activation = model.activation - self.norm = model.norm - - def forward(self, x): - x = self.activation(self.w_1(x)) - x = self.w_2(self.norm(x)) - return x + def __init__(self, model): + super().__init__() + self.w_1 = model.w_1 + self.w_2 = model.w_2 + self.activation = model.activation + self.norm = model.norm + def forward(self, x): + x = self.activation(self.w_1(x)) + x = self.w_2(self.norm(x)) + return x -- Gitblit v1.9.1