aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
funasr/models/encoder/chunk_encoder.py
File was renamed from funasr/models_transducer/encoder/encoder.py
@@ -1,26 +1,23 @@
"""Encoder for Transducer model."""
from typing import Any, Dict, List, Tuple
import torch
from typeguard import check_argument_types
from funasr.models_transducer.encoder.building import (
from funasr.models.encoder.chunk_encoder_utils.building import (
    build_body_blocks,
    build_input_block,
    build_main_parameters,
    build_positional_encoding,
)
from funasr.models_transducer.encoder.validation import validate_architecture
from funasr.models_transducer.utils import (
from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
from funasr.modules.nets_utils import (
    TooShortUttError,
    check_short_utt,
    make_chunk_mask,
    make_source_mask,
)
class Encoder(torch.nn.Module):
class ChunkEncoder(torch.nn.Module):
    """Encoder module definition.
    Args:
@@ -61,10 +58,9 @@
        self.unified_model_training = main_params["unified_model_training"]
        self.default_chunk_size = main_params["default_chunk_size"]
        self.jitter_range = main_params["jitter_range"]
        self.jitter_range = main_params["jitter_range"]
        self.time_reduction_factor = main_params["time_reduction_factor"]
        self.time_reduction_factor = main_params["time_reduction_factor"]
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
@@ -79,7 +75,7 @@
        """
        return self.embed.get_size_before_subsampling(size) * hop_length
    def get_encoder_input_size(self, size: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
@@ -157,7 +153,7 @@
                mask,
                chunk_mask=chunk_mask,
            )
            olens = mask.eq(0).sum(1)
            if self.time_reduction_factor > 1:
                x_utt = x_utt[:,::self.time_reduction_factor,:]
@@ -194,14 +190,14 @@
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
        return x, olens
    def simu_chunk_forward(
        self,
        x: torch.Tensor,
@@ -290,7 +286,7 @@
        if right_context > 0:
            x = x[:, 0:-right_context, :]
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
        return x