"""LinearInput block for Transducer encoder."""
|
|
from typing import Optional, Tuple, Union
|
|
import torch
|
|
class LinearInput(torch.nn.Module):
|
"""ConvInput module definition.
|
|
Args:
|
input_size: Input size.
|
conv_size: Convolution size.
|
subsampling_factor: Subsampling factor.
|
vgg_like: Whether to use a VGG-like network.
|
output_size: Block output dimension.
|
|
"""
|
|
def __init__(
|
self,
|
input_size: int,
|
output_size: Optional[int] = None,
|
subsampling_factor: int = 1,
|
) -> None:
|
"""Construct a ConvInput object."""
|
super().__init__()
|
self.embed = torch.nn.Sequential(
|
torch.nn.Linear(input_size, output_size),
|
torch.nn.LayerNorm(output_size),
|
torch.nn.Dropout(0.1),
|
)
|
self.subsampling_factor = subsampling_factor
|
self.min_frame_length = 1
|
|
def forward(
|
self, x: torch.Tensor, mask: Optional[torch.Tensor]
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
x = self.embed(x)
|
return x, mask
|
|
def get_size_before_subsampling(self, size: int) -> int:
|
"""Return the original size before subsampling for a given size.
|
|
Args:
|
size: Number of frames after subsampling.
|
|
Returns:
|
: Number of frames before subsampling.
|
|
"""
|
return size
|