From 783a051f6586d392bec6e1494607f79635b71741 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 09:56:55 +0800
Subject: [PATCH] Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed merge
---
funasr/models/sense_voice/whisper_lib/model.py | 19 +++++++++++++++++--
1 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 8b3d3ab..3d0d6a8 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
n_text_layer: int
+# class LayerNorm(nn.LayerNorm):
+# def forward(self, x: Tensor) -> Tensor:
+# return super().forward(x.float()).type(x.dtype)
+
+
class LayerNorm(nn.LayerNorm):
- def forward(self, x: Tensor) -> Tensor:
- return super().forward(x.float()).type(x.dtype)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
class Linear(nn.Linear):
--
Gitblit v1.9.1