From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/models/encoder/chunk_encoder.py |   26 +++++++++++---------------
 1 files changed, 11 insertions(+), 15 deletions(-)

diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models/encoder/chunk_encoder.py
similarity index 95%
rename from funasr/models_transducer/encoder/encoder.py
rename to funasr/models/encoder/chunk_encoder.py
index b486a11..c6fc292 100644
--- a/funasr/models_transducer/encoder/encoder.py
+++ b/funasr/models/encoder/chunk_encoder.py
@@ -1,26 +1,23 @@
-"""Encoder for Transducer model."""
-
 from typing import Any, Dict, List, Tuple
 
 import torch
 from typeguard import check_argument_types
 
-from funasr.models_transducer.encoder.building import (
+from funasr.models.encoder.chunk_encoder_utils.building import (
     build_body_blocks,
     build_input_block,
     build_main_parameters,
     build_positional_encoding,
 )
-from funasr.models_transducer.encoder.validation import validate_architecture
-from funasr.models_transducer.utils import (
+from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
+from funasr.modules.nets_utils import (
     TooShortUttError,
     check_short_utt,
     make_chunk_mask,
     make_source_mask,
 )
 
-
-class Encoder(torch.nn.Module):
+class ChunkEncoder(torch.nn.Module):
     """Encoder module definition.
 
     Args:
@@ -61,10 +58,9 @@
 
         self.unified_model_training = main_params["unified_model_training"]
         self.default_chunk_size = main_params["default_chunk_size"]
-        self.jitter_range = main_params["jitter_range"]       
+        self.jitter_range = main_params["jitter_range"]
 
-        self.time_reduction_factor = main_params["time_reduction_factor"] 
-
+        self.time_reduction_factor = main_params["time_reduction_factor"]
     def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
         """Return the corresponding number of sample for a given chunk size, in frames.
 
@@ -79,7 +75,7 @@
 
         """
         return self.embed.get_size_before_subsampling(size) * hop_length
-    
+
     def get_encoder_input_size(self, size: int) -> int:
         """Return the corresponding number of sample for a given chunk size, in frames.
 
@@ -157,7 +153,7 @@
                 mask,
                 chunk_mask=chunk_mask,
             )
-       
+
             olens = mask.eq(0).sum(1)
             if self.time_reduction_factor > 1:
                 x_utt = x_utt[:,::self.time_reduction_factor,:]
@@ -194,14 +190,14 @@
             mask,
             chunk_mask=chunk_mask,
         )
-        
+
         olens = mask.eq(0).sum(1)
         if self.time_reduction_factor > 1:
             x = x[:,::self.time_reduction_factor,:]
             olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
 
         return x, olens
-     
+
     def simu_chunk_forward(
         self,
         x: torch.Tensor,
@@ -290,7 +286,7 @@
 
         if right_context > 0:
             x = x[:, 0:-right_context, :]
-        
+
         if self.time_reduction_factor > 1:
             x = x[:,::self.time_reduction_factor,:]
         return x

--
Gitblit v1.9.1