aky15
2023-04-17 8672352ecde80a86609fe01195b398ebe77f0ed1
merge many functions
8个文件已修改
1 文件已重命名
5个文件已删除
2294 ■■■■■ 已修改文件
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml 30 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_rnnt.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_train_transducer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/rnnt_decoder.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_transducer.py 535 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_transducer_unified.py 586 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/conformer_encoder.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_predictor/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_predictor/abs_decoder.py 110 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/rnnt_predictor/stateless_decoder.py 145 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/beam_search/beam_search_transducer.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/e2e_asr_common.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 391 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr_transducer.py 477 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -1,32 +1,26 @@
encoder: chunk_conformer
encoder_conf:
    main_conf:
      pos_wise_act_type: swish
      pos_enc_dropout_rate: 0.5
      conv_mod_act_type: swish
      activation_type: swish
      positional_dropout_rate: 0.5
      time_reduction_factor: 2
      unified_model_training: true
      default_chunk_size: 16
      jitter_range: 4
      left_chunk_size: 0
    input_conf:
      block_type: conv2d
      conv_size: 512
      embed_vgg_like: false
      subsampling_factor: 4
      num_frame: 1
    body_conf:
    - block_type: conformer
      linear_size: 2048
      hidden_size: 512
      heads: 8
      linear_units: 2048
      output_size: 512
      attention_heads: 8
      dropout_rate: 0.5
      pos_wise_dropout_rate: 0.5
      att_dropout_rate: 0.5
      conv_mod_kernel_size: 15
      positional_dropout_rate: 0.5
      attention_dropout_rate: 0.5
      cnn_module_kernel: 15
      num_blocks: 12    
# decoder related
decoder: rnn
decoder_conf:
rnnt_decoder: rnnt
rnnt_decoder_conf:
    embed_size: 512
    hidden_size: 512
    embed_dropout_rate: 0.5
funasr/bin/asr_inference_rnnt.py
@@ -22,7 +22,7 @@
)
from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.asr_transducer import ASRTransducerTask
from funasr.tasks.asr import ASRTransducerTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
funasr/bin/asr_train_transducer.py
@@ -2,7 +2,7 @@
import os
from funasr.tasks.asr_transducer import ASRTransducerTask
from funasr.tasks.asr import ASRTransducerTask
# for ASR Training
funasr/models/decoder/rnnt_decoder.py
File was renamed from funasr/models/rnnt_predictor/rnn_decoder.py
@@ -6,10 +6,9 @@
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class RNNDecoder(AbsDecoder):
class RNNTDecoder(torch.nn.Module):
    """RNN decoder module.
    Args:
funasr/models/e2e_asr_transducer.py
@@ -10,7 +10,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -63,9 +63,9 @@
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: Encoder,
        decoder: AbsDecoder,
        att_decoder: Optional[AbsAttDecoder],
        decoder: RNNTDecoder,
        joint_network: JointNetwork,
        att_decoder: Optional[AbsAttDecoder] = None,
        transducer_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
@@ -482,3 +482,532 @@
        )
        return loss_lm
class UnifiedTransducerModel(AbsESPnetModel):
    """ESPnet2ASRTransducerModel module definition.
    Args:
        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
        token_list: List of token
        frontend: Frontend module.
        specaug: SpecAugment module.
        normalize: Normalization module.
        encoder: Encoder module.
        decoder: Decoder module.
        joint_network: Joint Network module.
        transducer_weight: Weight of the Transducer loss.
        fastemit_lambda: FastEmit lambda value.
        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
        ignore_id: Initial padding ID.
        sym_space: Space symbol.
        sym_blank: Blank Symbol
        report_cer: Whether to report Character Error Rate during validation.
        report_wer: Whether to report Word Error Rate during validation.
        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
    """
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: Encoder,
        decoder: RNNTDecoder,
        joint_network: JointNetwork,
        att_decoder: Optional[AbsAttDecoder] = None,
        transducer_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
        auxiliary_att_weight: float = 0.0,
        auxiliary_ctc_dropout_rate: float = 0.0,
        auxiliary_lm_loss_weight: float = 0.0,
        auxiliary_lm_loss_smoothing: float = 0.0,
        ignore_id: int = -1,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        report_cer: bool = True,
        report_wer: bool = True,
        sym_sos: str = "<sos/eos>",
        sym_eos: str = "<sos/eos>",
        extract_feats_in_collect_stats: bool = True,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
    ) -> None:
        """Construct an ESPnetASRTransducerModel object."""
        super().__init__()
        assert check_argument_types()
        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
        self.blank_id = 0
        if sym_sos in token_list:
            self.sos = token_list.index(sym_sos)
        else:
            self.sos = vocab_size - 1
        if sym_eos in token_list:
            self.eos = token_list.index(sym_eos)
        else:
            self.eos = vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.token_list = token_list.copy()
        self.sym_space = sym_space
        self.sym_blank = sym_blank
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        self.decoder = decoder
        self.joint_network = joint_network
        self.criterion_transducer = None
        self.error_calculator = None
        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
        self.use_auxiliary_att = auxiliary_att_weight > 0
        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
        if self.use_auxiliary_ctc:
            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
        if self.use_auxiliary_att:
            self.att_decoder = att_decoder
            self.criterion_att = LabelSmoothingLoss(
                size=vocab_size,
                padding_idx=ignore_id,
                smoothing=lsm_weight,
                normalize_length=length_normalized_loss,
            )
        if self.use_auxiliary_lm_loss:
            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
        self.transducer_weight = transducer_weight
        self.fastemit_lambda = fastemit_lambda
        self.auxiliary_ctc_weight = auxiliary_ctc_weight
        self.auxiliary_att_weight = auxiliary_att_weight
        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
        self.report_cer = report_cer
        self.report_wer = report_wer
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Forward architecture and compute loss(es).
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
            text: Label ID sequences. (B, L)
            text_lengths: Label ID sequences lengths. (B,)
            kwargs: Contains "utts_id".
        Return:
            loss: Main loss value.
            stats: Task statistics.
            weight: Task weights.
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        text = text[:, : text_lengths.max()]
        #print(speech.shape)
        # 1. Encoder
        encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_att, loss_att_chunk = 0.0, 0.0
        if self.use_auxiliary_att:
            loss_att, _ = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            loss_att_chunk, _ = self._calc_att_loss(
                encoder_out_chunk, encoder_out_lens, text, text_lengths
            )
        # 2. Transducer-related I/O preparation
        decoder_in, target, t_len, u_len = get_transducer_task_io(
            text,
            encoder_out_lens,
            ignore_id=self.ignore_id,
        )
        # 3. Decoder
        self.decoder.set_device(encoder_out.device)
        decoder_out = self.decoder(decoder_in, u_len)
        # 4. Joint Network
        joint_out = self.joint_network(
            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
        )
        joint_out_chunk = self.joint_network(
            encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
        )
        # 5. Losses
        loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
            encoder_out,
            joint_out,
            target,
            t_len,
            u_len,
        )
        loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
            encoder_out_chunk,
            joint_out_chunk,
            target,
            t_len,
            u_len,
        )
        loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
        if self.use_auxiliary_ctc:
            loss_ctc = self._calc_ctc_loss(
                encoder_out,
                target,
                t_len,
                u_len,
            )
            loss_ctc_chunk = self._calc_ctc_loss(
                encoder_out_chunk,
                target,
                t_len,
                u_len,
            )
        if self.use_auxiliary_lm_loss:
            loss_lm = self._calc_lm_loss(decoder_out, target)
        loss_trans = loss_trans_utt + loss_trans_chunk
        loss_ctc = loss_ctc + loss_ctc_chunk
        loss_ctc = loss_att + loss_att_chunk
        loss = (
            self.transducer_weight * loss_trans
            + self.auxiliary_ctc_weight * loss_ctc
            + self.auxiliary_att_weight * loss_att
            + self.auxiliary_lm_loss_weight * loss_lm
        )
        stats = dict(
            loss=loss.detach(),
            loss_transducer=loss_trans_utt.detach(),
            loss_transducer_chunk=loss_trans_chunk.detach(),
            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
            aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
            aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
            aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
            cer_transducer=cer_trans,
            wer_transducer=wer_trans,
            cer_transducer_chunk=cer_trans_chunk,
            wer_transducer_chunk=wer_trans_chunk,
        )
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        """Collect features sequences and features lengths sequences.
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
            text: Label ID sequences. (B, L)
            text_lengths: Label ID sequences lengths. (B,)
            kwargs: Contains "utts_id".
        Return:
            {}: "feats": Features sequences. (B, T, D_feats),
                "feats_lengths": Features sequences lengths. (B,)
        """
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder speech sequences.
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
        Return:
            encoder_out: Encoder outputs. (B, T, D_enc)
            encoder_out_lens: Encoder outputs lengths. (B,)
        """
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
            # 2. Data augmentation
            if self.specaug is not None and self.training:
                feats, feats_lengths = self.specaug(feats, feats_lengths)
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # 4. Forward encoder
        encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        assert encoder_out.size(1) <= encoder_out_lens.max(), (
            encoder_out.size(),
            encoder_out_lens.max(),
        )
        return encoder_out, encoder_out_chunk, encoder_out_lens
    def _extract_feats(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract features sequences and features sequences lengths.
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
        Return:
            feats: Features sequences. (B, T, D_feats)
            feats_lengths: Features sequences lengths. (B,)
        """
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # for data-parallel
        speech = speech[:, : speech_lengths.max()]
        if self.frontend is not None:
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats, feats_lengths = speech, speech_lengths
        return feats, feats_lengths
    def _calc_transducer_loss(
        self,
        encoder_out: torch.Tensor,
        joint_out: torch.Tensor,
        target: torch.Tensor,
        t_len: torch.Tensor,
        u_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
        """Compute Transducer loss.
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            joint_out: Joint Network output sequences (B, T, U, D_joint)
            target: Target label ID sequences. (B, L)
            t_len: Encoder output sequences lengths. (B,)
            u_len: Target label ID sequences lengths. (B,)
        Return:
            loss_transducer: Transducer loss value.
            cer_transducer: Character error rate for Transducer.
            wer_transducer: Word Error Rate for Transducer.
        """
        if self.criterion_transducer is None:
            try:
                # from warprnnt_pytorch import RNNTLoss
            # self.criterion_transducer = RNNTLoss(
                    # reduction="mean",
                    # fastemit_lambda=self.fastemit_lambda,
                # )
                from warp_rnnt import rnnt_loss as RNNTLoss
                self.criterion_transducer = RNNTLoss
            except ImportError:
                logging.error(
                    "warp-rnnt was not installed."
                    "Please consult the installation documentation."
                )
                exit(1)
        # loss_transducer = self.criterion_transducer(
        #     joint_out,
        #     target,
        #     t_len,
        #     u_len,
        # )
        log_probs = torch.log_softmax(joint_out, dim=-1)
        loss_transducer = self.criterion_transducer(
                log_probs,
                target,
                t_len,
                u_len,
                reduction="mean",
                blank=self.blank_id,
                fastemit_lambda=self.fastemit_lambda,
                gather=True,
        )
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                self.error_calculator = ErrorCalculator(
                    self.decoder,
                    self.joint_network,
                    self.token_list,
                    self.sym_space,
                    self.sym_blank,
                    report_cer=self.report_cer,
                    report_wer=self.report_wer,
                )
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
            return loss_transducer, cer_transducer, wer_transducer
        return loss_transducer, None, None
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        target: torch.Tensor,
        t_len: torch.Tensor,
        u_len: torch.Tensor,
    ) -> torch.Tensor:
        """Compute CTC loss.
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            target: Target label ID sequences. (B, L)
            t_len: Encoder output sequences lengths. (B,)
            u_len: Target label ID sequences lengths. (B,)
        Return:
            loss_ctc: CTC loss value.
        """
        ctc_in = self.ctc_lin(
            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
        )
        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
        target_mask = target != 0
        ctc_target = target[target_mask].cpu()
        with torch.backends.cudnn.flags(deterministic=True):
            loss_ctc = torch.nn.functional.ctc_loss(
                ctc_in,
                ctc_target,
                t_len,
                u_len,
                zero_infinity=True,
                reduction="sum",
            )
        loss_ctc /= target.size(0)
        return loss_ctc
    def _calc_lm_loss(
        self,
        decoder_out: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """Compute LM loss.
        Args:
            decoder_out: Decoder output sequences. (B, U, D_dec)
            target: Target label ID sequences. (B, L)
        Return:
            loss_lm: LM loss value.
        """
        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
        lm_target = target.view(-1).type(torch.int64)
        with torch.no_grad():
            true_dist = lm_loss_in.clone()
            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
            # Ignore blank ID (0)
            ignore = lm_target == 0
            lm_target = lm_target.masked_fill(ignore, 0)
            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
        loss_lm = torch.nn.functional.kl_div(
            torch.log_softmax(lm_loss_in, dim=1),
            true_dist,
            reduction="none",
        )
        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
            0
        )
        return loss_lm
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
            ys_pad = torch.cat(
                [
                    self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
                    ys_pad,
                ],
                dim=1,
            )
            ys_pad_lens += 1
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
        # 1. Forward decoder
        decoder_out, _ = self.att_decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        return loss_att, acc_att
funasr/models/e2e_asr_transducer_unified.py
File was deleted
funasr/models/encoder/conformer_encoder.py
@@ -894,7 +894,7 @@
        return x, cache
class ConformerChunkEncoder(torch.nn.Module):
class ConformerChunkEncoder(AbsEncoder):
    """Encoder module definition.
    Args:
        input_size: Input size.
@@ -1007,7 +1007,7 @@
            output_size,
        )
        self.output_size = output_size
        self._output_size = output_size
        self.dynamic_chunk_training = dynamic_chunk_training
        self.short_chunk_threshold = short_chunk_threshold
@@ -1020,6 +1020,9 @@
        self.time_reduction_factor = time_reduction_factor
    def output_size(self) -> int:
        return self._output_size
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
funasr/models/rnnt_predictor/__init__.py
funasr/models/rnnt_predictor/abs_decoder.py
File was deleted
funasr/models/rnnt_predictor/stateless_decoder.py
File was deleted
funasr/modules/beam_search/beam_search_transducer.py
@@ -6,7 +6,6 @@
import numpy as np
import torch
from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -68,7 +67,7 @@
    def __init__(
        self,
        decoder: AbsDecoder,
        decoder,
        joint_network: JointNetwork,
        beam_size: int,
        lm: Optional[torch.nn.Module] = None,
funasr/modules/e2e_asr_common.py
@@ -18,7 +18,6 @@
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_net.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
@@ -268,7 +267,7 @@
    def __init__(
        self,
        decoder: AbsDecoder,
        decoder,
        joint_network: JointNetwork,
        token_list: List[int],
        sym_space: str,
funasr/tasks/asr.py
@@ -38,13 +38,16 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
@@ -150,6 +153,7 @@
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        mfcca_enc=MFCCAEncoder,
        chunk_conformer=ConformerChunkEncoder,
    ),
    type_check=AbsEncoder,
    default="rnn",
@@ -207,6 +211,16 @@
    type_check=AbsDecoder,
    default="rnn",
)
rnnt_decoder_choices = ClassChoices(
    "rnnt_decoder",
    classes=dict(
        rnnt=RNNTDecoder,
    ),
    type_check=RNNTDecoder,
    default="rnnt",
)
predictor_choices = ClassChoices(
    name="predictor",
    classes=dict(
@@ -1331,3 +1345,378 @@
    ) -> Tuple[str, ...]:
        retval = ("speech", "text")
        return retval
class ASRTransducerTask(AbsTask):
    """ASR Transducer Task definition."""
    num_optimizers: int = 1
    class_choices_list = [
        frontend_choices,
        specaug_choices,
        normalize_choices,
        encoder_choices,
        rnnt_decoder_choices,
    ]
    trainer = Trainer
    @classmethod
    def add_task_arguments(cls, parser: argparse.ArgumentParser):
        """Add Transducer task arguments.
        Args:
            cls: ASRTransducerTask object.
            parser: Transducer arguments parser.
        """
        group = parser.add_argument_group(description="Task related.")
        # required = parser.get_default("required")
        # required += ["token_list"]
        group.add_argument(
            "--token_list",
            type=str_or_none,
            default=None,
            help="Integer-string mapper for tokens.",
        )
        group.add_argument(
            "--split_with_space",
            type=str2bool,
            default=True,
            help="whether to split text using <space>",
        )
        group.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of dimensions for input features.",
        )
        group.add_argument(
            "--init",
            type=str_or_none,
            default=None,
            help="Type of model initialization to use.",
        )
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(TransducerModel),
            help="The keyword arguments for the model class.",
        )
        # group.add_argument(
        #     "--encoder_conf",
        #     action=NestedDictAction,
        #     default={},
        #     help="The keyword arguments for the encoder class.",
        # )
        group.add_argument(
            "--joint_network_conf",
            action=NestedDictAction,
            default={},
            help="The keyword arguments for the joint network class.",
        )
        group = parser.add_argument_group(description="Preprocess related.")
        group.add_argument(
            "--use_preprocessor",
            type=str2bool,
            default=True,
            help="Whether to apply preprocessing to input data.",
        )
        group.add_argument(
            "--token_type",
            type=str,
            default="bpe",
            choices=["bpe", "char", "word", "phn"],
            help="The type of tokens to use during tokenization.",
        )
        group.add_argument(
            "--bpemodel",
            type=str_or_none,
            default=None,
            help="The path of the sentencepiece model.",
        )
        parser.add_argument(
            "--non_linguistic_symbols",
            type=str_or_none,
            help="The 'non_linguistic_symbols' file path.",
        )
        parser.add_argument(
            "--cleaner",
            type=str_or_none,
            choices=[None, "tacotron", "jaconv", "vietnamese"],
            default=None,
            help="Text cleaner to use.",
        )
        parser.add_argument(
            "--g2p",
            type=str_or_none,
            choices=g2p_choices,
            default=None,
            help="g2p method to use if --token_type=phn.",
        )
        parser.add_argument(
            "--speech_volume_normalize",
            type=float_or_none,
            default=None,
            help="Normalization value for maximum amplitude scaling.",
        )
        parser.add_argument(
            "--rir_scp",
            type=str_or_none,
            default=None,
            help="The RIR SCP file path.",
        )
        parser.add_argument(
            "--rir_apply_prob",
            type=float,
            default=1.0,
            help="The probability of the applied RIR convolution.",
        )
        parser.add_argument(
            "--noise_scp",
            type=str_or_none,
            default=None,
            help="The path of noise SCP file.",
        )
        parser.add_argument(
            "--noise_apply_prob",
            type=float,
            default=1.0,
            help="The probability of the applied noise addition.",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of the noise decibel level.",
        )
        for class_choices in cls.class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --decoder and --decoder_conf
            class_choices.add_arguments(group)
    @classmethod
    def build_collate_fn(
        cls, args: argparse.Namespace, train: bool
    ) -> Callable[
        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
        Tuple[List[str], Dict[str, torch.Tensor]],
    ]:
        """Build collate function.
        Args:
            cls: ASRTransducerTask object.
            args: Task arguments.
            train: Training mode.
        Return:
            : Callable collate function.
        """
        assert check_argument_types()
        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
    @classmethod
    def build_preprocess_fn(
        cls, args: argparse.Namespace, train: bool
    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
        """Build pre-processing function.
        Args:
            cls: ASRTransducerTask object.
            args: Task arguments.
            train: Training mode.
        Return:
            : Callable pre-processing function.
        """
        assert check_argument_types()
        if args.use_preprocessor:
            retval = CommonPreprocessor(
                train=train,
                token_type=args.token_type,
                token_list=args.token_list,
                bpemodel=args.bpemodel,
                non_linguistic_symbols=args.non_linguistic_symbols,
                text_cleaner=args.cleaner,
                g2p_type=args.g2p,
                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                rir_apply_prob=args.rir_apply_prob
                if hasattr(args, "rir_apply_prob")
                else 1.0,
                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
                noise_apply_prob=args.noise_apply_prob
                if hasattr(args, "noise_apply_prob")
                else 1.0,
                noise_db_range=args.noise_db_range
                if hasattr(args, "noise_db_range")
                else "13_15",
                speech_volume_normalize=args.speech_volume_normalize
                if hasattr(args, "rir_scp")
                else None,
            )
        else:
            retval = None
        assert check_return_type(retval)
        return retval
    @classmethod
    def required_data_names(
        cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        """Required data depending on task mode.
        Args:
            cls: ASRTransducerTask object.
            train: Training mode.
            inference: Inference mode.
        Return:
            retval: Required task data.
        """
        if not inference:
            retval = ("speech", "text")
        else:
            retval = ("speech",)
        return retval
    @classmethod
    def optional_data_names(
        cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        """Optional data depending on task mode.
        Args:
            cls: ASRTransducerTask object.
            train: Training mode.
            inference: Inference mode.
        Return:
            retval: Optional task data.
        """
        retval = ()
        assert check_return_type(retval)
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
        """Required data depending on task mode.
        Args:
            cls: ASRTransducerTask object.
            args: Task arguments.
        Return:
            model: ASR Transducer model.
        """
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]
            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            frontend = None
            input_size = args.input_size
        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None
        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None
        # 4. Encoder
        if getattr(args, "encoder", None) is not None:
            encoder_class = encoder_choices.get_class(args.encoder)
            encoder = encoder_class(input_size, **args.encoder_conf)
        else:
            encoder = Encoder(input_size, **args.encoder_conf)
        encoder_output_size = encoder.output_size()
        # 5. Decoder
        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
        decoder = rnnt_decoder_class(
            vocab_size,
            **args.rnnt_decoder_conf,
        )
        decoder_output_size = decoder.output_size
        if getattr(args, "decoder", None) is not None:
            att_decoder_class = decoder_choices.get_class(args.att_decoder)
            att_decoder = att_decoder_class(
                vocab_size=vocab_size,
                encoder_output_size=encoder_output_size,
                **args.decoder_conf,
            )
        else:
            att_decoder = None
        # 6. Joint Network
        joint_network = JointNetwork(
            vocab_size,
            encoder_output_size,
            decoder_output_size,
            **args.joint_network_conf,
        )
        # 7. Build model
        if encoder.unified_model_training:
            model = UnifiedTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        else:
            model = TransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        # 8. Initialize model
        if args.init is not None:
            raise NotImplementedError(
                "Currently not supported.",
                "Initialization part will be reworked in a short future.",
            )
        #assert check_return_type(model)
        return model
funasr/tasks/asr_transducer.py
File was deleted