From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/encoder.py |  100 ++++++++++++++++++++++---------------------------
 1 files changed, 45 insertions(+), 55 deletions(-)

diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py
index 3870c52..d464f1c 100644
--- a/funasr/models/sense_voice/encoder.py
+++ b/funasr/models/sense_voice/encoder.py
@@ -8,60 +8,50 @@
 
 
 def sense_voice_encode_forward(
-	self,
-	x: torch.Tensor,
-	ilens: torch.Tensor = None,
-	**kwargs,
+    self,
+    x: torch.Tensor,
+    ilens: torch.Tensor = None,
+    **kwargs,
 ):
-	use_padmask = self.use_padmask
-	x = F.gelu(self.conv1(x))
-	x = F.gelu(self.conv2(x))
-	x = x.permute(0, 2, 1)
-	
-	n_frames = x.size(1)
-	max_pos = self.positional_embedding.size(0)
-	max_pos = n_frames if n_frames < max_pos else max_pos
-	x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
-	
-	
-	if ilens is not None:
-		if self.downsample_rate == 4:
-			olens = (
-				1
-				+ (
-					ilens
-					- self.conv1.kernel_size[0]
-					+ 2 * self.conv1.padding[0]
-				)
-				// self.conv1.stride[0]
-			)
-		else:
-			olens = ilens
-		olens = (
-			1
-			+ (
-				olens
-				- self.conv2.kernel_size[0]
-				+ 2 * self.conv2.padding[0]
-			)
-			// self.conv2.stride[0]
-		)
-		olens = torch.clamp(olens, max=max_pos)
-	else:
-		olens = None
-	
-	if use_padmask and olens is not None:
-		padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
-	else:
-		padding_mask = None
-	
-	for layer, block in enumerate(self.blocks):
-		x = block(x, mask=padding_mask, is_pad_mask=True)
-		
+    use_padmask = self.use_padmask
+    x = F.gelu(self.conv1(x))
+    x = F.gelu(self.conv2(x))
+    x = x.permute(0, 2, 1)
 
-	x = self.ln_post(x)
-	
-	if ilens is None:
-		return x
-	else:
-		return x, olens
+    n_frames = x.size(1)
+    max_pos = self.positional_embedding.size(0)
+    max_pos = n_frames if n_frames < max_pos else max_pos
+    x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
+
+    if ilens is not None:
+        if self.downsample_rate == 4:
+            olens = (
+                1
+                + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
+                // self.conv1.stride[0]
+            )
+        else:
+            olens = ilens
+        olens = (
+            1
+            + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
+            // self.conv2.stride[0]
+        )
+        olens = torch.clamp(olens, max=max_pos)
+    else:
+        olens = None
+
+    if use_padmask and olens is not None:
+        padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+    else:
+        padding_mask = None
+
+    for layer, block in enumerate(self.blocks):
+        x = block(x, mask=padding_mask, is_pad_mask=True)
+
+    x = self.ln_post(x)
+
+    if ilens is None:
+        return x
+    else:
+        return x, olens

--
Gitblit v1.9.1