From fc9595625855be5b63f86a38ac785e49c142c0ae Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期二, 21 三月 2023 14:10:03 +0800
Subject: [PATCH] embed debug

---
 funasr/models_transducer/encoder/blocks/conv_input.py |   15 ++++++++-------
 1 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py
index c68c73b..ffec93e 100644
--- a/funasr/models_transducer/encoder/blocks/conv_input.py
+++ b/funasr/models_transducer/encoder/blocks/conv_input.py
@@ -146,30 +146,31 @@
         if mask is not None:
             mask = self.create_new_mask(mask)
             olens = max(mask.eq(0).sum(1))
-        
-        b, t_input, f = x.size()
+
+        b, t, f = x.size()
         x = x.unsqueeze(1) # (b. 1. t. f)
+
         if chunk_size is not None:
             max_input_length = int(
-                chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) ))
+                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
             )
             x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
             x = list(x)
             x = torch.stack(x, dim=0)
             N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
             x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
         x = self.conv(x)
 
-        _, c, t, f = x.size()
-        
+        _, c, _, f = x.size()
         if chunk_size is not None:
             x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
         else:
-            x = x.transpose(1, 2).contiguous().view(b, t, c * f)
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
 
         if self.output is not None:
             x = self.output(x)
-        
+
         return x, mask[:,:olens][:,:x.size(1)]
 
     def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:

--
Gitblit v1.9.1