From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期一, 24 六月 2024 11:55:17 +0800 Subject: [PATCH] fixbug hotwords --- funasr/models/sense_voice/encoder.py | 100 ++++++++++++++++++++++--------------------------- 1 files changed, 45 insertions(+), 55 deletions(-) diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py index 3870c52..d464f1c 100644 --- a/funasr/models/sense_voice/encoder.py +++ b/funasr/models/sense_voice/encoder.py @@ -8,60 +8,50 @@ def sense_voice_encode_forward( - self, - x: torch.Tensor, - ilens: torch.Tensor = None, - **kwargs, + self, + x: torch.Tensor, + ilens: torch.Tensor = None, + **kwargs, ): - use_padmask = self.use_padmask - x = F.gelu(self.conv1(x)) - x = F.gelu(self.conv2(x)) - x = x.permute(0, 2, 1) - - n_frames = x.size(1) - max_pos = self.positional_embedding.size(0) - max_pos = n_frames if n_frames < max_pos else max_pos - x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype) - - - if ilens is not None: - if self.downsample_rate == 4: - olens = ( - 1 - + ( - ilens - - self.conv1.kernel_size[0] - + 2 * self.conv1.padding[0] - ) - // self.conv1.stride[0] - ) - else: - olens = ilens - olens = ( - 1 - + ( - olens - - self.conv2.kernel_size[0] - + 2 * self.conv2.padding[0] - ) - // self.conv2.stride[0] - ) - olens = torch.clamp(olens, max=max_pos) - else: - olens = None - - if use_padmask and olens is not None: - padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device) - else: - padding_mask = None - - for layer, block in enumerate(self.blocks): - x = block(x, mask=padding_mask, is_pad_mask=True) - + use_padmask = self.use_padmask + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) - x = self.ln_post(x) - - if ilens is None: - return x - else: - return x, olens + n_frames = x.size(1) + max_pos = self.positional_embedding.size(0) + max_pos = n_frames if n_frames < max_pos else max_pos + x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype) + + if ilens is not None: + if self.downsample_rate == 4: + olens = ( + 1 + + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0]) + // self.conv1.stride[0] + ) + else: + olens = ilens + olens = ( + 1 + + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0]) + // self.conv2.stride[0] + ) + olens = torch.clamp(olens, max=max_pos) + else: + olens = None + + if use_padmask and olens is not None: + padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device) + else: + padding_mask = None + + for layer, block in enumerate(self.blocks): + x = block(x, mask=padding_mask, is_pad_mask=True) + + x = self.ln_post(x) + + if ilens is None: + return x + else: + return x, olens -- Gitblit v1.9.1