From 012903e42ec890ab5c50137beb365c3d94e731d1 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 30 六月 2023 11:21:28 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR

---
 funasr/models/encoder/conformer_encoder.py |   58 ++++++++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 50 insertions(+), 8 deletions(-)

diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 5f20dee..e5fac62 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -12,7 +12,6 @@
 
 import torch
 from torch import nn
-from typeguard import check_argument_types
 
 from funasr.models.ctc import CTC
 from funasr.modules.attention import (
@@ -533,7 +532,6 @@
             interctc_use_conditioning: bool = False,
             stochastic_depth_rate: Union[float, List[float]] = 0.0,
     ):
-        assert check_argument_types()
         super().__init__()
         self._output_size = output_size
 
@@ -943,7 +941,6 @@
         """Construct an Encoder object."""
         super().__init__()
 
-        assert check_argument_types()
 
         self.embed = StreamingConvInput(
             input_size,
@@ -1081,7 +1078,10 @@
         mask = make_source_mask(x_len).to(x.device)
 
         if self.unified_model_training:
-            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
             chunk_mask = make_chunk_mask(
@@ -1113,12 +1113,15 @@
 
         elif self.dynamic_chunk_training:
             max_len = x.size(1)
-            chunk_size = torch.randint(1, max_len, (1,)).item()
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
 
-            if chunk_size > (max_len * self.short_chunk_threshold):
-                chunk_size = max_len
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
             else:
-                chunk_size = (chunk_size % self.short_chunk_size) + 1
+                chunk_size = self.default_chunk_size
 
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
@@ -1147,6 +1150,45 @@
 
         return x, olens, None
 
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
     def simu_chunk_forward(
         self,
         x: torch.Tensor,

--
Gitblit v1.9.1