From c880db53646ab9fd26417f4baf004ab44cc24e1a Mon Sep 17 00:00:00 2001
From: lingji-yidong <75744976+lingji-yidong@users.noreply.github.com>
Date: 星期五, 28 六月 2024 01:28:24 +0800
Subject: [PATCH] Fix: Return tuple ('', []) when char_list is empty to prevent ValueError (#1857)
---
funasr/models/sense_voice/encoder.py | 100 ++++++++++++++++++++++---------------------------
1 files changed, 45 insertions(+), 55 deletions(-)
diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py
index 3870c52..d464f1c 100644
--- a/funasr/models/sense_voice/encoder.py
+++ b/funasr/models/sense_voice/encoder.py
@@ -8,60 +8,50 @@
def sense_voice_encode_forward(
- self,
- x: torch.Tensor,
- ilens: torch.Tensor = None,
- **kwargs,
+ self,
+ x: torch.Tensor,
+ ilens: torch.Tensor = None,
+ **kwargs,
):
- use_padmask = self.use_padmask
- x = F.gelu(self.conv1(x))
- x = F.gelu(self.conv2(x))
- x = x.permute(0, 2, 1)
-
- n_frames = x.size(1)
- max_pos = self.positional_embedding.size(0)
- max_pos = n_frames if n_frames < max_pos else max_pos
- x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
-
-
- if ilens is not None:
- if self.downsample_rate == 4:
- olens = (
- 1
- + (
- ilens
- - self.conv1.kernel_size[0]
- + 2 * self.conv1.padding[0]
- )
- // self.conv1.stride[0]
- )
- else:
- olens = ilens
- olens = (
- 1
- + (
- olens
- - self.conv2.kernel_size[0]
- + 2 * self.conv2.padding[0]
- )
- // self.conv2.stride[0]
- )
- olens = torch.clamp(olens, max=max_pos)
- else:
- olens = None
-
- if use_padmask and olens is not None:
- padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- else:
- padding_mask = None
-
- for layer, block in enumerate(self.blocks):
- x = block(x, mask=padding_mask, is_pad_mask=True)
-
+ use_padmask = self.use_padmask
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
- x = self.ln_post(x)
-
- if ilens is None:
- return x
- else:
- return x, olens
+ n_frames = x.size(1)
+ max_pos = self.positional_embedding.size(0)
+ max_pos = n_frames if n_frames < max_pos else max_pos
+ x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
+
+ if ilens is not None:
+ if self.downsample_rate == 4:
+ olens = (
+ 1
+ + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
+ // self.conv1.stride[0]
+ )
+ else:
+ olens = ilens
+ olens = (
+ 1
+ + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
+ // self.conv2.stride[0]
+ )
+ olens = torch.clamp(olens, max=max_pos)
+ else:
+ olens = None
+
+ if use_padmask and olens is not None:
+ padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+ else:
+ padding_mask = None
+
+ for layer, block in enumerate(self.blocks):
+ x = block(x, mask=padding_mask, is_pad_mask=True)
+
+ x = self.ln_post(x)
+
+ if ilens is None:
+ return x
+ else:
+ return x, olens
--
Gitblit v1.9.1