From d79287c37e4e7ae2694a992cbbfb03a5ca4f7670 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 20 二月 2024 14:05:58 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge

---
 funasr/models/transducer/rnnt_decoder.py           |   13 
 funasr/models/seaco_paraformer/model.py            |    2 
 /dev/null                                          |  112 -----
 funasr/models/transducer/beam_search_transducer.py |   10 
 funasr/models/transducer/rnn_decoder.py            |   11 
 funasr/models/conformer/encoder.py                 |  666 +++++++++++++++++++++++++++++++++
 funasr/models/transducer/model.py                  |  155 ++-----
 funasr/models/transducer/joint_network.py          |    7 
 funasr/models/transformer/attention.py             |  213 ++++++++++
 9 files changed, 957 insertions(+), 232 deletions(-)

diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index 1ca437d..1d252c2 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -14,6 +14,7 @@
     MultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
     LegacyRelPositionMultiHeadedAttention,  # noqa: H301
+    RelPositionMultiHeadedAttentionChunk,
 )
 from funasr.models.transformer.embedding import (
     PositionalEncoding,  # noqa: H301
@@ -610,4 +611,669 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
+    
 
+class CausalConvolution(torch.nn.Module):
+    """ConformerConvolution module definition.
+    Args:
+        channels: The number of channels.
+        kernel_size: Size of the convolving kernel.
+        activation: Type of activation function.
+        norm_args: Normalization module arguments.
+        causal: Whether to use causal convolution (set to True if streaming).
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        kernel_size: int,
+        activation: torch.nn.Module = torch.nn.ReLU(),
+        norm_args: Dict = {},
+        causal: bool = False,
+    ) -> None:
+        """Construct an ConformerConvolution object."""
+        super().__init__()
+
+        assert (kernel_size - 1) % 2 == 0
+
+        self.kernel_size = kernel_size
+
+        self.pointwise_conv1 = torch.nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        if causal:
+            self.lorder = kernel_size - 1
+            padding = 0
+        else:
+            self.lorder = 0
+            padding = (kernel_size - 1) // 2
+
+        self.depthwise_conv = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=padding,
+            groups=channels,
+        )
+        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+        self.pointwise_conv2 = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        self.activation = activation
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cache: Optional[torch.Tensor] = None,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute convolution module.
+        Args:
+            x: ConformerConvolution input sequences. (B, T, D_hidden)
+            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+            right_context: Number of frames in right context.
+        Returns:
+            x: ConformerConvolution output sequences. (B, T, D_hidden)
+            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+        """
+        x = self.pointwise_conv1(x.transpose(1, 2))
+        x = torch.nn.functional.glu(x, dim=1)
+
+        if self.lorder > 0:
+            if cache is None:
+                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+            else:
+                x = torch.cat([cache, x], dim=2)
+
+                if right_context > 0:
+                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
+                else:
+                    cache = x[:, :, -self.lorder :]
+
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x).transpose(1, 2)
+
+        return x, cache
+
+class ChunkEncoderLayer(torch.nn.Module):
+    """Chunk Conformer module definition.
+    Args:
+        block_size: Input/output size.
+        self_att: Self-attention module instance.
+        feed_forward: Feed-forward module instance.
+        feed_forward_macaron: Feed-forward module instance for macaron network.
+        conv_mod: Convolution module instance.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        self_att: torch.nn.Module,
+        feed_forward: torch.nn.Module,
+        feed_forward_macaron: torch.nn.Module,
+        conv_mod: torch.nn.Module,
+        norm_class: torch.nn.Module = LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Conformer object."""
+        super().__init__()
+
+        self.self_att = self_att
+
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.feed_forward_scale = 0.5
+
+        self.conv_mod = conv_mod
+
+        self.norm_feed_forward = norm_class(block_size, **norm_args)
+        self.norm_self_att = norm_class(block_size, **norm_args)
+
+        self.norm_macaron = norm_class(block_size, **norm_args)
+        self.norm_conv = norm_class(block_size, **norm_args)
+        self.norm_final = norm_class(block_size, **norm_args)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.block_size = block_size
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset self-attention and convolution modules cache for streaming.
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+        """
+        self.cache = [
+            torch.zeros(
+                (1, left_context, self.block_size),
+                device=device,
+            ),
+            torch.zeros(
+                (
+                    1,
+                    self.block_size,
+                    self.conv_mod.kernel_size - 1,
+                ),
+                device=device,
+            ),
+        ]
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            mask: Source mask. (B, T)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.dropout(
+            self.feed_forward_macaron(x)
+        )
+
+        residual = x
+        x = self.norm_self_att(x)
+        x_q = x
+        x = residual + self.dropout(
+            self.self_att(
+                x_q,
+                x,
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+        )
+
+        residual = x
+
+        x = self.norm_conv(x)
+        x, _ = self.conv_mod(x)
+        x = residual + self.dropout(x)
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+        x = self.norm_final(x)
+        return x, mask, pos_enc
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_self_att(x)
+        if left_context > 0:
+            key = torch.cat([self.cache[0], x], dim=1)
+        else:
+            key = x
+        val = key
+
+        if right_context > 0:
+            att_cache = key[:, -(left_context + right_context) : -right_context, :]
+        else:
+            att_cache = key[:, -left_context:, :]
+        x = residual + self.self_att(
+            x,
+            key,
+            val,
+            pos_enc,
+            mask,
+            left_context=left_context,
+        )
+
+        residual = x
+        x = self.norm_conv(x)
+        x, conv_cache = self.conv_mod(
+            x, cache=self.cache[1], right_context=right_context
+        )
+        x = residual + x
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+        x = self.norm_final(x)
+        self.cache = [att_cache, conv_cache]
+
+        return x, pos_enc
+
+@tables.register("encoder_classes", "ChunkConformerEncoder")
+class ConformerChunkEncoder(torch.nn.Module):
+    """Encoder module definition.
+    Args:
+        input_size: Input size.
+        body_conf: Encoder body configuration.
+        input_conf: Encoder input configuration.
+        main_conf: Encoder main configuration.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        embed_vgg_like: bool = False,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        norm_type: str = "layer_norm",
+        cnn_module_kernel: int = 31,
+        conv_mod_norm_eps: float = 0.00001,
+        conv_mod_norm_momentum: float = 0.1,
+        simplified_att_score: bool = False,
+        dynamic_chunk_training: bool = False,
+        short_chunk_threshold: float = 0.75,
+        short_chunk_size: int = 25,
+        left_chunk_size: int = 0,
+        time_reduction_factor: int = 1,
+        unified_model_training: bool = False,
+        default_chunk_size: int = 16,
+        jitter_range: int = 4,
+        subsampling_factor: int = 1,
+    ) -> None:
+        """Construct an Encoder object."""
+        super().__init__()
+
+
+        self.embed = StreamingConvInput(
+            input_size=input_size,
+            conv_size=output_size,
+            subsampling_factor=subsampling_factor,
+            vgg_like=embed_vgg_like,
+            output_size=output_size,
+        )
+
+        self.pos_enc = StreamingRelPositionalEncoding(
+            output_size,
+            positional_dropout_rate,
+        )
+
+        activation = get_activation(
+            activation_type
+       )        
+
+        pos_wise_args = (
+            output_size,
+            linear_units,
+            positional_dropout_rate,
+            activation,
+        )
+
+        conv_mod_norm_args = {
+            "eps": conv_mod_norm_eps,
+            "momentum": conv_mod_norm_momentum,
+        }
+
+        conv_mod_args = (
+            output_size,
+            cnn_module_kernel,
+            activation,
+            conv_mod_norm_args,
+            dynamic_chunk_training or unified_model_training,
+        )
+
+        mult_att_args = (
+            attention_heads,
+            output_size,
+            attention_dropout_rate,
+            simplified_att_score,
+        )
+
+
+        fn_modules = []
+        for _ in range(num_blocks):
+            module = lambda: ChunkEncoderLayer(
+                output_size,
+                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                CausalConvolution(*conv_mod_args),
+                dropout_rate=dropout_rate,
+            )
+            fn_modules.append(module)        
+
+        self.encoders = MultiBlocks(
+            [fn() for fn in fn_modules],
+            output_size,
+        )
+
+        self._output_size = output_size
+
+        self.dynamic_chunk_training = dynamic_chunk_training
+        self.short_chunk_threshold = short_chunk_threshold
+        self.short_chunk_size = short_chunk_size
+        self.left_chunk_size = left_chunk_size
+
+        self.unified_model_training = unified_model_training
+        self.default_chunk_size = default_chunk_size
+        self.jitter_range = jitter_range
+
+        self.time_reduction_factor = time_reduction_factor
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+            hop_length: Frontend's hop length
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size) * hop_length
+
+    def get_encoder_input_size(self, size: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size)
+
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+        Args:
+            left_context: Number of frames in left context.
+            device: Device ID.
+        """
+        return self.encoders.reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+
+        if self.unified_model_training:
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+            x_utt = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=None,
+            )
+            x_chunk = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+
+            olens = mask.eq(0).sum(1)
+            if self.time_reduction_factor > 1:
+                x_utt = x_utt[:,::self.time_reduction_factor,:]
+                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+            return x_utt, x_chunk, olens
+
+        elif self.dynamic_chunk_training:
+            max_len = x.size(1)
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
+
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
+            else:
+                chunk_size = self.default_chunk_size
+
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+        else:
+            x, mask = self.embed(x, mask, None)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = None
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens, None
+
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
+    def simu_chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+
+        x, mask = self.embed(x, mask, chunk_size)
+        pos_enc = self.pos_enc(x)
+        chunk_mask = make_chunk_mask(
+            x.size(1),
+            chunk_size,
+            left_chunk_size=self.left_chunk_size,
+            device=x.device,
+        )
+
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        processed_frames: torch.tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Encode input sequences as chunks.
+        Args:
+            x: Encoder input features. (1, T_in, F)
+            x_len: Encoder input features lengths. (1,)
+            processed_frames: Number of frames already seen.
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+        """
+        mask = make_source_mask(x_len)
+        x, mask = self.embed(x, mask, None)
+
+        if left_context > 0:
+            processed_mask = (
+                torch.arange(left_context, device=x.device)
+                .view(1, left_context)
+                .flip(1)
+            )
+            processed_mask = processed_mask >= processed_frames
+            mask = torch.cat([processed_mask, mask], dim=1)
+        pos_enc = self.pos_enc(x, left_context=left_context)
+        x = self.encoders.chunk_forward(
+            x,
+            pos_enc,
+            mask,
+            chunk_size=chunk_size,
+            left_context=left_context,
+            right_context=right_context,
+        )
+
+        if right_context > 0:
+            x = x[:, 0:-right_context, :]
+
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+        return x
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 8b8e97e..2f55e6e 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -335,7 +335,7 @@
         
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
-
+        
         # hotword
         self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
         
diff --git a/funasr/models/transducer/beam_search_transducer.py b/funasr/models/transducer/beam_search_transducer.py
index 04b26b3..f599615 100644
--- a/funasr/models/transducer/beam_search_transducer.py
+++ b/funasr/models/transducer/beam_search_transducer.py
@@ -1,10 +1,12 @@
-"""Search algorithms for Transducer models."""
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
+import torch
+import numpy as np
 from dataclasses import dataclass
 from typing import Any, Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
 
 from funasr.models.transducer.joint_network import JointNetwork
 
diff --git a/funasr/models/transducer/joint_network.py b/funasr/models/transducer/joint_network.py
index 9fca632..7d424db 100644
--- a/funasr/models/transducer/joint_network.py
+++ b/funasr/models/transducer/joint_network.py
@@ -1,10 +1,15 @@
-"""Transducer joint network implementation."""
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import torch
 
+from funasr.register import tables
 from funasr.models.transformer.utils.nets_utils import get_activation
 
 
+@tables.register("joint_network_classes", "joint_network")
 class JointNetwork(torch.nn.Module):
     """Transducer joint network module.
 
diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py
index 906aa60..fd8ad71 100644
--- a/funasr/models/transducer/model.py
+++ b/funasr/models/transducer/model.py
@@ -1,42 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
 import logging
 from contextlib import contextmanager
+from typing import Dict, Optional, Tuple
 from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.frontends.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.paraformer.cif_predictor import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.paraformer.cif_predictor import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
 
-from funasr.models.model_class_factory import *
+from funasr.register import tables
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.scorers.length_bonus import LengthBonus
+from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
+
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -45,16 +29,10 @@
     @contextmanager
     def autocast(enabled=True):
         yield
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
 
 
-class Transducer(nn.Module):
-    """ESPnet2ASRTransducerModel module definition."""
-
-    
+@tables.register("model_classes", "Transducer")
+class Transducer(torch.nn.Module):
     def __init__(
         self,
         frontend: Optional[str] = None,
@@ -96,35 +74,30 @@
 
         super().__init__()
 
-        if frontend is not None:
-            frontend_class = frontend_classes.get_class(frontend)
-            frontend = frontend_class(**frontend_conf)
         if specaug is not None:
-            specaug_class = specaug_classes.get_class(specaug)
+            specaug_class = tables.specaug_classes.get(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = normalize_classes.get_class(normalize)
+            normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = encoder_classes.get_class(encoder)
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
 
-        decoder_class = decoder_classes.get_class(decoder)
+        decoder_class = tables.decoder_classes.get(decoder)
         decoder = decoder_class(
             vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
             **decoder_conf,
         )
         decoder_output_size = decoder.output_size
 
-        joint_network_class = joint_network_classes.get_class(decoder)
+        joint_network_class = tables.joint_network_classes.get(joint_network)
         joint_network = joint_network_class(
             vocab_size,
             encoder_output_size,
             decoder_output_size,
             **joint_network_conf,
         )
-        
         
         self.criterion_transducer = None
         self.error_calculator = None
@@ -157,23 +130,17 @@
         self.decoder = decoder
         self.joint_network = joint_network
 
-
-        
         self.criterion_att = LabelSmoothingLoss(
             size=vocab_size,
             padding_idx=ignore_id,
             smoothing=lsm_weight,
             normalize_length=length_normalized_loss,
         )
-        #
-        # if report_cer or report_wer:
-        #     self.error_calculator = ErrorCalculator(
-        #         token_list, sym_space, sym_blank, report_cer, report_wer
-        #     )
-        #
 
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
+        self.ctc = None
+        self.ctc_weight = 0.0
     
     def forward(
         self,
@@ -190,8 +157,6 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -283,12 +248,7 @@
         # Forward encoder
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                speech, speech_lengths, ctc=self.ctc
-            )
-        else:
-            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -449,9 +409,6 @@
     def init_beam_search(self,
                          **kwargs,
                          ):
-        from funasr.models.transformer.search import BeamSearch
-        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
-        from funasr.models.transformer.scorers.length_bonus import LengthBonus
     
         # 1. Build ASR model
         scorers = {}
@@ -466,28 +423,16 @@
             length_bonus=LengthBonus(len(token_list)),
         )
 
-        
         # 3. Build ngram model
         # ngram is not supported now
         ngram = None
         scorers["ngram"] = ngram
         
-        weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
-            ctc=kwargs.get("decoding_ctc_weight", 0.0),
-            lm=kwargs.get("lm_weight", 0.0),
-            ngram=kwargs.get("ngram_weight", 0.0),
-            length_bonus=kwargs.get("penalty", 0.0),
-        )
-        beam_search = BeamSearch(
-            beam_size=kwargs.get("beam_size", 2),
-            weights=weights,
-            scorers=scorers,
-            sos=self.sos,
-            eos=self.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        beam_search = BeamSearchTransducer(
+            self.decoder,
+            self.joint_network,
+            kwargs.get("beam_size", 2),
+            nbest=1,
         )
         # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
         # for scorer in scorers.values():
@@ -495,13 +440,13 @@
         #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
         self.beam_search = beam_search
         
-    def generate(self,
-             data_in: list,
-             data_lengths: list=None,
-             key: list=None,
-             tokenizer=None,
-             **kwargs,
-             ):
+    def inference(self,
+                  data_in: list,
+                  data_lengths: list=None,
+                  key: list=None,
+                  tokenizer=None,
+                  **kwargs,
+                  ):
         
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
@@ -509,10 +454,10 @@
         # init beamsearch
         is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
         is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-        if self.beam_search is None and (is_use_lm or is_use_ctc):
-            logging.info("enable beam_search")
-            self.init_beam_search(**kwargs)
-            self.nbest = kwargs.get("nbest", 1)
+        # if self.beam_search is None and (is_use_lm or is_use_ctc):
+        logging.info("enable beam_search")
+        self.init_beam_search(**kwargs)
+        self.nbest = kwargs.get("nbest", 1)
         
         meta_data = {}
         # extract fbank feats
@@ -534,12 +479,8 @@
             encoder_out = encoder_out[0]
         
         # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
-        )
-        
+        nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
         nbest_hyps = nbest_hyps[: self.nbest]
-
 
         results = []
         b, n, d = encoder_out.size()
@@ -553,9 +494,9 @@
                 # remove sos/eos and get results
                 last_pos = -1
                 if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
+                    token_int = hyp.yseq#[1:last_pos]
                 else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
+                    token_int = hyp.yseq#[1:last_pos].tolist()
                     
                 # remove blank symbol id, which is assumed to be 0
                 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
diff --git a/funasr/models/transducer/rnn_decoder.py b/funasr/models/transducer/rnn_decoder.py
index 204f0b1..b999d9c 100644
--- a/funasr/models/transducer/rnn_decoder.py
+++ b/funasr/models/transducer/rnn_decoder.py
@@ -1,10 +1,15 @@
-import random
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
-import numpy as np
 import torch
+import random
+import numpy as np
 import torch.nn as nn
 import torch.nn.functional as F
 
+from funasr.register import tables
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.transformer.utils.nets_utils import to_device
 from funasr.models.language_model.rnn.attentions import initial_att
@@ -78,7 +83,7 @@
         )
     return att_list
 
-
+@tables.register("decoder_classes", "rnn_decoder")
 class RNNDecoder(nn.Module):
     def __init__(
         self,
diff --git a/funasr/models/transducer/rnn_encoder.py b/funasr/models/transducer/rnn_encoder.py
deleted file mode 100644
index 95fb4a5..0000000
--- a/funasr/models/transducer/rnn_encoder.py
+++ /dev/null
@@ -1,112 +0,0 @@
-
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-
-import numpy as np
-import torch
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.language_model.rnn.encoders import RNN
-from funasr.models.language_model.rnn.encoders import RNNP
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-
-class RNNEncoder(AbsEncoder):
-    """RNNEncoder class.
-    Args:
-        input_size: The number of expected features in the input
-        output_size: The number of output features
-        hidden_size: The number of hidden features
-        bidirectional: If ``True`` becomes a bidirectional LSTM
-        use_projection: Use projection layer or not
-        num_layers: Number of recurrent layers
-        dropout: dropout probability
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        rnn_type: str = "lstm",
-        bidirectional: bool = True,
-        use_projection: bool = True,
-        num_layers: int = 4,
-        hidden_size: int = 320,
-        output_size: int = 320,
-        dropout: float = 0.0,
-        subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
-    ):
-        super().__init__()
-        self._output_size = output_size
-        self.rnn_type = rnn_type
-        self.bidirectional = bidirectional
-        self.use_projection = use_projection
-
-        if rnn_type not in {"lstm", "gru"}:
-            raise ValueError(f"Not supported rnn_type={rnn_type}")
-
-        if subsample is None:
-            subsample = np.ones(num_layers + 1, dtype=np.int32)
-        else:
-            subsample = subsample[:num_layers]
-            # Append 1 at the beginning because the second or later is used
-            subsample = np.pad(
-                np.array(subsample, dtype=np.int32),
-                [1, num_layers - len(subsample)],
-                mode="constant",
-                constant_values=1,
-            )
-
-        rnn_type = ("b" if bidirectional else "") + rnn_type
-        if use_projection:
-            self.enc = torch.nn.ModuleList(
-                [
-                    RNNP(
-                        input_size,
-                        num_layers,
-                        hidden_size,
-                        output_size,
-                        subsample,
-                        dropout,
-                        typ=rnn_type,
-                    )
-                ]
-            )
-
-        else:
-            self.enc = torch.nn.ModuleList(
-                [
-                    RNN(
-                        input_size,
-                        num_layers,
-                        hidden_size,
-                        output_size,
-                        dropout,
-                        typ=rnn_type,
-                    )
-                ]
-            )
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        if prev_states is None:
-            prev_states = [None] * len(self.enc)
-        assert len(prev_states) == len(self.enc)
-
-        current_states = []
-        for module, prev_state in zip(self.enc, prev_states):
-            xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
-            current_states.append(states)
-
-        if self.use_projection:
-            xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0)
-        else:
-            xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0)
-        return xs_pad, ilens, current_states
diff --git a/funasr/models/transducer/rnnt_decoder.py b/funasr/models/transducer/rnnt_decoder.py
index 6d35b71..26ca1f2 100644
--- a/funasr/models/transducer/rnnt_decoder.py
+++ b/funasr/models/transducer/rnnt_decoder.py
@@ -1,12 +1,17 @@
-"""RNN decoder definition for Transducer models."""
-
-from typing import List, Optional, Tuple
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import torch
+from typing import List, Optional, Tuple
 
-from funasr.models.transducer.beam_search_transducer import Hypothesis
+from funasr.register import tables
 from funasr.models.specaug.specaug import SpecAug
+from funasr.models.transducer.beam_search_transducer import Hypothesis
 
+
+@tables.register("decoder_classes", "rnnt_decoder")
 class RNNTDecoder(torch.nn.Module):
     """RNN decoder module.
 
diff --git a/funasr/models/transformer/attention.py b/funasr/models/transformer/attention.py
index 32e1e47..f09d642 100644
--- a/funasr/models/transformer/attention.py
+++ b/funasr/models/transformer/attention.py
@@ -312,8 +312,221 @@
         return self.forward_attention(v, scores, mask)
 
 
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+    """RelPositionMultiHeadedAttention definition.
+    Args:
+        num_heads: Number of attention heads.
+        embed_size: Embedding size.
+        dropout_rate: Dropout rate.
+    """
 
+    def __init__(
+        self,
+        num_heads: int,
+        embed_size: int,
+        dropout_rate: float = 0.0,
+        simplified_attention_score: bool = False,
+    ) -> None:
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
 
+        self.d_k = embed_size // num_heads
+        self.num_heads = num_heads
 
+        assert self.d_k * num_heads == embed_size, (
+            "embed_size (%d) must be divisible by num_heads (%d)",
+            (embed_size, num_heads),
+        )
 
+        self.linear_q = torch.nn.Linear(embed_size, embed_size)
+        self.linear_k = torch.nn.Linear(embed_size, embed_size)
+        self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+        self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+        if simplified_attention_score:
+            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+            self.compute_att_score = self.compute_simplified_attention_score
+        else:
+            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            torch.nn.init.xavier_uniform_(self.pos_bias_u)
+            torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+            self.compute_att_score = self.compute_attention_score
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.attn = None
+
+    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute relative positional encoding.
+        Args:
+            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+            left_context: Number of frames in left context.
+        Returns:
+            x: Output sequence. (B, H, T_1, T_2)
+        """
+        batch_size, n_heads, time1, n = x.shape
+        time2 = time1 + left_context
+
+        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+        return x.as_strided(
+            (batch_size, n_heads, time1, time2),
+            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+            storage_offset=(n_stride * (time1 - 1)),
+        )
+
+    def compute_simplified_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Simplified attention score computation.
+        Reference: https://github.com/k2-fsa/icefall/pull/458
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        pos_enc = self.linear_pos(pos_enc)
+
+        matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+        matrix_bd = self.rel_shift(
+            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+            left_context=left_context,
+        )
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def compute_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Attention score computation.
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def forward_qkv(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Transform query, key and value.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            v: Value tensor. (B, T_2, size)
+        Returns:
+            q: Transformed query tensor. (B, H, T_1, d_k)
+            k: Transformed key tensor. (B, H, T_2, d_k)
+            v: Transformed value tensor. (B, H, T_2, d_k)
+        """
+        n_batch = query.size(0)
+
+        q = (
+            self.linear_q(query)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        k = (
+            self.linear_k(key)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        v = (
+            self.linear_v(value)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+
+        return q, k, v
+
+    def forward_attention(
+        self,
+        value: torch.Tensor,
+        scores: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Compute attention context vector.
+        Args:
+            value: Transformed value. (B, H, T_2, d_k)
+            scores: Attention score. (B, H, T_1, T_2)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+        Returns:
+           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+        """
+        batch_size = scores.size(0)
+        mask = mask.unsqueeze(1).unsqueeze(2)
+        if chunk_mask is not None:
+            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+        scores = scores.masked_fill(mask, float("-inf"))
+        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+        attn_output = self.dropout(self.attn)
+        attn_output = torch.matmul(attn_output, value)
+
+        attn_output = self.linear_out(
+            attn_output.transpose(1, 2)
+            .contiguous()
+            .view(batch_size, -1, self.num_heads * self.d_k)
+        )
+
+        return attn_output
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Compute scaled dot product attention with rel. positional encoding.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            value: Value tensor. (B, T_2, size)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+            left_context: Number of frames in left context.
+        Returns:
+            : Output tensor. (B, T_1, H * d_k)
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
 

--
Gitblit v1.9.1