From 256035b6c1fa6115b6f33972ed243eb43f3e4299 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期五, 14 四月 2023 11:38:00 +0800
Subject: [PATCH] rnnt reorg
---
funasr/modules/embedding.py | 77 +++
funasr/models/e2e_transducer_unified.py | 2
/dev/null | 171 -------
funasr/models/encoder/conformer_encoder.py | 640 ++++++++++++++++++++++++++
funasr/modules/attention.py | 220 +++++++++
funasr/modules/repeat.py | 92 +++
funasr/modules/subsampling.py | 202 ++++++++
funasr/tasks/asr_transducer.py | 6
funasr/models/e2e_transducer.py | 2
funasr/modules/normalization.py | 0
10 files changed, 1,233 insertions(+), 179 deletions(-)
diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_transducer.py
index b669c9d..8630aec 100644
--- a/funasr/models/e2e_transducer.py
+++ b/funasr/models/e2e_transducer.py
@@ -12,7 +12,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_transducer_unified.py
index 6003542..124bc09 100644
--- a/funasr/models/e2e_transducer_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -11,7 +11,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
-from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/encoder/chunk_encoder.py b/funasr/models/encoder/chunk_encoder.py
deleted file mode 100644
index c6fc292..0000000
--- a/funasr/models/encoder/chunk_encoder.py
+++ /dev/null
@@ -1,292 +0,0 @@
-from typing import Any, Dict, List, Tuple
-
-import torch
-from typeguard import check_argument_types
-
-from funasr.models.encoder.chunk_encoder_utils.building import (
- build_body_blocks,
- build_input_block,
- build_main_parameters,
- build_positional_encoding,
-)
-from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
-from funasr.modules.nets_utils import (
- TooShortUttError,
- check_short_utt,
- make_chunk_mask,
- make_source_mask,
-)
-
-class ChunkEncoder(torch.nn.Module):
- """Encoder module definition.
-
- Args:
- input_size: Input size.
- body_conf: Encoder body configuration.
- input_conf: Encoder input configuration.
- main_conf: Encoder main configuration.
-
- """
-
- def __init__(
- self,
- input_size: int,
- body_conf: List[Dict[str, Any]],
- input_conf: Dict[str, Any] = {},
- main_conf: Dict[str, Any] = {},
- ) -> None:
- """Construct an Encoder object."""
- super().__init__()
-
- assert check_argument_types()
-
- embed_size, output_size = validate_architecture(
- input_conf, body_conf, input_size
- )
- main_params = build_main_parameters(**main_conf)
-
- self.embed = build_input_block(input_size, input_conf)
- self.pos_enc = build_positional_encoding(embed_size, main_params)
- self.encoders = build_body_blocks(body_conf, main_params, output_size)
-
- self.output_size = output_size
-
- self.dynamic_chunk_training = main_params["dynamic_chunk_training"]
- self.short_chunk_threshold = main_params["short_chunk_threshold"]
- self.short_chunk_size = main_params["short_chunk_size"]
- self.left_chunk_size = main_params["left_chunk_size"]
-
- self.unified_model_training = main_params["unified_model_training"]
- self.default_chunk_size = main_params["default_chunk_size"]
- self.jitter_range = main_params["jitter_range"]
-
- self.time_reduction_factor = main_params["time_reduction_factor"]
- def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
- """Return the corresponding number of sample for a given chunk size, in frames.
-
- Where size is the number of features frames after applying subsampling.
-
- Args:
- size: Number of frames after subsampling.
- hop_length: Frontend's hop length
-
- Returns:
- : Number of raw samples
-
- """
- return self.embed.get_size_before_subsampling(size) * hop_length
-
- def get_encoder_input_size(self, size: int) -> int:
- """Return the corresponding number of sample for a given chunk size, in frames.
-
- Where size is the number of features frames after applying subsampling.
-
- Args:
- size: Number of frames after subsampling.
-
- Returns:
- : Number of raw samples
-
- """
- return self.embed.get_size_before_subsampling(size)
-
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset encoder streaming cache.
-
- Args:
- left_context: Number of frames in left context.
- device: Device ID.
-
- """
- return self.encoders.reset_streaming_cache(left_context, device)
-
- def forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode input sequences.
-
- Args:
- x: Encoder input features. (B, T_in, F)
- x_len: Encoder input features lengths. (B,)
-
- Returns:
- x: Encoder outputs. (B, T_out, D_enc)
- x_len: Encoder outputs lenghts. (B,)
-
- """
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
-
- if short_status:
- raise TooShortUttError(
- f"has {x.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- x.size(1),
- limit_size,
- )
-
- mask = make_source_mask(x_len)
-
- if self.unified_model_training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
- x, mask = self.embed(x, mask, chunk_size)
- pos_enc = self.pos_enc(x)
- chunk_mask = make_chunk_mask(
- x.size(1),
- chunk_size,
- left_chunk_size=self.left_chunk_size,
- device=x.device,
- )
- x_utt = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=None,
- )
- x_chunk = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
-
- olens = mask.eq(0).sum(1)
- if self.time_reduction_factor > 1:
- x_utt = x_utt[:,::self.time_reduction_factor,:]
- x_chunk = x_chunk[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
- return x_utt, x_chunk, olens
-
- elif self.dynamic_chunk_training:
- max_len = x.size(1)
- chunk_size = torch.randint(1, max_len, (1,)).item()
-
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
- else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
-
- x, mask = self.embed(x, mask, chunk_size)
- pos_enc = self.pos_enc(x)
-
- chunk_mask = make_chunk_mask(
- x.size(1),
- chunk_size,
- left_chunk_size=self.left_chunk_size,
- device=x.device,
- )
- else:
- x, mask = self.embed(x, mask, None)
- pos_enc = self.pos_enc(x)
- chunk_mask = None
- x = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
-
- olens = mask.eq(0).sum(1)
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
- return x, olens
-
- def simu_chunk_forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- ) -> torch.Tensor:
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
-
- if short_status:
- raise TooShortUttError(
- f"has {x.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- x.size(1),
- limit_size,
- )
-
- mask = make_source_mask(x_len)
-
- x, mask = self.embed(x, mask, chunk_size)
- pos_enc = self.pos_enc(x)
- chunk_mask = make_chunk_mask(
- x.size(1),
- chunk_size,
- left_chunk_size=self.left_chunk_size,
- device=x.device,
- )
-
- x = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
- olens = mask.eq(0).sum(1)
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
-
- return x
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- processed_frames: torch.tensor,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- ) -> torch.Tensor:
- """Encode input sequences as chunks.
-
- Args:
- x: Encoder input features. (1, T_in, F)
- x_len: Encoder input features lengths. (1,)
- processed_frames: Number of frames already seen.
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
-
- Returns:
- x: Encoder outputs. (B, T_out, D_enc)
-
- """
- mask = make_source_mask(x_len)
- x, mask = self.embed(x, mask, None)
-
- if left_context > 0:
- processed_mask = (
- torch.arange(left_context, device=x.device)
- .view(1, left_context)
- .flip(1)
- )
- processed_mask = processed_mask >= processed_frames
- mask = torch.cat([processed_mask, mask], dim=1)
- pos_enc = self.pos_enc(x, left_context=left_context)
- x = self.encoders.chunk_forward(
- x,
- pos_enc,
- mask,
- chunk_size=chunk_size,
- left_context=left_context,
- right_context=right_context,
- )
-
- if right_context > 0:
- x = x[:, 0:-right_context, :]
-
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
- return x
diff --git a/funasr/models/encoder/chunk_encoder_blocks/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/__init__.py
+++ /dev/null
diff --git a/funasr/models/encoder/chunk_encoder_blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py
deleted file mode 100644
index ba0b25d..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/branchformer.py
+++ /dev/null
@@ -1,178 +0,0 @@
-"""Branchformer block for Transducer encoder."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-class Branchformer(torch.nn.Module):
- """Branchformer module definition.
-
- Reference: https://arxiv.org/pdf/2207.02971.pdf
-
- Args:
- block_size: Input/output size.
- linear_size: Linear layers' hidden size.
- self_att: Self-attention module instance.
- conv_mod: Convolution module instance.
- norm_class: Normalization class.
- norm_args: Normalization module arguments.
- dropout_rate: Dropout rate.
-
- """
-
- def __init__(
- self,
- block_size: int,
- linear_size: int,
- self_att: torch.nn.Module,
- conv_mod: torch.nn.Module,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
- norm_args: Dict = {},
- dropout_rate: float = 0.0,
- ) -> None:
- """Construct a Branchformer object."""
- super().__init__()
-
- self.self_att = self_att
- self.conv_mod = conv_mod
-
- self.channel_proj1 = torch.nn.Sequential(
- torch.nn.Linear(block_size, linear_size), torch.nn.GELU()
- )
- self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size)
-
- self.merge_proj = torch.nn.Linear(block_size + block_size, block_size)
-
- self.norm_self_att = norm_class(block_size, **norm_args)
- self.norm_mlp = norm_class(block_size, **norm_args)
- self.norm_final = norm_class(block_size, **norm_args)
-
- self.dropout = torch.nn.Dropout(dropout_rate)
-
- self.block_size = block_size
- self.linear_size = linear_size
- self.cache = None
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset self-attention and convolution modules cache for streaming.
-
- Args:
- left_context: Number of left frames during chunk-by-chunk inference.
- device: Device to use for cache tensor.
-
- """
- self.cache = [
- torch.zeros(
- (1, left_context, self.block_size),
- device=device,
- ),
- torch.zeros(
- (
- 1,
- self.linear_size // 2,
- self.conv_mod.kernel_size - 1,
- ),
- device=device,
- ),
- ]
-
- def forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Encode input sequences.
-
- Args:
- x: Branchformer input sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask: Source mask. (B, T)
- chunk_mask: Chunk mask. (T_2, T_2)
-
- Returns:
- x: Branchformer output sequences. (B, T, D_block)
- mask: Source mask. (B, T)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
- """
- x1 = x
- x2 = x
-
- x1 = self.norm_self_att(x1)
-
- x1 = self.dropout(
- self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask)
- )
-
- x2 = self.norm_mlp(x2)
-
- x2 = self.channel_proj1(x2)
- x2, _ = self.conv_mod(x2)
- x2 = self.channel_proj2(x2)
-
- x2 = self.dropout(x2)
-
- x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1)))
-
- x = self.norm_final(x)
-
- return x, mask, pos_enc
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- left_context: int = 0,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode chunk of input sequence.
-
- Args:
- x: Branchformer input sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask: Source mask. (B, T_2)
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
-
- Returns:
- x: Branchformer output sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
- """
- x1 = x
- x2 = x
-
- x1 = self.norm_self_att(x1)
-
- if left_context > 0:
- key = torch.cat([self.cache[0], x1], dim=1)
- else:
- key = x1
- val = key
-
- if right_context > 0:
- att_cache = key[:, -(left_context + right_context) : -right_context, :]
- else:
- att_cache = key[:, -left_context:, :]
-
- x1 = self.self_att(x1, key, val, pos_enc, mask=mask, left_context=left_context)
-
- x2 = self.norm_mlp(x2)
- x2 = self.channel_proj1(x2)
-
- x2, conv_cache = self.conv_mod(
- x2, cache=self.cache[1], right_context=right_context
- )
-
- x2 = self.channel_proj2(x2)
-
- x = x + self.merge_proj(torch.cat([x1, x2], dim=-1))
-
- x = self.norm_final(x)
- self.cache = [att_cache, conv_cache]
-
- return x, pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py
deleted file mode 100644
index 0b9bbbf..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/conformer.py
+++ /dev/null
@@ -1,198 +0,0 @@
-"""Conformer block for Transducer encoder."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-class Conformer(torch.nn.Module):
- """Conformer module definition.
-
- Args:
- block_size: Input/output size.
- self_att: Self-attention module instance.
- feed_forward: Feed-forward module instance.
- feed_forward_macaron: Feed-forward module instance for macaron network.
- conv_mod: Convolution module instance.
- norm_class: Normalization module class.
- norm_args: Normalization module arguments.
- dropout_rate: Dropout rate.
-
- """
-
- def __init__(
- self,
- block_size: int,
- self_att: torch.nn.Module,
- feed_forward: torch.nn.Module,
- feed_forward_macaron: torch.nn.Module,
- conv_mod: torch.nn.Module,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
- norm_args: Dict = {},
- dropout_rate: float = 0.0,
- ) -> None:
- """Construct a Conformer object."""
- super().__init__()
-
- self.self_att = self_att
-
- self.feed_forward = feed_forward
- self.feed_forward_macaron = feed_forward_macaron
- self.feed_forward_scale = 0.5
-
- self.conv_mod = conv_mod
-
- self.norm_feed_forward = norm_class(block_size, **norm_args)
- self.norm_self_att = norm_class(block_size, **norm_args)
-
- self.norm_macaron = norm_class(block_size, **norm_args)
- self.norm_conv = norm_class(block_size, **norm_args)
- self.norm_final = norm_class(block_size, **norm_args)
-
- self.dropout = torch.nn.Dropout(dropout_rate)
-
- self.block_size = block_size
- self.cache = None
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset self-attention and convolution modules cache for streaming.
-
- Args:
- left_context: Number of left frames during chunk-by-chunk inference.
- device: Device to use for cache tensor.
-
- """
- self.cache = [
- torch.zeros(
- (1, left_context, self.block_size),
- device=device,
- ),
- torch.zeros(
- (
- 1,
- self.block_size,
- self.conv_mod.kernel_size - 1,
- ),
- device=device,
- ),
- ]
-
- def forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Encode input sequences.
-
- Args:
- x: Conformer input sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask: Source mask. (B, T)
- chunk_mask: Chunk mask. (T_2, T_2)
-
- Returns:
- x: Conformer output sequences. (B, T, D_block)
- mask: Source mask. (B, T)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
- """
- residual = x
-
- x = self.norm_macaron(x)
- x = residual + self.feed_forward_scale * self.dropout(
- self.feed_forward_macaron(x)
- )
-
- residual = x
- x = self.norm_self_att(x)
- x_q = x
- x = residual + self.dropout(
- self.self_att(
- x_q,
- x,
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
- )
-
- residual = x
-
- x = self.norm_conv(x)
- x, _ = self.conv_mod(x)
- x = residual + self.dropout(x)
- residual = x
-
- x = self.norm_feed_forward(x)
- x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
-
- x = self.norm_final(x)
- return x, mask, pos_enc
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_size: int = 16,
- left_context: int = 0,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode chunk of input sequence.
-
- Args:
- x: Conformer input sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask: Source mask. (B, T_2)
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
-
- Returns:
- x: Conformer output sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
- """
- residual = x
-
- x = self.norm_macaron(x)
- x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
-
- residual = x
- x = self.norm_self_att(x)
- if left_context > 0:
- key = torch.cat([self.cache[0], x], dim=1)
- else:
- key = x
- val = key
-
- if right_context > 0:
- att_cache = key[:, -(left_context + right_context) : -right_context, :]
- else:
- att_cache = key[:, -left_context:, :]
- x = residual + self.self_att(
- x,
- key,
- val,
- pos_enc,
- mask,
- left_context=left_context,
- )
-
- residual = x
- x = self.norm_conv(x)
- x, conv_cache = self.conv_mod(
- x, cache=self.cache[1], right_context=right_context
- )
- x = residual + x
- residual = x
-
- x = self.norm_feed_forward(x)
- x = residual + self.feed_forward_scale * self.feed_forward(x)
-
- x = self.norm_final(x)
- self.cache = [att_cache, conv_cache]
-
- return x, pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py
deleted file mode 100644
index f79cc37..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/conv1d.py
+++ /dev/null
@@ -1,221 +0,0 @@
-"""Conv1d block for Transducer encoder."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-
-
-class Conv1d(torch.nn.Module):
- """Conv1d module definition.
-
- Args:
- input_size: Input dimension.
- output_size: Output dimension.
- kernel_size: Size of the convolving kernel.
- stride: Stride of the convolution.
- dilation: Spacing between the kernel points.
- groups: Number of blocked connections from input channels to output channels.
- bias: Whether to add a learnable bias to the output.
- batch_norm: Whether to use batch normalization after convolution.
- relu: Whether to use a ReLU activation after convolution.
- causal: Whether to use causal convolution (set to True if streaming).
- dropout_rate: Dropout rate.
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- kernel_size: Union[int, Tuple],
- stride: Union[int, Tuple] = 1,
- dilation: Union[int, Tuple] = 1,
- groups: Union[int, Tuple] = 1,
- bias: bool = True,
- batch_norm: bool = False,
- relu: bool = True,
- causal: bool = False,
- dropout_rate: float = 0.0,
- ) -> None:
- """Construct a Conv1d object."""
- super().__init__()
-
- if causal:
- self.lorder = kernel_size - 1
- stride = 1
- else:
- self.lorder = 0
- stride = stride
-
- self.conv = torch.nn.Conv1d(
- input_size,
- output_size,
- kernel_size,
- stride=stride,
- dilation=dilation,
- groups=groups,
- bias=bias,
- )
-
- self.dropout = torch.nn.Dropout(p=dropout_rate)
-
- if relu:
- self.relu_func = torch.nn.ReLU()
-
- if batch_norm:
- self.bn = torch.nn.BatchNorm1d(output_size)
-
- self.out_pos = torch.nn.Linear(input_size, output_size)
-
- self.input_size = input_size
- self.output_size = output_size
-
- self.relu = relu
- self.batch_norm = batch_norm
- self.causal = causal
-
- self.kernel_size = kernel_size
- self.padding = dilation * (kernel_size - 1)
- self.stride = stride
-
- self.cache = None
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset Conv1d cache for streaming.
-
- Args:
- left_context: Number of left frames during chunk-by-chunk inference.
- device: Device to use for cache tensor.
-
- """
- self.cache = torch.zeros(
- (1, self.input_size, self.kernel_size - 1), device=device
- )
-
- def forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Encode input sequences.
-
- Args:
- x: Conv1d input sequences. (B, T, D_in)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
- mask: Source mask. (B, T)
- chunk_mask: Chunk mask. (T_2, T_2)
-
- Returns:
- x: Conv1d output sequences. (B, sub(T), D_out)
- mask: Source mask. (B, T) or (B, sub(T))
- pos_enc: Positional embedding sequences.
- (B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out)
-
- """
- x = x.transpose(1, 2)
-
- if self.lorder > 0:
- x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
- else:
- mask = self.create_new_mask(mask)
- pos_enc = self.create_new_pos_enc(pos_enc)
-
- x = self.conv(x)
-
- if self.batch_norm:
- x = self.bn(x)
-
- x = self.dropout(x)
-
- if self.relu:
- x = self.relu_func(x)
-
- x = x.transpose(1, 2)
-
- return x, mask, self.out_pos(pos_enc)
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- left_context: int = 0,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode chunk of input sequence.
-
- Args:
- x: Conv1d input sequences. (B, T, D_in)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
- mask: Source mask. (B, T)
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
-
- Returns:
- x: Conv1d output sequences. (B, T, D_out)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out)
-
- """
- x = torch.cat([self.cache, x.transpose(1, 2)], dim=2)
-
- if right_context > 0:
- self.cache = x[:, :, -(self.lorder + right_context) : -right_context]
- else:
- self.cache = x[:, :, -self.lorder :]
-
- x = self.conv(x)
-
- if self.batch_norm:
- x = self.bn(x)
-
- x = self.dropout(x)
-
- if self.relu:
- x = self.relu_func(x)
-
- x = x.transpose(1, 2)
-
- return x, self.out_pos(pos_enc)
-
- def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor:
- """Create new mask for output sequences.
-
- Args:
- mask: Mask of input sequences. (B, T)
-
- Returns:
- mask: Mask of output sequences. (B, sub(T))
-
- """
- if self.padding != 0:
- mask = mask[:, : -self.padding]
-
- return mask[:, :: self.stride]
-
- def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor:
- """Create new positional embedding vector.
-
- Args:
- pos_enc: Input sequences positional embedding.
- (B, 2 * (T - 1), D_in)
-
- Returns:
- pos_enc: Output sequences positional embedding.
- (B, 2 * (sub(T) - 1), D_in)
-
- """
- pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :]
- pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :]
-
- if self.padding != 0:
- pos_enc_positive = pos_enc_positive[:, : -self.padding, :]
- pos_enc_negative = pos_enc_negative[:, : -self.padding, :]
-
- pos_enc_positive = pos_enc_positive[:, :: self.stride, :]
- pos_enc_negative = pos_enc_negative[:, :: self.stride, :]
-
- pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1)
-
- return pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
deleted file mode 100644
index b9bd2fd..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
+++ /dev/null
@@ -1,222 +0,0 @@
-"""ConvInput block for Transducer encoder."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-import math
-
-from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
-
-
-class ConvInput(torch.nn.Module):
- """ConvInput module definition.
-
- Args:
- input_size: Input size.
- conv_size: Convolution size.
- subsampling_factor: Subsampling factor.
- vgg_like: Whether to use a VGG-like network.
- output_size: Block output dimension.
-
- """
-
- def __init__(
- self,
- input_size: int,
- conv_size: Union[int, Tuple],
- subsampling_factor: int = 4,
- vgg_like: bool = True,
- output_size: Optional[int] = None,
- ) -> None:
- """Construct a ConvInput object."""
- super().__init__()
- if vgg_like:
- if subsampling_factor == 1:
- conv_size1, conv_size2 = conv_size
-
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.MaxPool2d((1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.MaxPool2d((1, 2)),
- )
-
- output_proj = conv_size2 * ((input_size // 2) // 2)
-
- self.subsampling_factor = 1
-
- self.stride_1 = 1
-
- self.create_new_mask = self.create_new_vgg_mask
-
- else:
- conv_size1, conv_size2 = conv_size
-
- kernel_1 = int(subsampling_factor / 2)
-
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.MaxPool2d((kernel_1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
- torch.nn.ReLU(),
- torch.nn.MaxPool2d((2, 2)),
- )
-
- output_proj = conv_size2 * ((input_size // 2) // 2)
-
- self.subsampling_factor = subsampling_factor
-
- self.create_new_mask = self.create_new_vgg_mask
-
- self.stride_1 = kernel_1
-
- else:
- if subsampling_factor == 1:
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
- torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
- torch.nn.ReLU(),
- )
-
- output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
-
- self.subsampling_factor = subsampling_factor
- self.kernel_2 = 3
- self.stride_2 = 1
-
- self.create_new_mask = self.create_new_conv2d_mask
-
- else:
- kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
- subsampling_factor,
- input_size,
- )
-
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size, 3, 2),
- torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
- torch.nn.ReLU(),
- )
-
- output_proj = conv_size * conv_2_output_size
-
- self.subsampling_factor = subsampling_factor
- self.kernel_2 = kernel_2
- self.stride_2 = stride_2
-
- self.create_new_mask = self.create_new_conv2d_mask
-
- self.vgg_like = vgg_like
- self.min_frame_length = 7
-
- if output_size is not None:
- self.output = torch.nn.Linear(output_proj, output_size)
- self.output_size = output_size
- else:
- self.output = None
- self.output_size = output_proj
-
- def forward(
- self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode input sequences.
-
- Args:
- x: ConvInput input sequences. (B, T, D_feats)
- mask: Mask of input sequences. (B, 1, T)
-
- Returns:
- x: ConvInput output sequences. (B, sub(T), D_out)
- mask: Mask of output sequences. (B, 1, sub(T))
-
- """
- if mask is not None:
- mask = self.create_new_mask(mask)
- olens = max(mask.eq(0).sum(1))
-
- b, t, f = x.size()
- x = x.unsqueeze(1) # (b. 1. t. f)
-
- if chunk_size is not None:
- max_input_length = int(
- chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
- )
- x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
- x = list(x)
- x = torch.stack(x, dim=0)
- N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
- x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
-
- x = self.conv(x)
-
- _, c, _, f = x.size()
- if chunk_size is not None:
- x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
- else:
- x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
-
- if self.output is not None:
- x = self.output(x)
-
- return x, mask[:,:olens][:,:x.size(1)]
-
- def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
- """Create a new mask for VGG output sequences.
-
- Args:
- mask: Mask of input sequences. (B, T)
-
- Returns:
- mask: Mask of output sequences. (B, sub(T))
-
- """
- if self.subsampling_factor > 1:
- vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
- mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
-
- vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
- mask = mask[:, :vgg2_t_len][:, ::2]
- else:
- mask = mask
-
- return mask
-
- def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
- """Create new conformer mask for Conv2d output sequences.
-
- Args:
- mask: Mask of input sequences. (B, T)
-
- Returns:
- mask: Mask of output sequences. (B, sub(T))
-
- """
- if self.subsampling_factor > 1:
- return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
- else:
- return mask
-
- def get_size_before_subsampling(self, size: int) -> int:
- """Return the original size before subsampling for a given size.
-
- Args:
- size: Number of frames after subsampling.
-
- Returns:
- : Number of frames before subsampling.
-
- """
- return size * self.subsampling_factor
diff --git a/funasr/models/encoder/chunk_encoder_blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py
deleted file mode 100644
index 9bb9698..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/linear_input.py
+++ /dev/null
@@ -1,52 +0,0 @@
-"""LinearInput block for Transducer encoder."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-
-class LinearInput(torch.nn.Module):
- """ConvInput module definition.
-
- Args:
- input_size: Input size.
- conv_size: Convolution size.
- subsampling_factor: Subsampling factor.
- vgg_like: Whether to use a VGG-like network.
- output_size: Block output dimension.
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: Optional[int] = None,
- subsampling_factor: int = 1,
- ) -> None:
- """Construct a ConvInput object."""
- super().__init__()
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(0.1),
- )
- self.subsampling_factor = subsampling_factor
- self.min_frame_length = 1
-
- def forward(
- self, x: torch.Tensor, mask: Optional[torch.Tensor]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
-
- x = self.embed(x)
- return x, mask
-
- def get_size_before_subsampling(self, size: int) -> int:
- """Return the original size before subsampling for a given size.
-
- Args:
- size: Number of frames after subsampling.
-
- Returns:
- : Number of frames before subsampling.
-
- """
- return size
diff --git a/funasr/models/encoder/chunk_encoder_modules/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/__init__.py
+++ /dev/null
diff --git a/funasr/models/encoder/chunk_encoder_modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py
deleted file mode 100644
index 53e7087..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/attention.py
+++ /dev/null
@@ -1,246 +0,0 @@
-"""Multi-Head attention layers with relative positional encoding."""
-
-import math
-from typing import Optional, Tuple
-
-import torch
-
-
-class RelPositionMultiHeadedAttention(torch.nn.Module):
- """RelPositionMultiHeadedAttention definition.
-
- Args:
- num_heads: Number of attention heads.
- embed_size: Embedding size.
- dropout_rate: Dropout rate.
-
- """
-
- def __init__(
- self,
- num_heads: int,
- embed_size: int,
- dropout_rate: float = 0.0,
- simplified_attention_score: bool = False,
- ) -> None:
- """Construct an MultiHeadedAttention object."""
- super().__init__()
-
- self.d_k = embed_size // num_heads
- self.num_heads = num_heads
-
- assert self.d_k * num_heads == embed_size, (
- "embed_size (%d) must be divisible by num_heads (%d)",
- (embed_size, num_heads),
- )
-
- self.linear_q = torch.nn.Linear(embed_size, embed_size)
- self.linear_k = torch.nn.Linear(embed_size, embed_size)
- self.linear_v = torch.nn.Linear(embed_size, embed_size)
-
- self.linear_out = torch.nn.Linear(embed_size, embed_size)
-
- if simplified_attention_score:
- self.linear_pos = torch.nn.Linear(embed_size, num_heads)
-
- self.compute_att_score = self.compute_simplified_attention_score
- else:
- self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
-
- self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
- self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
- torch.nn.init.xavier_uniform_(self.pos_bias_u)
- torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
- self.compute_att_score = self.compute_attention_score
-
- self.dropout = torch.nn.Dropout(p=dropout_rate)
- self.attn = None
-
- def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
- """Compute relative positional encoding.
-
- Args:
- x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
- left_context: Number of frames in left context.
-
- Returns:
- x: Output sequence. (B, H, T_1, T_2)
-
- """
- batch_size, n_heads, time1, n = x.shape
- time2 = time1 + left_context
-
- batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
-
- return x.as_strided(
- (batch_size, n_heads, time1, time2),
- (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
- storage_offset=(n_stride * (time1 - 1)),
- )
-
- def compute_simplified_attention_score(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- pos_enc: torch.Tensor,
- left_context: int = 0,
- ) -> torch.Tensor:
- """Simplified attention score computation.
-
- Reference: https://github.com/k2-fsa/icefall/pull/458
-
- Args:
- query: Transformed query tensor. (B, H, T_1, d_k)
- key: Transformed key tensor. (B, H, T_2, d_k)
- pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
- left_context: Number of frames in left context.
-
- Returns:
- : Attention score. (B, H, T_1, T_2)
-
- """
- pos_enc = self.linear_pos(pos_enc)
-
- matrix_ac = torch.matmul(query, key.transpose(2, 3))
-
- matrix_bd = self.rel_shift(
- pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
- left_context=left_context,
- )
-
- return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
- def compute_attention_score(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- pos_enc: torch.Tensor,
- left_context: int = 0,
- ) -> torch.Tensor:
- """Attention score computation.
-
- Args:
- query: Transformed query tensor. (B, H, T_1, d_k)
- key: Transformed key tensor. (B, H, T_2, d_k)
- pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
- left_context: Number of frames in left context.
-
- Returns:
- : Attention score. (B, H, T_1, T_2)
-
- """
- p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
-
- query = query.transpose(1, 2)
- q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
- q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
-
- matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
-
- matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
- matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
-
- return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
- def forward_qkv(
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Transform query, key and value.
-
- Args:
- query: Query tensor. (B, T_1, size)
- key: Key tensor. (B, T_2, size)
- v: Value tensor. (B, T_2, size)
-
- Returns:
- q: Transformed query tensor. (B, H, T_1, d_k)
- k: Transformed key tensor. (B, H, T_2, d_k)
- v: Transformed value tensor. (B, H, T_2, d_k)
-
- """
- n_batch = query.size(0)
-
- q = (
- self.linear_q(query)
- .view(n_batch, -1, self.num_heads, self.d_k)
- .transpose(1, 2)
- )
- k = (
- self.linear_k(key)
- .view(n_batch, -1, self.num_heads, self.d_k)
- .transpose(1, 2)
- )
- v = (
- self.linear_v(value)
- .view(n_batch, -1, self.num_heads, self.d_k)
- .transpose(1, 2)
- )
-
- return q, k, v
-
- def forward_attention(
- self,
- value: torch.Tensor,
- scores: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """Compute attention context vector.
-
- Args:
- value: Transformed value. (B, H, T_2, d_k)
- scores: Attention score. (B, H, T_1, T_2)
- mask: Source mask. (B, T_2)
- chunk_mask: Chunk mask. (T_1, T_1)
-
- Returns:
- attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
-
- """
- batch_size = scores.size(0)
- mask = mask.unsqueeze(1).unsqueeze(2)
- if chunk_mask is not None:
- mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
- scores = scores.masked_fill(mask, float("-inf"))
- self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
-
- attn_output = self.dropout(self.attn)
- attn_output = torch.matmul(attn_output, value)
-
- attn_output = self.linear_out(
- attn_output.transpose(1, 2)
- .contiguous()
- .view(batch_size, -1, self.num_heads * self.d_k)
- )
-
- return attn_output
-
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- left_context: int = 0,
- ) -> torch.Tensor:
- """Compute scaled dot product attention with rel. positional encoding.
-
- Args:
- query: Query tensor. (B, T_1, size)
- key: Key tensor. (B, T_2, size)
- value: Value tensor. (B, T_2, size)
- pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
- mask: Source mask. (B, T_2)
- chunk_mask: Chunk mask. (T_1, T_1)
- left_context: Number of frames in left context.
-
- Returns:
- : Output tensor. (B, T_1, H * d_k)
-
- """
- q, k, v = self.forward_qkv(query, key, value)
- scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
- return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/models/encoder/chunk_encoder_modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py
deleted file mode 100644
index 012538a..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/convolution.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""Convolution modules for X-former blocks."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-class ConformerConvolution(torch.nn.Module):
- """ConformerConvolution module definition.
-
- Args:
- channels: The number of channels.
- kernel_size: Size of the convolving kernel.
- activation: Type of activation function.
- norm_args: Normalization module arguments.
- causal: Whether to use causal convolution (set to True if streaming).
-
- """
-
- def __init__(
- self,
- channels: int,
- kernel_size: int,
- activation: torch.nn.Module = torch.nn.ReLU(),
- norm_args: Dict = {},
- causal: bool = False,
- ) -> None:
- """Construct an ConformerConvolution object."""
- super().__init__()
-
- assert (kernel_size - 1) % 2 == 0
-
- self.kernel_size = kernel_size
-
- self.pointwise_conv1 = torch.nn.Conv1d(
- channels,
- 2 * channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
-
- if causal:
- self.lorder = kernel_size - 1
- padding = 0
- else:
- self.lorder = 0
- padding = (kernel_size - 1) // 2
-
- self.depthwise_conv = torch.nn.Conv1d(
- channels,
- channels,
- kernel_size,
- stride=1,
- padding=padding,
- groups=channels,
- )
- self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
- self.pointwise_conv2 = torch.nn.Conv1d(
- channels,
- channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
-
- self.activation = activation
-
- def forward(
- self,
- x: torch.Tensor,
- cache: Optional[torch.Tensor] = None,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute convolution module.
-
- Args:
- x: ConformerConvolution input sequences. (B, T, D_hidden)
- cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
- right_context: Number of frames in right context.
-
- Returns:
- x: ConformerConvolution output sequences. (B, T, D_hidden)
- cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
-
- """
- x = self.pointwise_conv1(x.transpose(1, 2))
- x = torch.nn.functional.glu(x, dim=1)
-
- if self.lorder > 0:
- if cache is None:
- x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
- else:
- x = torch.cat([cache, x], dim=2)
-
- if right_context > 0:
- cache = x[:, :, -(self.lorder + right_context) : -right_context]
- else:
- cache = x[:, :, -self.lorder :]
-
- x = self.depthwise_conv(x)
- x = self.activation(self.norm(x))
-
- x = self.pointwise_conv2(x).transpose(1, 2)
-
- return x, cache
-
-
-class ConvolutionalSpatialGatingUnit(torch.nn.Module):
- """Convolutional Spatial Gating Unit module definition.
-
- Args:
- size: Initial size to determine the number of channels.
- kernel_size: Size of the convolving kernel.
- norm_class: Normalization module class.
- norm_args: Normalization module arguments.
- dropout_rate: Dropout rate.
- causal: Whether to use causal convolution (set to True if streaming).
-
- """
-
- def __init__(
- self,
- size: int,
- kernel_size: int,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
- norm_args: Dict = {},
- dropout_rate: float = 0.0,
- causal: bool = False,
- ) -> None:
- """Construct a ConvolutionalSpatialGatingUnit object."""
- super().__init__()
-
- channels = size // 2
-
- self.kernel_size = kernel_size
-
- if causal:
- self.lorder = kernel_size - 1
- padding = 0
- else:
- self.lorder = 0
- padding = (kernel_size - 1) // 2
-
- self.conv = torch.nn.Conv1d(
- channels,
- channels,
- kernel_size,
- stride=1,
- padding=padding,
- groups=channels,
- )
-
- self.norm = norm_class(channels, **norm_args)
- self.activation = torch.nn.Identity()
-
- self.dropout = torch.nn.Dropout(dropout_rate)
-
- def forward(
- self,
- x: torch.Tensor,
- cache: Optional[torch.Tensor] = None,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute convolution module.
-
- Args:
- x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden)
- cache: ConvolutionalSpationGatingUnit input cache.
- (1, conv_kernel, D_hidden)
- right_context: Number of frames in right context.
-
- Returns:
- x: ConvolutionalSpatialGatingUnit output sequences. (B, T, D_hidden // 2)
-
- """
- x_r, x_g = x.chunk(2, dim=-1)
-
- x_g = self.norm(x_g).transpose(1, 2)
-
- if self.lorder > 0:
- if cache is None:
- x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0)
- else:
- x_g = torch.cat([cache, x_g], dim=2)
-
- if right_context > 0:
- cache = x_g[:, :, -(self.lorder + right_context) : -right_context]
- else:
- cache = x_g[:, :, -self.lorder :]
-
- x_g = self.conv(x_g).transpose(1, 2)
-
- x = self.dropout(x_r * self.activation(x_g))
-
- return x, cache
diff --git a/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
deleted file mode 100644
index 14aca8b..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
+++ /dev/null
@@ -1,105 +0,0 @@
-"""MultiBlocks for encoder architecture."""
-
-from typing import Dict, List, Optional
-
-import torch
-
-
-class MultiBlocks(torch.nn.Module):
- """MultiBlocks definition.
-
- Args:
- block_list: Individual blocks of the encoder architecture.
- output_size: Architecture output size.
- norm_class: Normalization module class.
- norm_args: Normalization module arguments.
-
- """
-
- def __init__(
- self,
- block_list: List[torch.nn.Module],
- output_size: int,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
- norm_args: Optional[Dict] = None,
- ) -> None:
- """Construct a MultiBlocks object."""
- super().__init__()
-
- self.blocks = torch.nn.ModuleList(block_list)
- self.norm_blocks = norm_class(output_size, **norm_args)
-
- self.num_blocks = len(block_list)
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset encoder streaming cache.
-
- Args:
- left_context: Number of left frames during chunk-by-chunk inference.
- device: Device to use for cache tensor.
-
- """
- for idx in range(self.num_blocks):
- self.blocks[idx].reset_streaming_cache(left_context, device)
-
- def forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """Forward each block of the encoder architecture.
-
- Args:
- x: MultiBlocks input sequences. (B, T, D_block_1)
- pos_enc: Positional embedding sequences.
- mask: Source mask. (B, T)
- chunk_mask: Chunk mask. (T_2, T_2)
-
- Returns:
- x: Output sequences. (B, T, D_block_N)
-
- """
- for block_index, block in enumerate(self.blocks):
- x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
-
- x = self.norm_blocks(x)
-
- return x
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_size: int = 0,
- left_context: int = 0,
- right_context: int = 0,
- ) -> torch.Tensor:
- """Forward each block of the encoder architecture.
-
- Args:
- x: MultiBlocks input sequences. (B, T, D_block_1)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
- mask: Source mask. (B, T_2)
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
-
- Returns:
- x: MultiBlocks output sequences. (B, T, D_block_N)
-
- """
- for block_idx, block in enumerate(self.blocks):
- x, pos_enc = block.chunk_forward(
- x,
- pos_enc,
- mask,
- chunk_size=chunk_size,
- left_context=left_context,
- right_context=right_context,
- )
-
- x = self.norm_blocks(x)
-
- return x
diff --git a/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
deleted file mode 100644
index 5b56e26..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""Positional encoding modules."""
-
-import math
-
-import torch
-
-from funasr.modules.embedding import _pre_hook
-
-
-class RelPositionalEncoding(torch.nn.Module):
- """Relative positional encoding.
-
- Args:
- size: Module size.
- max_len: Maximum input length.
- dropout_rate: Dropout rate.
-
- """
-
- def __init__(
- self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
- ) -> None:
- """Construct a RelativePositionalEncoding object."""
- super().__init__()
-
- self.size = size
-
- self.pe = None
- self.dropout = torch.nn.Dropout(p=dropout_rate)
-
- self.extend_pe(torch.tensor(0.0).expand(1, max_len))
- self._register_load_state_dict_pre_hook(_pre_hook)
-
- def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
- """Reset positional encoding.
-
- Args:
- x: Input sequences. (B, T, ?)
- left_context: Number of frames in left context.
-
- """
- time1 = x.size(1) + left_context
-
- if self.pe is not None:
- if self.pe.size(1) >= time1 * 2 - 1:
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
- self.pe = self.pe.to(device=x.device, dtype=x.dtype)
- return
-
- pe_positive = torch.zeros(time1, self.size)
- pe_negative = torch.zeros(time1, self.size)
-
- position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, self.size, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.size)
- )
-
- pe_positive[:, 0::2] = torch.sin(position * div_term)
- pe_positive[:, 1::2] = torch.cos(position * div_term)
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
-
- pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
- pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
- pe_negative = pe_negative[1:].unsqueeze(0)
-
- self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
- dtype=x.dtype, device=x.device
- )
-
- def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
- """Compute positional encoding.
-
- Args:
- x: Input sequences. (B, T, ?)
- left_context: Number of frames in left context.
-
- Returns:
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
-
- """
- self.extend_pe(x, left_context=left_context)
-
- time1 = x.size(1) + left_context
-
- pos_enc = self.pe[
- :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
- ]
- pos_enc = self.dropout(pos_enc)
-
- return pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_utils/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py
deleted file mode 100644
index 21611aa..0000000
--- a/funasr/models/encoder/chunk_encoder_utils/building.py
+++ /dev/null
@@ -1,352 +0,0 @@
-"""Set of methods to build Transducer encoder architecture."""
-
-from typing import Any, Dict, List, Optional, Union
-
-from funasr.modules.activation import get_activation
-from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
-from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
-from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
-from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
-from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
-from funasr.models.encoder.chunk_encoder_modules.attention import ( # noqa: H301
- RelPositionMultiHeadedAttention,
-)
-from funasr.models.encoder.chunk_encoder_modules.convolution import ( # noqa: H301
- ConformerConvolution,
- ConvolutionalSpatialGatingUnit,
-)
-from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
-from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
-from funasr.models.encoder.chunk_encoder_modules.positional_encoding import ( # noqa: H301
- RelPositionalEncoding,
-)
-from funasr.modules.positionwise_feed_forward import (
- PositionwiseFeedForward,
-)
-
-
-def build_main_parameters(
- pos_wise_act_type: str = "swish",
- conv_mod_act_type: str = "swish",
- pos_enc_dropout_rate: float = 0.0,
- pos_enc_max_len: int = 5000,
- simplified_att_score: bool = False,
- norm_type: str = "layer_norm",
- conv_mod_norm_type: str = "layer_norm",
- after_norm_eps: Optional[float] = None,
- after_norm_partial: Optional[float] = None,
- dynamic_chunk_training: bool = False,
- short_chunk_threshold: float = 0.75,
- short_chunk_size: int = 25,
- left_chunk_size: int = 0,
- time_reduction_factor: int = 1,
- unified_model_training: bool = False,
- default_chunk_size: int = 16,
- jitter_range: int =4,
- **activation_parameters,
-) -> Dict[str, Any]:
- """Build encoder main parameters.
-
- Args:
- pos_wise_act_type: Conformer position-wise feed-forward activation type.
- conv_mod_act_type: Conformer convolution module activation type.
- pos_enc_dropout_rate: Positional encoding dropout rate.
- pos_enc_max_len: Positional encoding maximum length.
- simplified_att_score: Whether to use simplified attention score computation.
- norm_type: X-former normalization module type.
- conv_mod_norm_type: Conformer convolution module normalization type.
- after_norm_eps: Epsilon value for the final normalization.
- after_norm_partial: Value for the final normalization with RMSNorm.
- dynamic_chunk_training: Whether to use dynamic chunk training.
- short_chunk_threshold: Threshold for dynamic chunk selection.
- short_chunk_size: Minimum number of frames during dynamic chunk training.
- left_chunk_size: Number of frames in left context.
- **activations_parameters: Parameters of the activation functions.
- (See espnet2/asr_transducer/activation.py)
-
- Returns:
- : Main encoder parameters
-
- """
- main_params = {}
-
- main_params["pos_wise_act"] = get_activation(
- pos_wise_act_type, **activation_parameters
- )
-
- main_params["conv_mod_act"] = get_activation(
- conv_mod_act_type, **activation_parameters
- )
-
- main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
- main_params["pos_enc_max_len"] = pos_enc_max_len
-
- main_params["simplified_att_score"] = simplified_att_score
-
- main_params["norm_type"] = norm_type
- main_params["conv_mod_norm_type"] = conv_mod_norm_type
-
- (
- main_params["after_norm_class"],
- main_params["after_norm_args"],
- ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
-
- main_params["dynamic_chunk_training"] = dynamic_chunk_training
- main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
- main_params["short_chunk_size"] = max(0, short_chunk_size)
- main_params["left_chunk_size"] = max(0, left_chunk_size)
-
- main_params["unified_model_training"] = unified_model_training
- main_params["default_chunk_size"] = max(0, default_chunk_size)
- main_params["jitter_range"] = max(0, jitter_range)
-
- main_params["time_reduction_factor"] = time_reduction_factor
-
- return main_params
-
-
-def build_positional_encoding(
- block_size: int, configuration: Dict[str, Any]
-) -> RelPositionalEncoding:
- """Build positional encoding block.
-
- Args:
- block_size: Input/output size.
- configuration: Positional encoding configuration.
-
- Returns:
- : Positional encoding module.
-
- """
- return RelPositionalEncoding(
- block_size,
- configuration.get("pos_enc_dropout_rate", 0.0),
- max_len=configuration.get("pos_enc_max_len", 5000),
- )
-
-
-def build_input_block(
- input_size: int,
- configuration: Dict[str, Union[str, int]],
-) -> ConvInput:
- """Build encoder input block.
-
- Args:
- input_size: Input size.
- configuration: Input block configuration.
-
- Returns:
- : ConvInput block function.
-
- """
- if configuration["linear"]:
- return LinearInput(
- input_size,
- configuration["output_size"],
- configuration["subsampling_factor"],
- )
- else:
- return ConvInput(
- input_size,
- configuration["conv_size"],
- configuration["subsampling_factor"],
- vgg_like=configuration["vgg_like"],
- output_size=configuration["output_size"],
- )
-
-
-def build_branchformer_block(
- configuration: List[Dict[str, Any]],
- main_params: Dict[str, Any],
-) -> Conformer:
- """Build Branchformer block.
-
- Args:
- configuration: Branchformer block configuration.
- main_params: Encoder main parameters.
-
- Returns:
- : Branchformer block function.
-
- """
- hidden_size = configuration["hidden_size"]
- linear_size = configuration["linear_size"]
-
- dropout_rate = configuration.get("dropout_rate", 0.0)
-
- conv_mod_norm_class, conv_mod_norm_args = get_normalization(
- main_params["conv_mod_norm_type"],
- eps=configuration.get("conv_mod_norm_eps"),
- partial=configuration.get("conv_mod_norm_partial"),
- )
-
- conv_mod_args = (
- linear_size,
- configuration["conv_mod_kernel_size"],
- conv_mod_norm_class,
- conv_mod_norm_args,
- dropout_rate,
- main_params["dynamic_chunk_training"],
- )
-
- mult_att_args = (
- configuration.get("heads", 4),
- hidden_size,
- configuration.get("att_dropout_rate", 0.0),
- main_params["simplified_att_score"],
- )
-
- norm_class, norm_args = get_normalization(
- main_params["norm_type"],
- eps=configuration.get("norm_eps"),
- partial=configuration.get("norm_partial"),
- )
-
- return lambda: Branchformer(
- hidden_size,
- linear_size,
- RelPositionMultiHeadedAttention(*mult_att_args),
- ConvolutionalSpatialGatingUnit(*conv_mod_args),
- norm_class=norm_class,
- norm_args=norm_args,
- dropout_rate=dropout_rate,
- )
-
-
-def build_conformer_block(
- configuration: List[Dict[str, Any]],
- main_params: Dict[str, Any],
-) -> Conformer:
- """Build Conformer block.
-
- Args:
- configuration: Conformer block configuration.
- main_params: Encoder main parameters.
-
- Returns:
- : Conformer block function.
-
- """
- hidden_size = configuration["hidden_size"]
- linear_size = configuration["linear_size"]
-
- pos_wise_args = (
- hidden_size,
- linear_size,
- configuration.get("pos_wise_dropout_rate", 0.0),
- main_params["pos_wise_act"],
- )
-
- conv_mod_norm_args = {
- "eps": configuration.get("conv_mod_norm_eps", 1e-05),
- "momentum": configuration.get("conv_mod_norm_momentum", 0.1),
- }
-
- conv_mod_args = (
- hidden_size,
- configuration["conv_mod_kernel_size"],
- main_params["conv_mod_act"],
- conv_mod_norm_args,
- main_params["dynamic_chunk_training"] or main_params["unified_model_training"],
- )
-
- mult_att_args = (
- configuration.get("heads", 4),
- hidden_size,
- configuration.get("att_dropout_rate", 0.0),
- main_params["simplified_att_score"],
- )
-
- norm_class, norm_args = get_normalization(
- main_params["norm_type"],
- eps=configuration.get("norm_eps"),
- partial=configuration.get("norm_partial"),
- )
-
- return lambda: Conformer(
- hidden_size,
- RelPositionMultiHeadedAttention(*mult_att_args),
- PositionwiseFeedForward(*pos_wise_args),
- PositionwiseFeedForward(*pos_wise_args),
- ConformerConvolution(*conv_mod_args),
- norm_class=norm_class,
- norm_args=norm_args,
- dropout_rate=configuration.get("dropout_rate", 0.0),
- )
-
-
-def build_conv1d_block(
- configuration: List[Dict[str, Any]],
- causal: bool,
-) -> Conv1d:
- """Build Conv1d block.
-
- Args:
- configuration: Conv1d block configuration.
-
- Returns:
- : Conv1d block function.
-
- """
- return lambda: Conv1d(
- configuration["input_size"],
- configuration["output_size"],
- configuration["kernel_size"],
- stride=configuration.get("stride", 1),
- dilation=configuration.get("dilation", 1),
- groups=configuration.get("groups", 1),
- bias=configuration.get("bias", True),
- relu=configuration.get("relu", True),
- batch_norm=configuration.get("batch_norm", False),
- causal=causal,
- dropout_rate=configuration.get("dropout_rate", 0.0),
- )
-
-
-def build_body_blocks(
- configuration: List[Dict[str, Any]],
- main_params: Dict[str, Any],
- output_size: int,
-) -> MultiBlocks:
- """Build encoder body blocks.
-
- Args:
- configuration: Body blocks configuration.
- main_params: Encoder main parameters.
- output_size: Architecture output size.
-
- Returns:
- MultiBlocks function encapsulation all encoder blocks.
-
- """
- fn_modules = []
- extended_conf = []
-
- for c in configuration:
- if c.get("num_blocks") is not None:
- extended_conf += c["num_blocks"] * [
- {c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
- ]
- else:
- extended_conf += [c]
-
- for i, c in enumerate(extended_conf):
- block_type = c["block_type"]
-
- if block_type == "branchformer":
- module = build_branchformer_block(c, main_params)
- elif block_type == "conformer":
- module = build_conformer_block(c, main_params)
- elif block_type == "conv1d":
- module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
- else:
- raise NotImplementedError
-
- fn_modules.append(module)
-
- return MultiBlocks(
- [fn() for fn in fn_modules],
- output_size,
- norm_class=main_params["after_norm_class"],
- norm_args=main_params["after_norm_args"],
- )
diff --git a/funasr/models/encoder/chunk_encoder_utils/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py
deleted file mode 100644
index 1103cb9..0000000
--- a/funasr/models/encoder/chunk_encoder_utils/validation.py
+++ /dev/null
@@ -1,171 +0,0 @@
-"""Set of methods to validate encoder architecture."""
-
-from typing import Any, Dict, List, Tuple
-
-from funasr.modules.nets_utils import sub_factor_to_params
-
-
-def validate_block_arguments(
- configuration: Dict[str, Any],
- block_id: int,
- previous_block_output: int,
-) -> Tuple[int, int]:
- """Validate block arguments.
-
- Args:
- configuration: Architecture configuration.
- block_id: Block ID.
- previous_block_output: Previous block output size.
-
- Returns:
- input_size: Block input size.
- output_size: Block output size.
-
- """
- block_type = configuration.get("block_type")
-
- if block_type is None:
- raise ValueError(
- "Block %d in encoder doesn't have a type assigned. " % block_id
- )
-
- if block_type in ["branchformer", "conformer"]:
- if configuration.get("linear_size") is None:
- raise ValueError(
- "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id
- )
-
- if configuration.get("conv_mod_kernel_size") is None:
- raise ValueError(
- "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)"
- % block_id
- )
-
- input_size = configuration.get("hidden_size")
- output_size = configuration.get("hidden_size")
-
- elif block_type == "conv1d":
- output_size = configuration.get("output_size")
-
- if output_size is None:
- raise ValueError(
- "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id
- )
-
- if configuration.get("kernel_size") is None:
- raise ValueError(
- "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id
- )
-
- input_size = configuration["input_size"] = previous_block_output
- else:
- raise ValueError("Block type: %s is not supported." % block_type)
-
- return input_size, output_size
-
-
-def validate_input_block(
- configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int
-) -> int:
- """Validate input block.
-
- Args:
- configuration: Encoder input block configuration.
- body_first_conf: Encoder first body block configuration.
- input_size: Encoder input block input size.
-
- Return:
- output_size: Encoder input block output size.
-
- """
- vgg_like = configuration.get("vgg_like", False)
- linear = configuration.get("linear", False)
- next_block_type = body_first_conf.get("block_type")
- allowed_next_block_type = ["branchformer", "conformer", "conv1d"]
-
- if next_block_type is None or (next_block_type not in allowed_next_block_type):
- return -1
-
- if configuration.get("subsampling_factor") is None:
- configuration["subsampling_factor"] = 4
-
- if vgg_like:
- conv_size = configuration.get("conv_size", (64, 128))
-
- if isinstance(conv_size, int):
- conv_size = (conv_size, conv_size)
- else:
- conv_size = configuration.get("conv_size", None)
-
- if isinstance(conv_size, tuple):
- conv_size = conv_size[0]
-
- if next_block_type == "conv1d":
- if vgg_like:
- output_size = conv_size[1] * ((input_size // 2) // 2)
- else:
- if conv_size is None:
- conv_size = body_first_conf.get("output_size", 64)
-
- sub_factor = configuration["subsampling_factor"]
-
- _, _, conv_osize = sub_factor_to_params(sub_factor, input_size)
- assert (
- conv_osize > 0
- ), "Conv2D output size is <1 with input size %d and subsampling %d" % (
- input_size,
- sub_factor,
- )
-
- output_size = conv_osize * conv_size
-
- configuration["output_size"] = None
- else:
- output_size = body_first_conf.get("hidden_size")
-
- if conv_size is None:
- conv_size = output_size
-
- configuration["output_size"] = output_size
-
- configuration["conv_size"] = conv_size
- configuration["vgg_like"] = vgg_like
- configuration["linear"] = linear
-
- return output_size
-
-
-def validate_architecture(
- input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int
-) -> Tuple[int, int]:
- """Validate specified architecture is valid.
-
- Args:
- input_conf: Encoder input block configuration.
- body_conf: Encoder body blocks configuration.
- input_size: Encoder input size.
-
- Returns:
- input_block_osize: Encoder input block output size.
- : Encoder body block output size.
-
- """
- input_block_osize = validate_input_block(input_conf, body_conf[0], input_size)
-
- cmp_io = []
-
- for i, b in enumerate(body_conf):
- _io = validate_block_arguments(
- b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1]
- )
-
- cmp_io.append(_io)
-
- for i in range(1, len(cmp_io)):
- if cmp_io[(i - 1)][1] != cmp_io[i][0]:
- raise ValueError(
- "Output/Input mismatch between blocks %d and %d"
- " in the encoder body." % ((i - 1), i)
- )
-
- return input_block_osize, cmp_io[-1][1]
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..c837cf5 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -8,6 +8,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
+from typing import Dict
import torch
from torch import nn
@@ -18,6 +19,7 @@
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
+ RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.modules.embedding import (
@@ -25,16 +27,24 @@
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
+ StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import (
+ TooShortUttError,
+ check_short_utt,
+ make_chunk_mask,
+ make_source_mask,
+)
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.modules.repeat import repeat
+from funasr.modules.repeat import repeat, MultiBlocks
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
@@ -42,6 +52,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
+from funasr.modules.subsampling import StreamingConvInput
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@@ -275,6 +287,188 @@
return (x, pos_emb), mask
return x, mask
+
+class ChunkEncoderLayer(torch.nn.Module):
+ """Chunk Conformer module definition.
+ Args:
+ block_size: Input/output size.
+ self_att: Self-attention module instance.
+ feed_forward: Feed-forward module instance.
+ feed_forward_macaron: Feed-forward module instance for macaron network.
+ conv_mod: Convolution module instance.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ block_size: int,
+ self_att: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: torch.nn.Module,
+ conv_mod: torch.nn.Module,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_args: Dict = {},
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a Conformer object."""
+ super().__init__()
+
+ self.self_att = self_att
+
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.feed_forward_scale = 0.5
+
+ self.conv_mod = conv_mod
+
+ self.norm_feed_forward = norm_class(block_size, **norm_args)
+ self.norm_self_att = norm_class(block_size, **norm_args)
+
+ self.norm_macaron = norm_class(block_size, **norm_args)
+ self.norm_conv = norm_class(block_size, **norm_args)
+ self.norm_final = norm_class(block_size, **norm_args)
+
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.block_size = block_size
+ self.cache = None
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset self-attention and convolution modules cache for streaming.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ self.cache = [
+ torch.zeros(
+ (1, left_context, self.block_size),
+ device=device,
+ ),
+ torch.zeros(
+ (
+ 1,
+ self.block_size,
+ self.conv_mod.kernel_size - 1,
+ ),
+ device=device,
+ ),
+ ]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ mask: Source mask. (B, T)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.dropout(
+ self.feed_forward_macaron(x)
+ )
+
+ residual = x
+ x = self.norm_self_att(x)
+ x_q = x
+ x = residual + self.dropout(
+ self.self_att(
+ x_q,
+ x,
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ )
+
+ residual = x
+
+ x = self.norm_conv(x)
+ x, _ = self.conv_mod(x)
+ x = residual + self.dropout(x)
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+ x = self.norm_final(x)
+ return x, mask, pos_enc
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode chunk of input sequence.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+ residual = x
+ x = self.norm_self_att(x)
+ if left_context > 0:
+ key = torch.cat([self.cache[0], x], dim=1)
+ else:
+ key = x
+ val = key
+
+ if right_context > 0:
+ att_cache = key[:, -(left_context + right_context) : -right_context, :]
+ else:
+ att_cache = key[:, -left_context:, :]
+ x = residual + self.self_att(
+ x,
+ key,
+ val,
+ pos_enc,
+ mask,
+ left_context=left_context,
+ )
+
+ residual = x
+ x = self.norm_conv(x)
+ x, conv_cache = self.conv_mod(
+ x, cache=self.cache[1], right_context=right_context
+ )
+ x = residual + x
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+ x = self.norm_final(x)
+ self.cache = [att_cache, conv_cache]
+
+ return x, pos_enc
class ConformerEncoder(AbsEncoder):
@@ -604,3 +798,447 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+
+
+class CausalConvolution(torch.nn.Module):
+ """ConformerConvolution module definition.
+ Args:
+ channels: The number of channels.
+ kernel_size: Size of the convolving kernel.
+ activation: Type of activation function.
+ norm_args: Normalization module arguments.
+ causal: Whether to use causal convolution (set to True if streaming).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ norm_args: Dict = {},
+ causal: bool = False,
+ ) -> None:
+ """Construct an ConformerConvolution object."""
+ super().__init__()
+
+ assert (kernel_size - 1) % 2 == 0
+
+ self.kernel_size = kernel_size
+
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ if causal:
+ self.lorder = kernel_size - 1
+ padding = 0
+ else:
+ self.lorder = 0
+ padding = (kernel_size - 1) // 2
+
+ self.depthwise_conv = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ )
+ self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cache: Optional[torch.Tensor] = None,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x: ConformerConvolution input sequences. (B, T, D_hidden)
+ cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+ right_context: Number of frames in right context.
+ Returns:
+ x: ConformerConvolution output sequences. (B, T, D_hidden)
+ cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+ """
+ x = self.pointwise_conv1(x.transpose(1, 2))
+ x = torch.nn.functional.glu(x, dim=1)
+
+ if self.lorder > 0:
+ if cache is None:
+ x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ x = torch.cat([cache, x], dim=2)
+
+ if right_context > 0:
+ cache = x[:, :, -(self.lorder + right_context) : -right_context]
+ else:
+ cache = x[:, :, -self.lorder :]
+
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x).transpose(1, 2)
+
+ return x, cache
+
+class ConformerChunkEncoder(torch.nn.Module):
+ """Encoder module definition.
+ Args:
+ input_size: Input size.
+ body_conf: Encoder body configuration.
+ input_conf: Encoder input configuration.
+ main_conf: Encoder main configuration.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ embed_vgg_like: bool = False,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ norm_type: str = "layer_norm",
+ cnn_module_kernel: int = 31,
+ conv_mod_norm_eps: float = 0.00001,
+ conv_mod_norm_momentum: float = 0.1,
+ simplified_att_score: bool = False,
+ dynamic_chunk_training: bool = False,
+ short_chunk_threshold: float = 0.75,
+ short_chunk_size: int = 25,
+ left_chunk_size: int = 0,
+ time_reduction_factor: int = 1,
+ unified_model_training: bool = False,
+ default_chunk_size: int = 16,
+ jitter_range: int = 4,
+ subsampling_factor: int = 1,
+ **activation_parameters,
+ ) -> None:
+ """Construct an Encoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ self.embed = StreamingConvInput(
+ input_size,
+ output_size,
+ subsampling_factor,
+ vgg_like=embed_vgg_like,
+ output_size=output_size,
+ )
+
+ self.pos_enc = StreamingRelPositionalEncoding(
+ output_size,
+ positional_dropout_rate,
+ )
+
+ activation = get_activation(
+ activation_type, **activation_parameters
+ )
+
+ pos_wise_args = (
+ output_size,
+ linear_units,
+ positional_dropout_rate,
+ activation,
+ )
+
+ conv_mod_norm_args = {
+ "eps": conv_mod_norm_eps,
+ "momentum": conv_mod_norm_momentum,
+ }
+
+ conv_mod_args = (
+ output_size,
+ cnn_module_kernel,
+ activation,
+ conv_mod_norm_args,
+ dynamic_chunk_training or unified_model_training,
+ )
+
+ mult_att_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ simplified_att_score,
+ )
+
+ norm_class, norm_args = get_normalization(
+ norm_type,
+ )
+
+ fn_modules = []
+ for _ in range(num_blocks):
+ module = lambda: ChunkEncoderLayer(
+ output_size,
+ RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ CausalConvolution(*conv_mod_args),
+ norm_class=norm_class,
+ norm_args=norm_args,
+ dropout_rate=dropout_rate,
+ )
+ fn_modules.append(module)
+
+ self.encoders = MultiBlocks(
+ [fn() for fn in fn_modules],
+ output_size,
+ norm_class=norm_class,
+ norm_args=norm_args,
+ )
+
+ self.output_size = output_size
+
+ self.dynamic_chunk_training = dynamic_chunk_training
+ self.short_chunk_threshold = short_chunk_threshold
+ self.short_chunk_size = short_chunk_size
+ self.left_chunk_size = left_chunk_size
+
+ self.unified_model_training = unified_model_training
+ self.default_chunk_size = default_chunk_size
+ self.jitter_range = jitter_range
+
+ self.time_reduction_factor = time_reduction_factor
+
+ def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ hop_length: Frontend's hop length
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size) * hop_length
+
+ def get_encoder_input_size(self, size: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size)
+
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of frames in left context.
+ device: Device ID.
+ """
+ return self.encoders.reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ if self.unified_model_training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+ x_chunk = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x_utt, x_chunk, olens
+
+ elif self.dynamic_chunk_training:
+ max_len = x.size(1)
+ chunk_size = torch.randint(1, max_len, (1,)).item()
+
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ else:
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = None
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x, olens
+
+ def simu_chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ processed_frames: torch.tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Encode input sequences as chunks.
+ Args:
+ x: Encoder input features. (1, T_in, F)
+ x_len: Encoder input features lengths. (1,)
+ processed_frames: Number of frames already seen.
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ """
+ mask = make_source_mask(x_len)
+ x, mask = self.embed(x, mask, None)
+
+ if left_context > 0:
+ processed_mask = (
+ torch.arange(left_context, device=x.device)
+ .view(1, left_context)
+ .flip(1)
+ )
+ processed_mask = processed_mask >= processed_frames
+ mask = torch.cat([processed_mask, mask], dim=1)
+ pos_enc = self.pos_enc(x, left_context=left_context)
+ x = self.encoders.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ if right_context > 0:
+ x = x[:, 0:-right_context, :]
+
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ return x
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 31d5a87..6202079 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -11,7 +11,7 @@
import numpy
import torch
from torch import nn
-
+from typing import Optional, Tuple
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -741,3 +741,221 @@
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+ """RelPositionMultiHeadedAttention definition.
+ Args:
+ num_heads: Number of attention heads.
+ embed_size: Embedding size.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ embed_size: int,
+ dropout_rate: float = 0.0,
+ simplified_attention_score: bool = False,
+ ) -> None:
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+
+ self.d_k = embed_size // num_heads
+ self.num_heads = num_heads
+
+ assert self.d_k * num_heads == embed_size, (
+ "embed_size (%d) must be divisible by num_heads (%d)",
+ (embed_size, num_heads),
+ )
+
+ self.linear_q = torch.nn.Linear(embed_size, embed_size)
+ self.linear_k = torch.nn.Linear(embed_size, embed_size)
+ self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+ self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+ if simplified_attention_score:
+ self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+ self.compute_att_score = self.compute_simplified_attention_score
+ else:
+ self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+ self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ self.compute_att_score = self.compute_attention_score
+
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.attn = None
+
+ def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute relative positional encoding.
+ Args:
+ x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+ left_context: Number of frames in left context.
+ Returns:
+ x: Output sequence. (B, H, T_1, T_2)
+ """
+ batch_size, n_heads, time1, n = x.shape
+ time2 = time1 + left_context
+
+ batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+ return x.as_strided(
+ (batch_size, n_heads, time1, time2),
+ (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+ storage_offset=(n_stride * (time1 - 1)),
+ )
+
+ def compute_simplified_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Simplified attention score computation.
+ Reference: https://github.com/k2-fsa/icefall/pull/458
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ pos_enc = self.linear_pos(pos_enc)
+
+ matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+ matrix_bd = self.rel_shift(
+ pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+ left_context=left_context,
+ )
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def compute_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Attention score computation.
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+ matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ v: Value tensor. (B, T_2, size)
+ Returns:
+ q: Transformed query tensor. (B, H, T_1, d_k)
+ k: Transformed key tensor. (B, H, T_2, d_k)
+ v: Transformed value tensor. (B, H, T_2, d_k)
+ """
+ n_batch = query.size(0)
+
+ q = (
+ self.linear_q(query)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ k = (
+ self.linear_k(key)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ v = (
+ self.linear_v(value)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ value: Transformed value. (B, H, T_2, d_k)
+ scores: Attention score. (B, H, T_1, T_2)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ Returns:
+ attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+ """
+ batch_size = scores.size(0)
+ mask = mask.unsqueeze(1).unsqueeze(2)
+ if chunk_mask is not None:
+ mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+ scores = scores.masked_fill(mask, float("-inf"))
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+ attn_output = self.dropout(self.attn)
+ attn_output = torch.matmul(attn_output, value)
+
+ attn_output = self.linear_out(
+ attn_output.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.num_heads * self.d_k)
+ )
+
+ return attn_output
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Compute scaled dot product attention with rel. positional encoding.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ value: Value tensor. (B, T_2, size)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ left_context: Number of frames in left context.
+ Returns:
+ : Output tensor. (B, T_1, H * d_k)
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+ return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 79ca0b2..e0070de 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -423,4 +423,79 @@
outputs = F.pad(outputs, (pad_left, pad_right))
outputs = outputs.transpose(1,2)
return outputs
-
+
+class StreamingRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding.
+ Args:
+ size: Module size.
+ max_len: Maximum input length.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
+ ) -> None:
+ """Construct a RelativePositionalEncoding object."""
+ super().__init__()
+
+ self.size = size
+
+ self.pe = None
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+ self._register_load_state_dict_pre_hook(_pre_hook)
+
+ def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
+ """Reset positional encoding.
+ Args:
+ x: Input sequences. (B, T, ?)
+ left_context: Number of frames in left context.
+ """
+ time1 = x.size(1) + left_context
+
+ if self.pe is not None:
+ if self.pe.size(1) >= time1 * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(device=x.device, dtype=x.dtype)
+ return
+
+ pe_positive = torch.zeros(time1, self.size)
+ pe_negative = torch.zeros(time1, self.size)
+
+ position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.size, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.size)
+ )
+
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+
+ self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
+ dtype=x.dtype, device=x.device
+ )
+
+ def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute positional encoding.
+ Args:
+ x: Input sequences. (B, T, ?)
+ left_context: Number of frames in left context.
+ Returns:
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
+ """
+ self.extend_pe(x, left_context=left_context)
+
+ time1 = x.size(1) + left_context
+
+ pos_enc = self.pe[
+ :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
+ ]
+ pos_enc = self.dropout(pos_enc)
+
+ return pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_modules/normalization.py b/funasr/modules/normalization.py
similarity index 100%
rename from funasr/models/encoder/chunk_encoder_modules/normalization.py
rename to funasr/modules/normalization.py
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index a3d2676..7241dd9 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -6,6 +6,8 @@
"""Repeat the same layer definition."""
+from typing import Dict, List, Optional
+
import torch
@@ -31,3 +33,93 @@
"""
return MultiSequential(*[fn(n) for n in range(N)])
+
+
+class MultiBlocks(torch.nn.Module):
+ """MultiBlocks definition.
+ Args:
+ block_list: Individual blocks of the encoder architecture.
+ output_size: Architecture output size.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ """
+
+ def __init__(
+ self,
+ block_list: List[torch.nn.Module],
+ output_size: int,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_args: Optional[Dict] = None,
+ ) -> None:
+ """Construct a MultiBlocks object."""
+ super().__init__()
+
+ self.blocks = torch.nn.ModuleList(block_list)
+ self.norm_blocks = norm_class(output_size, **norm_args)
+
+ self.num_blocks = len(block_list)
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ for idx in range(self.num_blocks):
+ self.blocks[idx].reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Forward each block of the encoder architecture.
+ Args:
+ x: MultiBlocks input sequences. (B, T, D_block_1)
+ pos_enc: Positional embedding sequences.
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Output sequences. (B, T, D_block_N)
+ """
+ for block_index, block in enumerate(self.blocks):
+ x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
+
+ x = self.norm_blocks(x)
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 0,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Forward each block of the encoder architecture.
+ Args:
+ x: MultiBlocks input sequences. (B, T, D_block_1)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: MultiBlocks output sequences. (B, T, D_block_N)
+ """
+ for block_idx, block in enumerate(self.blocks):
+ x, pos_enc = block.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ x = self.norm_blocks(x)
+
+ return x
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index d492ccf..623be65 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -11,6 +11,10 @@
from funasr.modules.embedding import PositionalEncoding
import logging
from funasr.modules.streaming_utils.utils import sequence_mask
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
+from typing import Optional, Tuple, Union
+import math
+
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
@@ -407,3 +411,201 @@
var_dict_tf[name_tf].shape))
return var_dict_torch_update
+class StreamingConvInput(torch.nn.Module):
+ """Streaming ConvInput module definition.
+ Args:
+ input_size: Input size.
+ conv_size: Convolution size.
+ subsampling_factor: Subsampling factor.
+ vgg_like: Whether to use a VGG-like network.
+ output_size: Block output dimension.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ conv_size: Union[int, Tuple],
+ subsampling_factor: int = 4,
+ vgg_like: bool = True,
+ output_size: Optional[int] = None,
+ ) -> None:
+ """Construct a ConvInput object."""
+ super().__init__()
+ if vgg_like:
+ if subsampling_factor == 1:
+ conv_size1, conv_size2 = conv_size
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((1, 2)),
+ torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((1, 2)),
+ )
+
+ output_proj = conv_size2 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = 1
+
+ self.stride_1 = 1
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ else:
+ conv_size1, conv_size2 = conv_size
+
+ kernel_1 = int(subsampling_factor / 2)
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((kernel_1, 2)),
+ torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((2, 2)),
+ )
+
+ output_proj = conv_size2 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = subsampling_factor
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ self.stride_1 = kernel_1
+
+ else:
+ if subsampling_factor == 1:
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
+
+ self.subsampling_factor = subsampling_factor
+ self.kernel_2 = 3
+ self.stride_2 = 1
+
+ self.create_new_mask = self.create_new_conv2d_mask
+
+ else:
+ kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
+ subsampling_factor,
+ input_size,
+ )
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size * conv_2_output_size
+
+ self.subsampling_factor = subsampling_factor
+ self.kernel_2 = kernel_2
+ self.stride_2 = stride_2
+
+ self.create_new_mask = self.create_new_conv2d_mask
+
+ self.vgg_like = vgg_like
+ self.min_frame_length = 7
+
+ if output_size is not None:
+ self.output = torch.nn.Linear(output_proj, output_size)
+ self.output_size = output_size
+ else:
+ self.output = None
+ self.output_size = output_proj
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: ConvInput input sequences. (B, T, D_feats)
+ mask: Mask of input sequences. (B, 1, T)
+ Returns:
+ x: ConvInput output sequences. (B, sub(T), D_out)
+ mask: Mask of output sequences. (B, 1, sub(T))
+ """
+ if mask is not None:
+ mask = self.create_new_mask(mask)
+ olens = max(mask.eq(0).sum(1))
+
+ b, t, f = x.size()
+ x = x.unsqueeze(1) # (b. 1. t. f)
+
+ if chunk_size is not None:
+ max_input_length = int(
+ chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
+ )
+ x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+ x = list(x)
+ x = torch.stack(x, dim=0)
+ N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+ x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
+ x = self.conv(x)
+
+ _, c, _, f = x.size()
+ if chunk_size is not None:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+ else:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
+
+ if self.output is not None:
+ x = self.output(x)
+
+ return x, mask[:,:olens][:,:x.size(1)]
+
+ def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create a new mask for VGG output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
+ mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
+
+ vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
+ mask = mask[:, :vgg2_t_len][:, ::2]
+ else:
+ mask = mask
+
+ return mask
+
+ def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create new conformer mask for Conv2d output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+ else:
+ return mask
+
+ def get_size_before_subsampling(self, size: int) -> int:
+ """Return the original size before subsampling for a given size.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of frames before subsampling.
+ """
+ return size * self.subsampling_factor
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index cae18c1..bb1f996 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -24,7 +24,7 @@
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
-from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
from funasr.models.e2e_transducer import TransducerModel
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
from funasr.models.joint_network import JointNetwork
@@ -72,9 +72,9 @@
encoder_choices = ClassChoices(
"encoder",
classes=dict(
- encoder=Encoder,
+ chunk_conformer=ConformerChunkEncoder,
),
- default="encoder",
+ default="chunk_conformer",
)
decoder_choices = ClassChoices(
--
Gitblit v1.9.1