From cdf117b9746fdb72c6d0a2aa1ada4e1a131895ec Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期二, 27 六月 2023 09:59:50 +0800
Subject: [PATCH] bug fix (#667)

---
 funasr/models/encoder/conformer_encoder.py |   55 ++++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 50 insertions(+), 5 deletions(-)

diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 5f20dee..994607f 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -1081,7 +1081,10 @@
         mask = make_source_mask(x_len).to(x.device)
 
         if self.unified_model_training:
-            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
             chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@
 
         elif self.dynamic_chunk_training:
             max_len = x.size(1)
-            chunk_size = torch.randint(1, max_len, (1,)).item()
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
 
-            if chunk_size > (max_len * self.short_chunk_threshold):
-                chunk_size = max_len
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
             else:
-                chunk_size = (chunk_size % self.short_chunk_size) + 1
+                chunk_size = self.default_chunk_size
 
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@
 
         return x, olens, None
 
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
     def simu_chunk_forward(
         self,
         x: torch.Tensor,

--
Gitblit v1.9.1