From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg
---
funasr/models/encoder/chunk_encoder.py | 26 +++++++++++---------------
1 files changed, 11 insertions(+), 15 deletions(-)
diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models/encoder/chunk_encoder.py
similarity index 95%
rename from funasr/models_transducer/encoder/encoder.py
rename to funasr/models/encoder/chunk_encoder.py
index b486a11..c6fc292 100644
--- a/funasr/models_transducer/encoder/encoder.py
+++ b/funasr/models/encoder/chunk_encoder.py
@@ -1,26 +1,23 @@
-"""Encoder for Transducer model."""
-
from typing import Any, Dict, List, Tuple
import torch
from typeguard import check_argument_types
-from funasr.models_transducer.encoder.building import (
+from funasr.models.encoder.chunk_encoder_utils.building import (
build_body_blocks,
build_input_block,
build_main_parameters,
build_positional_encoding,
)
-from funasr.models_transducer.encoder.validation import validate_architecture
-from funasr.models_transducer.utils import (
+from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
+from funasr.modules.nets_utils import (
TooShortUttError,
check_short_utt,
make_chunk_mask,
make_source_mask,
)
-
-class Encoder(torch.nn.Module):
+class ChunkEncoder(torch.nn.Module):
"""Encoder module definition.
Args:
@@ -61,10 +58,9 @@
self.unified_model_training = main_params["unified_model_training"]
self.default_chunk_size = main_params["default_chunk_size"]
- self.jitter_range = main_params["jitter_range"]
+ self.jitter_range = main_params["jitter_range"]
- self.time_reduction_factor = main_params["time_reduction_factor"]
-
+ self.time_reduction_factor = main_params["time_reduction_factor"]
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@@ -79,7 +75,7 @@
"""
return self.embed.get_size_before_subsampling(size) * hop_length
-
+
def get_encoder_input_size(self, size: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
@@ -157,7 +153,7 @@
mask,
chunk_mask=chunk_mask,
)
-
+
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
@@ -194,14 +190,14 @@
mask,
chunk_mask=chunk_mask,
)
-
+
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens
-
+
def simu_chunk_forward(
self,
x: torch.Tensor,
@@ -290,7 +286,7 @@
if right_context > 0:
x = x[:, 0:-right_context, :]
-
+
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x
--
Gitblit v1.9.1