| | |
| | | |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.joint_network import JointNetwork |
| | |
| | | |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.joint_network import JointNetwork |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | |
| | | StreamingRelPositionalEncoding, |
| | | ) |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.normalization import get_normalization |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | | from funasr.modules.multi_layer_conv import MultiLayeredConv1d |
| | | from funasr.modules.nets_utils import get_activation |
| | |
| | | default_chunk_size: int = 16, |
| | | jitter_range: int = 4, |
| | | subsampling_factor: int = 1, |
| | | **activation_parameters, |
| | | ) -> None: |
| | | """Construct an Encoder object.""" |
| | | super().__init__() |
| | |
| | | ) |
| | | |
| | | activation = get_activation( |
| | | activation_type, **activation_parameters |
| | | activation_type |
| | | ) |
| | | |
| | | pos_wise_args = ( |
| | |
| | | simplified_att_score, |
| | | ) |
| | | |
| | | norm_class, norm_args = get_normalization( |
| | | norm_type, |
| | | ) |
| | | |
| | | fn_modules = [] |
| | | for _ in range(num_blocks): |
| | |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | CausalConvolution(*conv_mod_args), |
| | | norm_class=norm_class, |
| | | norm_args=norm_args, |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | fn_modules.append(module) |
| | |
| | | self.encoders = MultiBlocks( |
| | | [fn() for fn in fn_modules], |
| | | output_size, |
| | | norm_class=norm_class, |
| | | norm_args=norm_args, |
| | | ) |
| | | |
| | | self.output_size = output_size |
| | |
| | | |
| | | import torch |
| | | |
| | | from funasr.modules.activation import get_activation |
| | | from funasr.modules.nets_utils import get_activation |
| | | |
| | | |
| | | class JointNetwork(torch.nn.Module): |
| | |
| | | decoder_size: int, |
| | | joint_space_size: int = 256, |
| | | joint_activation_type: str = "tanh", |
| | | **activation_parameters, |
| | | ) -> None: |
| | | """Construct a JointNetwork object.""" |
| | | super().__init__() |
| | |
| | | self.lin_out = torch.nn.Linear(joint_space_size, output_size) |
| | | |
| | | self.joint_activation = get_activation( |
| | | joint_activation_type, **activation_parameters |
| | | joint_activation_type |
| | | ) |
| | | |
| | | def forward( |
| File was renamed from funasr/models/rnnt_decoder/rnn_decoder.py |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.beam_search.beam_search_transducer import Hypothesis |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | |
| | | class RNNDecoder(AbsDecoder): |
| File was renamed from funasr/models/rnnt_decoder/stateless_decoder.py |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.beam_search.beam_search_transducer import Hypothesis |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | |
| | | class StatelessDecoder(AbsDecoder): |
| | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.joint_network import JointNetwork |
| | | |
| | | |
| | |
| | | import torch |
| | | |
| | | from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.joint_network import JointNetwork |
| | | |
| | | def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): |
| | |
| | | block_list: List[torch.nn.Module], |
| | | output_size: int, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_args: Optional[Dict] = None, |
| | | ) -> None: |
| | | """Construct a MultiBlocks object.""" |
| | | super().__init__() |
| | | |
| | | self.blocks = torch.nn.ModuleList(block_list) |
| | | self.norm_blocks = norm_class(output_size, **norm_args) |
| | | self.norm_blocks = norm_class(output_size) |
| | | |
| | | self.num_blocks = len(block_list) |
| | | |
| | |
| | | LightweightConvolutionTransformerDecoder, |
| | | TransformerDecoder, |
| | | ) |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder |
| | | from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder |
| | | from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder |
| | | from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder |
| | | from funasr.models.e2e_transducer import TransducerModel |
| | | from funasr.models.e2e_transducer_unified import UnifiedTransducerModel |