aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
rnnt reorg
5个文件已修改
24 文件已重命名
6个文件已删除
2267 ■■■■ 已修改文件
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml 32 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_rnnt.py 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_transducer.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_transducer_unified.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/branchformer.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/conformer.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/conv1d.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/conv_input.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/linear_input.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/attention.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/convolution.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/multi_blocks.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/normalization.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/positional_encoding.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_utils/building.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_utils/validation.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/joint_network.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_decoder/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_decoder/abs_decoder.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_decoder/rnn_decoder.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_decoder/stateless_decoder.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/encoder/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/encoder/sanm_encoder.py 835 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/error_calculator.py 169 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/espnet_transducer_model_uni_asr.py 485 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models_transducer/utils.py 200 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/activation.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/beam_search/beam_search_transducer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/e2e_asr_common.py 151 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/nets_utils.py 195 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr_transducer.py 41 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -1,13 +1,13 @@
encoder_conf:
    main_conf:
      pos_wise_act_type: swish
      pos_enc_dropout_rate: 0.3
      pos_enc_dropout_rate: 0.5
      conv_mod_act_type: swish
      time_reduction_factor: 2
      unified_model_training: true
      default_chunk_size: 16
      jitter_range: 4
      left_chunk_size: 1
      left_chunk_size: 0
    input_conf:
      block_type: conv2d
      conv_size: 512
@@ -18,9 +18,9 @@
      linear_size: 2048
      hidden_size: 512
      heads: 8
      dropout_rate: 0.3
      pos_wise_dropout_rate: 0.3
      att_dropout_rate: 0.3
      dropout_rate: 0.5
      pos_wise_dropout_rate: 0.5
      att_dropout_rate: 0.5
      conv_mod_kernel_size: 15
      num_blocks: 12    
@@ -29,8 +29,8 @@
decoder_conf:
    embed_size: 512
    hidden_size: 512
    embed_dropout_rate: 0.2
    dropout_rate: 0.1
    embed_dropout_rate: 0.5
    dropout_rate: 0.5
joint_network_conf:
    joint_space_size: 512
@@ -41,14 +41,14 @@
# minibatch related
use_amp: true
batch_type: numel
batch_bins: 1600000
batch_type: unsorted
batch_size: 16
num_workers: 16
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 80
max_epoch: 200
val_scheduler_criterion:
    - valid
    - loss
@@ -56,11 +56,11 @@
-   - valid
    - cer_transducer_chunk
    - min
keep_nbest_models: 5
keep_nbest_models: 10
optim: adam
optim_conf:
   lr: 0.0003
   lr: 0.001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 25000
@@ -75,10 +75,12 @@
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    - 40
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 40
    num_time_mask: 2
    - 50
    num_time_mask: 5
log_interval: 50
funasr/bin/asr_inference_rnnt.py
@@ -16,11 +16,11 @@
from packaging.version import parse as V
from typeguard import check_argument_types, check_return_type
from funasr.models_transducer.beam_search_transducer import (
from funasr.modules.beam_search.beam_search_transducer import (
    BeamSearchTransducer,
    Hypothesis,
)
from funasr.models_transducer.utils import TooShortUttError
from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.asr_transducer import ASRTransducerTask
from funasr.tasks.lm import LMTask
@@ -500,7 +500,6 @@
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
<<<<<<< HEAD
            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            assert len(batch.keys()) == 1
@@ -541,59 +540,6 @@
                if text is not None:
                    ibest_writer["text"][key] = text
=======
            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
            logging.info("decoding, utt_id: {}".format(keys))
            # N-best list of (text, token, token_int, hyp_object)
            time_beg = time.time()
            results = speech2text(cache=cache, **batch)
            if len(results) < 1:
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
            time_end = time.time()
            forward_time = time_end - time_beg
            lfr_factor = results[0][-1]
            length = results[0][-2]
            forward_time_total += forward_time
            length_total += length
            rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
            logging.info(rtf_cur)
            for batch_id in range(_bs):
                result = [results[batch_id][:-2]]
                key = keys[batch_id]
                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
                    # Create a directory: outdir/{n}best_recog
                    if writer is not None:
                        ibest_writer = writer[f"{n}best_recog"]
                        # Write the result to each file
                        ibest_writer["token"][key] = " ".join(token)
                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                        ibest_writer["score"][key] = str(hyp.score)
                        ibest_writer["rtf"][key] = rtf_cur
                    if text is not None:
                        text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                        item = {'key': key, 'value': text_postprocessed}
                        asr_result_list.append(item)
                        finish_count += 1
                        # asr_utils.print_progress(finish_count / file_count)
                        if writer is not None:
                            ibest_writer["text"][key] = " ".join(word_lists)
                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
        logging.info(rtf_avg)
        if writer is not None:
            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
        return asr_result_list
    return _forward
>>>>>>> main
def get_parser():
funasr/models/e2e_transducer.py
File was renamed from funasr/models_transducer/espnet_transducer_model.py
@@ -10,11 +10,11 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models_transducer.utils import get_transducer_task_io
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -28,7 +28,7 @@
        yield
class ESPnetASRTransducerModel(AbsESPnetModel):
class TransducerModel(AbsESPnetModel):
    """ESPnet2ASRTransducerModel module definition.
    Args:
funasr/models/e2e_transducer_unified.py
File was renamed from funasr/models_transducer/espnet_transducer_model_unified.py
@@ -10,10 +10,10 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models_transducer.utils import get_transducer_task_io
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -23,7 +23,7 @@
from funasr.losses.label_smoothing_loss import (  # noqa: H301
    LabelSmoothingLoss,
)
from funasr.models_transducer.error_calculator import ErrorCalculator
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -33,7 +33,7 @@
        yield
class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
class UnifiedTransducerModel(AbsESPnetModel):
    """ESPnet2ASRTransducerModel module definition.
    Args:
@@ -289,7 +289,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def collect_feats(
funasr/models/encoder/chunk_encoder.py
File was renamed from funasr/models_transducer/encoder/encoder.py
@@ -1,26 +1,23 @@
"""Encoder for Transducer model."""
from typing import Any, Dict, List, Tuple
import torch
from typeguard import check_argument_types
from funasr.models_transducer.encoder.building import (
from funasr.models.encoder.chunk_encoder_utils.building import (
    build_body_blocks,
    build_input_block,
    build_main_parameters,
    build_positional_encoding,
)
from funasr.models_transducer.encoder.validation import validate_architecture
from funasr.models_transducer.utils import (
from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
from funasr.modules.nets_utils import (
    TooShortUttError,
    check_short_utt,
    make_chunk_mask,
    make_source_mask,
)
class Encoder(torch.nn.Module):
class ChunkEncoder(torch.nn.Module):
    """Encoder module definition.
    Args:
@@ -61,10 +58,9 @@
        self.unified_model_training = main_params["unified_model_training"]
        self.default_chunk_size = main_params["default_chunk_size"]
        self.jitter_range = main_params["jitter_range"]
        self.jitter_range = main_params["jitter_range"]
        self.time_reduction_factor = main_params["time_reduction_factor"]
        self.time_reduction_factor = main_params["time_reduction_factor"]
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
@@ -79,7 +75,7 @@
        """
        return self.embed.get_size_before_subsampling(size) * hop_length
    def get_encoder_input_size(self, size: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
@@ -157,7 +153,7 @@
                mask,
                chunk_mask=chunk_mask,
            )
            olens = mask.eq(0).sum(1)
            if self.time_reduction_factor > 1:
                x_utt = x_utt[:,::self.time_reduction_factor,:]
@@ -194,14 +190,14 @@
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
        return x, olens
    def simu_chunk_forward(
        self,
        x: torch.Tensor,
@@ -290,7 +286,7 @@
        if right_context > 0:
            x = x[:, 0:-right_context, :]
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
        return x
funasr/models/encoder/chunk_encoder_blocks/__init__.py
funasr/models/encoder/chunk_encoder_blocks/branchformer.py
funasr/models/encoder/chunk_encoder_blocks/conformer.py
funasr/models/encoder/chunk_encoder_blocks/conv1d.py
funasr/models/encoder/chunk_encoder_blocks/conv_input.py
File was renamed from funasr/models_transducer/encoder/blocks/conv_input.py
@@ -5,7 +5,7 @@
import torch
import math
from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
class ConvInput(torch.nn.Module):
funasr/models/encoder/chunk_encoder_blocks/linear_input.py
funasr/models/encoder/chunk_encoder_modules/__init__.py
funasr/models/encoder/chunk_encoder_modules/attention.py
funasr/models/encoder/chunk_encoder_modules/convolution.py
funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
funasr/models/encoder/chunk_encoder_modules/normalization.py
funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
funasr/models/encoder/chunk_encoder_utils/building.py
File was renamed from funasr/models_transducer/encoder/building.py
@@ -2,22 +2,22 @@
from typing import Any, Dict, List, Optional, Union
from funasr.models_transducer.activation import get_activation
from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
from funasr.models_transducer.encoder.blocks.conformer import Conformer
from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
from funasr.models_transducer.encoder.modules.attention import (  # noqa: H301
from funasr.modules.activation import get_activation
from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
from funasr.models.encoder.chunk_encoder_modules.attention import (  # noqa: H301
    RelPositionMultiHeadedAttention,
)
from funasr.models_transducer.encoder.modules.convolution import (  # noqa: H301
from funasr.models.encoder.chunk_encoder_modules.convolution import (  # noqa: H301
    ConformerConvolution,
    ConvolutionalSpatialGatingUnit,
)
from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
from funasr.models_transducer.encoder.modules.normalization import get_normalization
from funasr.models_transducer.encoder.modules.positional_encoding import (  # noqa: H301
from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
from funasr.models.encoder.chunk_encoder_modules.positional_encoding import (  # noqa: H301
    RelPositionalEncoding,
)
from funasr.modules.positionwise_feed_forward import (
funasr/models/encoder/chunk_encoder_utils/validation.py
File was renamed from funasr/models_transducer/encoder/validation.py
@@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple
from funasr.models_transducer.utils import sub_factor_to_params
from funasr.modules.nets_utils import sub_factor_to_params
def validate_block_arguments(
funasr/models/joint_network.py
File was renamed from funasr/models_transducer/joint_network.py
@@ -2,7 +2,7 @@
import torch
from funasr.models_transducer.activation import get_activation
from funasr.modules.activation import get_activation
class JointNetwork(torch.nn.Module):
funasr/models/rnnt_decoder/__init__.py
funasr/models/rnnt_decoder/abs_decoder.py
funasr/models/rnnt_decoder/rnn_decoder.py
File was renamed from funasr/models_transducer/decoder/rnn_decoder.py
@@ -5,8 +5,8 @@
import torch
from typeguard import check_argument_types
from funasr.models_transducer.beam_search_transducer import Hypothesis
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class RNNDecoder(AbsDecoder):
funasr/models/rnnt_decoder/stateless_decoder.py
File was renamed from funasr/models_transducer/decoder/stateless_decoder.py
@@ -5,8 +5,8 @@
import torch
from typeguard import check_argument_types
from funasr.models_transducer.beam_search_transducer import Hypothesis
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class StatelessDecoder(AbsDecoder):
@@ -26,7 +26,6 @@
        embed_size: int = 256,
        embed_dropout_rate: float = 0.0,
        embed_pad: int = 0,
        use_embed_mask: bool = False,
    ) -> None:
        """Construct a StatelessDecoder object."""
        super().__init__()
@@ -42,14 +41,6 @@
        self.device = next(self.parameters()).device
        self.score_cache = {}
        self.use_embed_mask = use_embed_mask
        if self.use_embed_mask:
            self._embed_mask = SpecAug(
                time_mask_width_range=3,
                num_time_mask=1,
                apply_freq_mask=False,
                apply_time_warp=False
            )
    def forward(
@@ -69,9 +60,6 @@
        """
        dec_embed = self.embed_dropout_rate(self.embed(labels))
        if self.use_embed_mask and self.training:
            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
        return dec_embed
    def score(
funasr/models_transducer/__init__.py
funasr/models_transducer/encoder/__init__.py
funasr/models_transducer/encoder/sanm_encoder.py
File was deleted
funasr/models_transducer/error_calculator.py
File was deleted
funasr/models_transducer/espnet_transducer_model_uni_asr.py
File was deleted
funasr/models_transducer/utils.py
File was deleted
funasr/modules/activation.py
funasr/modules/beam_search/beam_search_transducer.py
File was renamed from funasr/models_transducer/beam_search_transducer.py
@@ -6,8 +6,8 @@
import numpy as np
import torch
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.joint_network import JointNetwork
@dataclass
funasr/modules/e2e_asr_common.py
@@ -6,6 +6,8 @@
"""Common functions for ASR."""
from typing import List, Optional, Tuple
import json
import logging
import sys
@@ -13,7 +15,11 @@
from itertools import groupby
import numpy as np
import six
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
    """End detection.
@@ -247,3 +253,148 @@
            word_eds.append(editdistance.eval(hyp_words, ref_words))
            word_ref_lens.append(len(ref_words))
        return float(sum(word_eds)) / sum(word_ref_lens)
class ErrorCalculatorTransducer:
    """Calculate CER and WER for transducer models.
    Args:
        decoder: Decoder module.
        joint_network: Joint Network module.
        token_list: List of token units.
        sym_space: Space symbol.
        sym_blank: Blank symbol.
        report_cer: Whether to compute CER.
        report_wer: Whether to compute WER.
    """
    def __init__(
        self,
        decoder: AbsDecoder,
        joint_network: JointNetwork,
        token_list: List[int],
        sym_space: str,
        sym_blank: str,
        report_cer: bool = False,
        report_wer: bool = False,
    ) -> None:
        """Construct an ErrorCalculatorTransducer object."""
        super().__init__()
        self.beam_search = BeamSearchTransducer(
            decoder=decoder,
            joint_network=joint_network,
            beam_size=1,
            search_type="default",
            score_norm=False,
        )
        self.decoder = decoder
        self.token_list = token_list
        self.space = sym_space
        self.blank = sym_blank
        self.report_cer = report_cer
        self.report_wer = report_wer
    def __call__(
        self, encoder_out: torch.Tensor, target: torch.Tensor
    ) -> Tuple[Optional[float], Optional[float]]:
        """Calculate sentence-level WER or/and CER score for Transducer model.
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            target: Target label ID sequences. (B, L)
        Returns:
            : Sentence-level CER score.
            : Sentence-level WER score.
        """
        cer, wer = None, None
        batchsize = int(encoder_out.size(0))
        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
        char_pred, char_target = self.convert_to_char(pred, target)
        if self.report_cer:
            cer = self.calculate_cer(char_pred, char_target)
        if self.report_wer:
            wer = self.calculate_wer(char_pred, char_target)
        return cer, wer
    def convert_to_char(
        self, pred: torch.Tensor, target: torch.Tensor
    ) -> Tuple[List, List]:
        """Convert label ID sequences to character sequences.
        Args:
            pred: Prediction label ID sequences. (B, U)
            target: Target label ID sequences. (B, L)
        Returns:
            char_pred: Prediction character sequences. (B, ?)
            char_target: Target character sequences. (B, ?)
        """
        char_pred, char_target = [], []
        for i, pred_i in enumerate(pred):
            char_pred_i = [self.token_list[int(h)] for h in pred_i]
            char_target_i = [self.token_list[int(r)] for r in target[i]]
            char_pred_i = "".join(char_pred_i).replace(self.space, " ")
            char_pred_i = char_pred_i.replace(self.blank, "")
            char_target_i = "".join(char_target_i).replace(self.space, " ")
            char_target_i = char_target_i.replace(self.blank, "")
            char_pred.append(char_pred_i)
            char_target.append(char_target_i)
        return char_pred, char_target
    def calculate_cer(
        self, char_pred: torch.Tensor, char_target: torch.Tensor
    ) -> float:
        """Calculate sentence-level CER score.
        Args:
            char_pred: Prediction character sequences. (B, ?)
            char_target: Target character sequences. (B, ?)
        Returns:
            : Average sentence-level CER score.
        """
        import editdistance
        distances, lens = [], []
        for i, char_pred_i in enumerate(char_pred):
            pred = char_pred_i.replace(" ", "")
            target = char_target[i].replace(" ", "")
            distances.append(editdistance.eval(pred, target))
            lens.append(len(target))
        return float(sum(distances)) / sum(lens)
    def calculate_wer(
        self, char_pred: torch.Tensor, char_target: torch.Tensor
    ) -> float:
        """Calculate sentence-level WER score.
        Args:
            char_pred: Prediction character sequences. (B, ?)
            char_target: Target character sequences. (B, ?)
        Returns:
            : Average sentence-level WER score
        """
        import editdistance
        distances, lens = [], []
        for i, char_pred_i in enumerate(char_pred):
            pred = char_pred_i.replace("▁", " ").split()
            target = char_target[i].replace("▁", " ").split()
            distances.append(editdistance.eval(pred, target))
            lens.append(len(target))
        return float(sum(distances)) / sum(lens)
funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
from typing import Dict
from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -506,3 +506,196 @@
    }
    return activation_funcs[act]()
class TooShortUttError(Exception):
    """Raised when the utt is too short for subsampling.
    Args:
        message: Error message to display.
        actual_size: The size that cannot pass the subsampling.
        limit: The size limit for subsampling.
    """
    def __init__(self, message: str, actual_size: int, limit: int) -> None:
        """Construct a TooShortUttError module."""
        super().__init__(message)
        self.actual_size = actual_size
        self.limit = limit
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
    """Check if the input is too short for subsampling.
    Args:
        sub_factor: Subsampling factor for Conv2DSubsampling.
        size: Input size.
    Returns:
        : Whether an error should be sent.
        : Size limit for specified subsampling factor.
    """
    if sub_factor == 2 and size < 3:
        return True, 7
    elif sub_factor == 4 and size < 7:
        return True, 7
    elif sub_factor == 6 and size < 11:
        return True, 11
    return False, -1
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
    """Get conv2D second layer parameters for given subsampling factor.
    Args:
        sub_factor: Subsampling factor (1/X).
        input_size: Input size.
    Returns:
        : Kernel size for second convolution.
        : Stride for second convolution.
        : Conv2DSubsampling output size.
    """
    if sub_factor == 2:
        return 3, 1, (((input_size - 1) // 2 - 2))
    elif sub_factor == 4:
        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
    elif sub_factor == 6:
        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
    else:
        raise ValueError(
            "subsampling_factor parameter should be set to either 2, 4 or 6."
        )
def make_chunk_mask(
    size: int,
    chunk_size: int,
    left_chunk_size: int = 0,
    device: torch.device = None,
) -> torch.Tensor:
    """Create chunk mask for the subsequent steps (size, size).
    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
    Args:
        size: Size of the source mask.
        chunk_size: Number of frames in chunk.
        left_chunk_size: Size of the left context in chunks (0 means full context).
        device: Device for the mask tensor.
    Returns:
        mask: Chunk mask. (size, size)
    """
    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if left_chunk_size <= 0:
            start = 0
        else:
            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
        end = min((i // chunk_size + 1) * chunk_size, size)
        mask[i, start:end] = True
    return ~mask
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
    """Create source mask for given lengths.
    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
    Args:
        lengths: Sequence lengths. (B,)
    Returns:
        : Mask for the sequence lengths. (B, max_len)
    """
    max_len = lengths.max()
    batch_size = lengths.size(0)
    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
    return expanded_lengths >= lengths.unsqueeze(1)
def get_transducer_task_io(
    labels: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    ignore_id: int = -1,
    blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Get Transducer loss I/O.
    Args:
        labels: Label ID sequences. (B, L)
        encoder_out_lens: Encoder output lengths. (B,)
        ignore_id: Padding symbol ID.
        blank_id: Blank symbol ID.
    Returns:
        decoder_in: Decoder inputs. (B, U)
        target: Target label ID sequences. (B, U)
        t_len: Time lengths. (B,)
        u_len: Label lengths. (B,)
    """
    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
        """Create padded batch of labels from a list of labels sequences.
        Args:
            labels: Labels sequences. [B x (?)]
            padding_value: Padding value.
        Returns:
            labels: Batch of padded labels sequences. (B,)
        """
        batch_size = len(labels)
        padded = (
            labels[0]
            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
            .fill_(padding_value)
        )
        for i in range(batch_size):
            padded[i, : labels[i].size(0)] = labels[i]
        return padded
    device = labels.device
    labels_unpad = [y[y != ignore_id] for y in labels]
    blank = labels[0].new([blank_id])
    decoder_in = pad_list(
        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
    ).to(device)
    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
    encoder_out_lens = list(map(int, encoder_out_lens))
    t_len = torch.IntTensor(encoder_out_lens).to(device)
    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
    return decoder_in, target, t_len, u_len
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
    if t.size(dim) == pad_len:
        return t
    else:
        pad_size = list(t.shape)
        pad_size[dim] = pad_len - t.size(dim)
        return torch.cat(
            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
        )
funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
    LightweightConvolutionTransformerDecoder,
    TransformerDecoder,
)
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.e2e_transducer import TransducerModel
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
from funasr.models.joint_network import JointNetwork
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
        "encoder",
        classes=dict(
                encoder=Encoder,
                sanm_chunk_opt=SANMEncoderChunkOpt,
        ),
        default="encoder",
)
@@ -158,7 +155,7 @@
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(ESPnetASRTransducerModel),
            default=get_default_kwargs(TransducerModel),
            help="The keyword arguments for the model class.",
        )
        # group.add_argument(
@@ -354,7 +351,7 @@
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
        """Required data depending on task mode.
        Args:
            cls: ASRTransducerTask object.
@@ -440,22 +437,8 @@
        # 7. Build model
        if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
            model = UniASRTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        elif encoder.unified_model_training:
            model = ESPnetASRUnifiedTransducerModel(
        if encoder.unified_model_training:
            model = UnifiedTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
@@ -469,7 +452,7 @@
            )
        else:
            model = ESPnetASRTransducerModel(
            model = TransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,