From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/mossformer/e2e_ss.py |   24 ++++++++++++------------
 1 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/funasr/models/mossformer/e2e_ss.py b/funasr/models/mossformer/e2e_ss.py
index 1a46b3f..40d30ca 100644
--- a/funasr/models/mossformer/e2e_ss.py
+++ b/funasr/models/mossformer/e2e_ss.py
@@ -4,7 +4,7 @@
 import torch.nn.functional as F
 import copy
 from funasr.models.base_model import FunASRModel
-from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet 
+from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
 from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
 
 
@@ -48,7 +48,9 @@
         super(MossFormer, self).__init__()
         self.num_spks = num_spks
         # Encoding
-        self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
+        self.enc = MossFormerEncoder(
+            kernel_size=kernel_size, out_channels=in_channels, in_channels=1
+        )
 
         ##Compute Mask
         self.mask_net = MossFormer_MaskNet(
@@ -62,12 +64,13 @@
             max_length=max_length,
         )
         self.dec = MossFormerDecoder(
-           in_channels=out_channels,
-           out_channels=1,
-           kernel_size=kernel_size,
-           stride = kernel_size//2,
-           bias=False
+            in_channels=out_channels,
+            out_channels=1,
+            kernel_size=kernel_size,
+            stride=kernel_size // 2,
+            bias=False,
         )
+
     def forward(self, input):
         x = self.enc(input)
         mask = self.mask_net(x)
@@ -76,10 +79,7 @@
 
         # Decoding
         est_source = torch.cat(
-            [
-                self.dec(sep_x[i]).unsqueeze(-1)
-                for i in range(self.num_spks)
-            ],
+            [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
             dim=-1,
         )
         T_origin = input.size(1)
@@ -91,5 +91,5 @@
 
         out = []
         for spk in range(self.num_spks):
-            out.append(est_source[:,:,spk])
+            out.append(est_source[:, :, spk])
         return out

--
Gitblit v1.9.1