shixian.shi
2024-02-20 151c339ffeced822917e85255431fcfb74f24db9
support transducer model inference
8个文件已修改
1个文件已删除
1167 ■■■■ 已修改文件
funasr/models/conformer/encoder.py 666 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transducer/beam_search_transducer.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transducer/joint_network.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transducer/model.py 135 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transducer/rnn_decoder.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transducer/rnn_encoder.py 112 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transducer/rnnt_decoder.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/attention.py 213 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer/encoder.py
@@ -14,6 +14,7 @@
    MultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttention,  # noqa: H301
    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttentionChunk,
)
from funasr.models.transformer.embedding import (
    PositionalEncoding,  # noqa: H301
@@ -611,3 +612,668 @@
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
class CausalConvolution(torch.nn.Module):
    """ConformerConvolution module definition.
    Args:
        channels: The number of channels.
        kernel_size: Size of the convolving kernel.
        activation: Type of activation function.
        norm_args: Normalization module arguments.
        causal: Whether to use causal convolution (set to True if streaming).
    """
    def __init__(
        self,
        channels: int,
        kernel_size: int,
        activation: torch.nn.Module = torch.nn.ReLU(),
        norm_args: Dict = {},
        causal: bool = False,
    ) -> None:
        """Construct an ConformerConvolution object."""
        super().__init__()
        assert (kernel_size - 1) % 2 == 0
        self.kernel_size = kernel_size
        self.pointwise_conv1 = torch.nn.Conv1d(
            channels,
            2 * channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        if causal:
            self.lorder = kernel_size - 1
            padding = 0
        else:
            self.lorder = 0
            padding = (kernel_size - 1) // 2
        self.depthwise_conv = torch.nn.Conv1d(
            channels,
            channels,
            kernel_size,
            stride=1,
            padding=padding,
            groups=channels,
        )
        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
        self.pointwise_conv2 = torch.nn.Conv1d(
            channels,
            channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.activation = activation
    def forward(
        self,
        x: torch.Tensor,
        cache: Optional[torch.Tensor] = None,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute convolution module.
        Args:
            x: ConformerConvolution input sequences. (B, T, D_hidden)
            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
            right_context: Number of frames in right context.
        Returns:
            x: ConformerConvolution output sequences. (B, T, D_hidden)
            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
        """
        x = self.pointwise_conv1(x.transpose(1, 2))
        x = torch.nn.functional.glu(x, dim=1)
        if self.lorder > 0:
            if cache is None:
                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
            else:
                x = torch.cat([cache, x], dim=2)
                if right_context > 0:
                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
                else:
                    cache = x[:, :, -self.lorder :]
        x = self.depthwise_conv(x)
        x = self.activation(self.norm(x))
        x = self.pointwise_conv2(x).transpose(1, 2)
        return x, cache
class ChunkEncoderLayer(torch.nn.Module):
    """Chunk Conformer module definition.
    Args:
        block_size: Input/output size.
        self_att: Self-attention module instance.
        feed_forward: Feed-forward module instance.
        feed_forward_macaron: Feed-forward module instance for macaron network.
        conv_mod: Convolution module instance.
        norm_class: Normalization module class.
        norm_args: Normalization module arguments.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        block_size: int,
        self_att: torch.nn.Module,
        feed_forward: torch.nn.Module,
        feed_forward_macaron: torch.nn.Module,
        conv_mod: torch.nn.Module,
        norm_class: torch.nn.Module = LayerNorm,
        norm_args: Dict = {},
        dropout_rate: float = 0.0,
    ) -> None:
        """Construct a Conformer object."""
        super().__init__()
        self.self_att = self_att
        self.feed_forward = feed_forward
        self.feed_forward_macaron = feed_forward_macaron
        self.feed_forward_scale = 0.5
        self.conv_mod = conv_mod
        self.norm_feed_forward = norm_class(block_size, **norm_args)
        self.norm_self_att = norm_class(block_size, **norm_args)
        self.norm_macaron = norm_class(block_size, **norm_args)
        self.norm_conv = norm_class(block_size, **norm_args)
        self.norm_final = norm_class(block_size, **norm_args)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.block_size = block_size
        self.cache = None
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset self-attention and convolution modules cache for streaming.
        Args:
            left_context: Number of left frames during chunk-by-chunk inference.
            device: Device to use for cache tensor.
        """
        self.cache = [
            torch.zeros(
                (1, left_context, self.block_size),
                device=device,
            ),
            torch.zeros(
                (
                    1,
                    self.block_size,
                    self.conv_mod.kernel_size - 1,
                ),
                device=device,
            ),
        ]
    def forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Conformer input sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
            mask: Source mask. (B, T)
            chunk_mask: Chunk mask. (T_2, T_2)
        Returns:
            x: Conformer output sequences. (B, T, D_block)
            mask: Source mask. (B, T)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
        """
        residual = x
        x = self.norm_macaron(x)
        x = residual + self.feed_forward_scale * self.dropout(
            self.feed_forward_macaron(x)
        )
        residual = x
        x = self.norm_self_att(x)
        x_q = x
        x = residual + self.dropout(
            self.self_att(
                x_q,
                x,
                x,
                pos_enc,
                mask,
                chunk_mask=chunk_mask,
            )
        )
        residual = x
        x = self.norm_conv(x)
        x, _ = self.conv_mod(x)
        x = residual + self.dropout(x)
        residual = x
        x = self.norm_feed_forward(x)
        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
        x = self.norm_final(x)
        return x, mask, pos_enc
    def chunk_forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int = 16,
        left_context: int = 0,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk of input sequence.
        Args:
            x: Conformer input sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
            mask: Source mask. (B, T_2)
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
            x: Conformer output sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
        """
        residual = x
        x = self.norm_macaron(x)
        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
        residual = x
        x = self.norm_self_att(x)
        if left_context > 0:
            key = torch.cat([self.cache[0], x], dim=1)
        else:
            key = x
        val = key
        if right_context > 0:
            att_cache = key[:, -(left_context + right_context) : -right_context, :]
        else:
            att_cache = key[:, -left_context:, :]
        x = residual + self.self_att(
            x,
            key,
            val,
            pos_enc,
            mask,
            left_context=left_context,
        )
        residual = x
        x = self.norm_conv(x)
        x, conv_cache = self.conv_mod(
            x, cache=self.cache[1], right_context=right_context
        )
        x = residual + x
        residual = x
        x = self.norm_feed_forward(x)
        x = residual + self.feed_forward_scale * self.feed_forward(x)
        x = self.norm_final(x)
        self.cache = [att_cache, conv_cache]
        return x, pos_enc
@tables.register("encoder_classes", "ChunkConformerEncoder")
class ConformerChunkEncoder(torch.nn.Module):
    """Encoder module definition.
    Args:
        input_size: Input size.
        body_conf: Encoder body configuration.
        input_conf: Encoder input configuration.
        main_conf: Encoder main configuration.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        embed_vgg_like: bool = False,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 3,
        macaron_style: bool = False,
        rel_pos_type: str = "legacy",
        pos_enc_layer_type: str = "rel_pos",
        selfattention_layer_type: str = "rel_selfattn",
        activation_type: str = "swish",
        use_cnn_module: bool = True,
        zero_triu: bool = False,
        norm_type: str = "layer_norm",
        cnn_module_kernel: int = 31,
        conv_mod_norm_eps: float = 0.00001,
        conv_mod_norm_momentum: float = 0.1,
        simplified_att_score: bool = False,
        dynamic_chunk_training: bool = False,
        short_chunk_threshold: float = 0.75,
        short_chunk_size: int = 25,
        left_chunk_size: int = 0,
        time_reduction_factor: int = 1,
        unified_model_training: bool = False,
        default_chunk_size: int = 16,
        jitter_range: int = 4,
        subsampling_factor: int = 1,
    ) -> None:
        """Construct an Encoder object."""
        super().__init__()
        self.embed = StreamingConvInput(
            input_size=input_size,
            conv_size=output_size,
            subsampling_factor=subsampling_factor,
            vgg_like=embed_vgg_like,
            output_size=output_size,
        )
        self.pos_enc = StreamingRelPositionalEncoding(
            output_size,
            positional_dropout_rate,
        )
        activation = get_activation(
            activation_type
       )
        pos_wise_args = (
            output_size,
            linear_units,
            positional_dropout_rate,
            activation,
        )
        conv_mod_norm_args = {
            "eps": conv_mod_norm_eps,
            "momentum": conv_mod_norm_momentum,
        }
        conv_mod_args = (
            output_size,
            cnn_module_kernel,
            activation,
            conv_mod_norm_args,
            dynamic_chunk_training or unified_model_training,
        )
        mult_att_args = (
            attention_heads,
            output_size,
            attention_dropout_rate,
            simplified_att_score,
        )
        fn_modules = []
        for _ in range(num_blocks):
            module = lambda: ChunkEncoderLayer(
                output_size,
                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
                PositionwiseFeedForward(*pos_wise_args),
                PositionwiseFeedForward(*pos_wise_args),
                CausalConvolution(*conv_mod_args),
                dropout_rate=dropout_rate,
            )
            fn_modules.append(module)
        self.encoders = MultiBlocks(
            [fn() for fn in fn_modules],
            output_size,
        )
        self._output_size = output_size
        self.dynamic_chunk_training = dynamic_chunk_training
        self.short_chunk_threshold = short_chunk_threshold
        self.short_chunk_size = short_chunk_size
        self.left_chunk_size = left_chunk_size
        self.unified_model_training = unified_model_training
        self.default_chunk_size = default_chunk_size
        self.jitter_range = jitter_range
        self.time_reduction_factor = time_reduction_factor
    def output_size(self) -> int:
        return self._output_size
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
        Args:
            size: Number of frames after subsampling.
            hop_length: Frontend's hop length
        Returns:
            : Number of raw samples
        """
        return self.embed.get_size_before_subsampling(size) * hop_length
    def get_encoder_input_size(self, size: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
        Args:
            size: Number of frames after subsampling.
        Returns:
            : Number of raw samples
        """
        return self.embed.get_size_before_subsampling(size)
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset encoder streaming cache.
        Args:
            left_context: Number of frames in left context.
            device: Device ID.
        """
        return self.encoders.reset_streaming_cache(left_context, device)
    def forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len).to(x.device)
        if self.unified_model_training:
            if self.training:
                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            else:
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
                left_chunk_size=self.left_chunk_size,
                device=x.device,
            )
            x_utt = self.encoders(
                x,
                pos_enc,
                mask,
                chunk_mask=None,
            )
            x_chunk = self.encoders(
                x,
                pos_enc,
                mask,
                chunk_mask=chunk_mask,
            )
            olens = mask.eq(0).sum(1)
            if self.time_reduction_factor > 1:
                x_utt = x_utt[:,::self.time_reduction_factor,:]
                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
            return x_utt, x_chunk, olens
        elif self.dynamic_chunk_training:
            max_len = x.size(1)
            if self.training:
                chunk_size = torch.randint(1, max_len, (1,)).item()
                if chunk_size > (max_len * self.short_chunk_threshold):
                    chunk_size = max_len
                else:
                    chunk_size = (chunk_size % self.short_chunk_size) + 1
            else:
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
                left_chunk_size=self.left_chunk_size,
                device=x.device,
            )
        else:
            x, mask = self.embed(x, mask, None)
            pos_enc = self.pos_enc(x)
            chunk_mask = None
        x = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
        return x, olens, None
    def full_utt_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len).to(x.device)
        x, mask = self.embed(x, mask, None)
        pos_enc = self.pos_enc(x)
        x_utt = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=None,
        )
        if self.time_reduction_factor > 1:
            x_utt = x_utt[:,::self.time_reduction_factor,:]
        return x_utt
    def simu_chunk_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
    ) -> torch.Tensor:
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len)
        x, mask = self.embed(x, mask, chunk_size)
        pos_enc = self.pos_enc(x)
        chunk_mask = make_chunk_mask(
            x.size(1),
            chunk_size,
            left_chunk_size=self.left_chunk_size,
            device=x.device,
        )
        x = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
        return x
    def chunk_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
        processed_frames: torch.tensor,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
    ) -> torch.Tensor:
        """Encode input sequences as chunks.
        Args:
            x: Encoder input features. (1, T_in, F)
            x_len: Encoder input features lengths. (1,)
            processed_frames: Number of frames already seen.
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
        """
        mask = make_source_mask(x_len)
        x, mask = self.embed(x, mask, None)
        if left_context > 0:
            processed_mask = (
                torch.arange(left_context, device=x.device)
                .view(1, left_context)
                .flip(1)
            )
            processed_mask = processed_mask >= processed_frames
            mask = torch.cat([processed_mask, mask], dim=1)
        pos_enc = self.pos_enc(x, left_context=left_context)
        x = self.encoders.chunk_forward(
            x,
            pos_enc,
            mask,
            chunk_size=chunk_size,
            left_context=left_context,
            right_context=right_context,
        )
        if right_context > 0:
            x = x[:, 0:-right_context, :]
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
        return x
funasr/models/seaco_paraformer/model.py
funasr/models/transducer/beam_search_transducer.py
@@ -1,10 +1,12 @@
"""Search algorithms for Transducer models."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from funasr.models.transducer.joint_network import JointNetwork
funasr/models/transducer/joint_network.py
@@ -1,10 +1,15 @@
"""Transducer joint network implementation."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import get_activation
@tables.register("joint_network_classes", "joint_network")
class JointNetwork(torch.nn.Module):
    """Transducer joint network module.
funasr/models/transducer/model.py
@@ -1,42 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import torch
import logging
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import tempfile
import codecs
import requests
import re
import copy
import torch
import torch.nn as nn
import random
import numpy as np
import time
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
# from funasr.models.e2e_asr_common import ErrorCalculator
# from funasr.models.encoder.abs_encoder import AbsEncoder
# from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.paraformer.cif_predictor import mae_loss
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
# from funasr.models.base_model import FunASRModel
# from funasr.models.paraformer.cif_predictor import CifPredictorV3
from funasr.models.paraformer.search import Hypothesis
from funasr.models.model_class_factory import *
from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.scorers.length_bonus import LengthBonus
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -45,16 +29,10 @@
    @contextmanager
    def autocast(enabled=True):
        yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
class Transducer(nn.Module):
    """ESPnet2ASRTransducerModel module definition."""
@tables.register("model_classes", "Transducer")
class Transducer(torch.nn.Module):
    def __init__(
        self,
        frontend: Optional[str] = None,
@@ -96,35 +74,30 @@
        super().__init__()
        if frontend is not None:
            frontend_class = frontend_classes.get_class(frontend)
            frontend = frontend_class(**frontend_conf)
        if specaug is not None:
            specaug_class = specaug_classes.get_class(specaug)
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = normalize_classes.get_class(normalize)
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = encoder_classes.get_class(encoder)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        decoder_class = decoder_classes.get_class(decoder)
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            **decoder_conf,
        )
        decoder_output_size = decoder.output_size
        joint_network_class = joint_network_classes.get_class(decoder)
        joint_network_class = tables.joint_network_classes.get(joint_network)
        joint_network = joint_network_class(
            vocab_size,
            encoder_output_size,
            decoder_output_size,
            **joint_network_conf,
        )
        
        self.criterion_transducer = None
        self.error_calculator = None
@@ -157,23 +130,17 @@
        self.decoder = decoder
        self.joint_network = joint_network
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
        self.ctc = None
        self.ctc_weight = 0.0
    
    def forward(
        self,
@@ -190,8 +157,6 @@
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
@@ -283,11 +248,6 @@
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                speech, speech_lengths, ctc=self.ctc
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
@@ -449,9 +409,6 @@
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
    
        # 1. Build ASR model
        scorers = {}
@@ -466,28 +423,16 @@
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearch(
            beam_size=kwargs.get("beam_size", 2),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        beam_search = BeamSearchTransducer(
            self.decoder,
            self.joint_network,
            kwargs.get("beam_size", 2),
            nbest=1,
        )
        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        # for scorer in scorers.values():
@@ -495,7 +440,7 @@
        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        self.beam_search = beam_search
        
    def generate(self,
    def inference(self,
             data_in: list,
             data_lengths: list=None,
             key: list=None,
@@ -509,7 +454,7 @@
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
        # if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
@@ -534,12 +479,8 @@
            encoder_out = encoder_out[0]
        
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
        )
        nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        b, n, d = encoder_out.size()
@@ -553,9 +494,9 @@
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                    token_int = hyp.yseq#[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                    token_int = hyp.yseq#[1:last_pos].tolist()
                    
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
funasr/models/transducer/rnn_decoder.py
@@ -1,10 +1,15 @@
import random
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import numpy as np
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import initial_att
@@ -78,7 +83,7 @@
        )
    return att_list
@tables.register("decoder_classes", "rnn_decoder")
class RNNDecoder(nn.Module):
    def __init__(
        self,
funasr/models/transducer/rnn_encoder.py
File was deleted
funasr/models/transducer/rnnt_decoder.py
@@ -1,12 +1,17 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
from typing import List, Optional, Tuple
from funasr.models.transducer.beam_search_transducer import Hypothesis
from funasr.register import tables
from funasr.models.specaug.specaug import SpecAug
from funasr.models.transducer.beam_search_transducer import Hypothesis
@tables.register("decoder_classes", "rnnt_decoder")
class RNNTDecoder(torch.nn.Module):
    """RNN decoder module.
funasr/models/transformer/attention.py
@@ -312,8 +312,221 @@
        return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
    """RelPositionMultiHeadedAttention definition.
    Args:
        num_heads: Number of attention heads.
        embed_size: Embedding size.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        num_heads: int,
        embed_size: int,
        dropout_rate: float = 0.0,
        simplified_attention_score: bool = False,
    ) -> None:
        """Construct an MultiHeadedAttention object."""
        super().__init__()
        self.d_k = embed_size // num_heads
        self.num_heads = num_heads
        assert self.d_k * num_heads == embed_size, (
            "embed_size (%d) must be divisible by num_heads (%d)",
            (embed_size, num_heads),
        )
        self.linear_q = torch.nn.Linear(embed_size, embed_size)
        self.linear_k = torch.nn.Linear(embed_size, embed_size)
        self.linear_v = torch.nn.Linear(embed_size, embed_size)
        self.linear_out = torch.nn.Linear(embed_size, embed_size)
        if simplified_attention_score:
            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
            self.compute_att_score = self.compute_simplified_attention_score
        else:
            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            torch.nn.init.xavier_uniform_(self.pos_bias_u)
            torch.nn.init.xavier_uniform_(self.pos_bias_v)
            self.compute_att_score = self.compute_attention_score
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.attn = None
    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
        """Compute relative positional encoding.
        Args:
            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
            left_context: Number of frames in left context.
        Returns:
            x: Output sequence. (B, H, T_1, T_2)
        """
        batch_size, n_heads, time1, n = x.shape
        time2 = time1 + left_context
        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
        return x.as_strided(
            (batch_size, n_heads, time1, time2),
            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
            storage_offset=(n_stride * (time1 - 1)),
        )
    def compute_simplified_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Simplified attention score computation.
        Reference: https://github.com/k2-fsa/icefall/pull/458
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        pos_enc = self.linear_pos(pos_enc)
        matrix_ac = torch.matmul(query, key.transpose(2, 3))
        matrix_bd = self.rel_shift(
            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
            left_context=left_context,
        )
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def compute_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Attention score computation.
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
        query = query.transpose(1, 2)
        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def forward_qkv(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Transform query, key and value.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            v: Value tensor. (B, T_2, size)
        Returns:
            q: Transformed query tensor. (B, H, T_1, d_k)
            k: Transformed key tensor. (B, H, T_2, d_k)
            v: Transformed value tensor. (B, H, T_2, d_k)
        """
        n_batch = query.size(0)
        q = (
            self.linear_q(query)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        k = (
            self.linear_k(key)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        v = (
            self.linear_v(value)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        return q, k, v
    def forward_attention(
        self,
        value: torch.Tensor,
        scores: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute attention context vector.
        Args:
            value: Transformed value. (B, H, T_2, d_k)
            scores: Attention score. (B, H, T_1, T_2)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
        Returns:
           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
        """
        batch_size = scores.size(0)
        mask = mask.unsqueeze(1).unsqueeze(2)
        if chunk_mask is not None:
            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
        scores = scores.masked_fill(mask, float("-inf"))
        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        attn_output = self.dropout(self.attn)
        attn_output = torch.matmul(attn_output, value)
        attn_output = self.linear_out(
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.num_heads * self.d_k)
        )
        return attn_output
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Compute scaled dot product attention with rel. positional encoding.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            value: Value tensor. (B, T_2, size)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
            left_context: Number of frames in left context.
        Returns:
            : Output tensor. (B, T_1, H * d_k)
        """
        q, k, v = self.forward_qkv(query, key, value)
        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)