From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/mossformer/e2e_ss.py | 24 ++++++++++++------------
1 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/funasr/models/mossformer/e2e_ss.py b/funasr/models/mossformer/e2e_ss.py
index 1a46b3f..40d30ca 100644
--- a/funasr/models/mossformer/e2e_ss.py
+++ b/funasr/models/mossformer/e2e_ss.py
@@ -4,7 +4,7 @@
import torch.nn.functional as F
import copy
from funasr.models.base_model import FunASRModel
-from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
+from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
@@ -48,7 +48,9 @@
super(MossFormer, self).__init__()
self.num_spks = num_spks
# Encoding
- self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
+ self.enc = MossFormerEncoder(
+ kernel_size=kernel_size, out_channels=in_channels, in_channels=1
+ )
##Compute Mask
self.mask_net = MossFormer_MaskNet(
@@ -62,12 +64,13 @@
max_length=max_length,
)
self.dec = MossFormerDecoder(
- in_channels=out_channels,
- out_channels=1,
- kernel_size=kernel_size,
- stride = kernel_size//2,
- bias=False
+ in_channels=out_channels,
+ out_channels=1,
+ kernel_size=kernel_size,
+ stride=kernel_size // 2,
+ bias=False,
)
+
def forward(self, input):
x = self.enc(input)
mask = self.mask_net(x)
@@ -76,10 +79,7 @@
# Decoding
est_source = torch.cat(
- [
- self.dec(sep_x[i]).unsqueeze(-1)
- for i in range(self.num_spks)
- ],
+ [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
dim=-1,
)
T_origin = input.size(1)
@@ -91,5 +91,5 @@
out = []
for spk in range(self.num_spks):
- out.append(est_source[:,:,spk])
+ out.append(est_source[:, :, spk])
return out
--
Gitblit v1.9.1