From 8672352ecde80a86609fe01195b398ebe77f0ed1 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期一, 17 四月 2023 16:09:23 +0800
Subject: [PATCH] merge many functions
---
/dev/null | 477 -------------------
funasr/bin/asr_train_transducer.py | 2
funasr/models/encoder/conformer_encoder.py | 7
funasr/modules/e2e_asr_common.py | 3
funasr/modules/beam_search/beam_search_transducer.py | 3
funasr/bin/asr_inference_rnnt.py | 2
funasr/tasks/asr.py | 391 +++++++++++++++
funasr/models/decoder/rnnt_decoder.py | 3
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 30
funasr/models/e2e_asr_transducer.py | 535 +++++++++++++++++++++
10 files changed, 944 insertions(+), 509 deletions(-)
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index 60f796c..8a1c40c 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -1,32 +1,26 @@
+encoder: chunk_conformer
encoder_conf:
- main_conf:
- pos_wise_act_type: swish
- pos_enc_dropout_rate: 0.5
- conv_mod_act_type: swish
+ activation_type: swish
+ positional_dropout_rate: 0.5
time_reduction_factor: 2
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
left_chunk_size: 0
- input_conf:
- block_type: conv2d
- conv_size: 512
+ embed_vgg_like: false
subsampling_factor: 4
- num_frame: 1
- body_conf:
- - block_type: conformer
- linear_size: 2048
- hidden_size: 512
- heads: 8
+ linear_units: 2048
+ output_size: 512
+ attention_heads: 8
dropout_rate: 0.5
- pos_wise_dropout_rate: 0.5
- att_dropout_rate: 0.5
- conv_mod_kernel_size: 15
+ positional_dropout_rate: 0.5
+ attention_dropout_rate: 0.5
+ cnn_module_kernel: 15
num_blocks: 12
# decoder related
-decoder: rnn
-decoder_conf:
+rnnt_decoder: rnnt
+rnnt_decoder_conf:
embed_size: 512
hidden_size: 512
embed_dropout_rate: 0.5
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 465f882..bff8702 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -22,7 +22,7 @@
)
from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.asr_transducer import ASRTransducerTask
+from funasr.tasks.asr import ASRTransducerTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
index 9b6d287..fe418db 100755
--- a/funasr/bin/asr_train_transducer.py
+++ b/funasr/bin/asr_train_transducer.py
@@ -2,7 +2,7 @@
import os
-from funasr.tasks.asr_transducer import ASRTransducerTask
+from funasr.tasks.asr import ASRTransducerTask
# for ASR Training
diff --git a/funasr/models/rnnt_predictor/rnn_decoder.py b/funasr/models/decoder/rnnt_decoder.py
similarity index 98%
rename from funasr/models/rnnt_predictor/rnn_decoder.py
rename to funasr/models/decoder/rnnt_decoder.py
index 0df6fc7..5401ab2 100644
--- a/funasr/models/rnnt_predictor/rnn_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -6,10 +6,9 @@
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
-class RNNDecoder(AbsDecoder):
+class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
Args:
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index 6eb0023..0cae306 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -10,7 +10,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -63,9 +63,9 @@
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: Encoder,
- decoder: AbsDecoder,
- att_decoder: Optional[AbsAttDecoder],
+ decoder: RNNTDecoder,
joint_network: JointNetwork,
+ att_decoder: Optional[AbsAttDecoder] = None,
transducer_weight: float = 1.0,
fastemit_lambda: float = 0.0,
auxiliary_ctc_weight: float = 0.0,
@@ -482,3 +482,532 @@
)
return loss_lm
+
+class UnifiedTransducerModel(AbsESPnetModel):
+ """ESPnet2ASRTransducerModel module definition.
+ Args:
+ vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+ token_list: List of token
+ frontend: Frontend module.
+ specaug: SpecAugment module.
+ normalize: Normalization module.
+ encoder: Encoder module.
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ transducer_weight: Weight of the Transducer loss.
+ fastemit_lambda: FastEmit lambda value.
+ auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+ auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+ auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+ auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+ ignore_id: Initial padding ID.
+ sym_space: Space symbol.
+ sym_blank: Blank Symbol
+ report_cer: Whether to report Character Error Rate during validation.
+ report_wer: Whether to report Word Error Rate during validation.
+ extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: Encoder,
+ decoder: RNNTDecoder,
+ joint_network: JointNetwork,
+ att_decoder: Optional[AbsAttDecoder] = None,
+ transducer_weight: float = 1.0,
+ fastemit_lambda: float = 0.0,
+ auxiliary_ctc_weight: float = 0.0,
+ auxiliary_att_weight: float = 0.0,
+ auxiliary_ctc_dropout_rate: float = 0.0,
+ auxiliary_lm_loss_weight: float = 0.0,
+ auxiliary_lm_loss_smoothing: float = 0.0,
+ ignore_id: int = -1,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_sos: str = "<sos/eos>",
+ sym_eos: str = "<sos/eos>",
+ extract_feats_in_collect_stats: bool = True,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ ) -> None:
+ """Construct an ESPnetASRTransducerModel object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+ self.blank_id = 0
+
+ if sym_sos in token_list:
+ self.sos = token_list.index(sym_sos)
+ else:
+ self.sos = vocab_size - 1
+ if sym_eos in token_list:
+ self.eos = token_list.index(sym_eos)
+ else:
+ self.eos = vocab_size - 1
+
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.token_list = token_list.copy()
+
+ self.sym_space = sym_space
+ self.sym_blank = sym_blank
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_att = auxiliary_att_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_att:
+ self.att_decoder = att_decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_att_weight = auxiliary_att_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Forward architecture and compute loss(es).
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+ Return:
+ loss: Main loss value.
+ stats: Task statistics.
+ weight: Task weights.
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+ batch_size = speech.shape[0]
+ text = text[:, : text_lengths.max()]
+ #print(speech.shape)
+ # 1. Encoder
+ encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_att, loss_att_chunk = 0.0, 0.0
+
+ if self.use_auxiliary_att:
+ loss_att, _ = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+ loss_att_chunk, _ = self._calc_att_loss(
+ encoder_out_chunk, encoder_out_lens, text, text_lengths
+ )
+
+ # 2. Transducer-related I/O preparation
+ decoder_in, target, t_len, u_len = get_transducer_task_io(
+ text,
+ encoder_out_lens,
+ ignore_id=self.ignore_id,
+ )
+
+ # 3. Decoder
+ self.decoder.set_device(encoder_out.device)
+ decoder_out = self.decoder(decoder_in, u_len)
+
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ joint_out_chunk = self.joint_network(
+ encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ # 5. Losses
+ loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
+ encoder_out_chunk,
+ joint_out_chunk,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
+ )
+ loss_ctc_chunk = self._calc_ctc_loss(
+ encoder_out_chunk,
+ target,
+ t_len,
+ u_len,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss_trans = loss_trans_utt + loss_trans_chunk
+ loss_ctc = loss_ctc + loss_ctc_chunk
+ loss_ctc = loss_att + loss_att_chunk
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_att_weight * loss_att
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans_utt.detach(),
+ loss_transducer_chunk=loss_trans_chunk.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
+ aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
+ aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ cer_transducer_chunk=cer_trans_chunk,
+ wer_transducer_chunk=wer_trans_chunk,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ """Collect features sequences and features lengths sequences.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+ Return:
+ {}: "feats": Features sequences. (B, T, D_feats),
+ "feats_lengths": Features sequences lengths. (B,)
+ """
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+
+ feats, feats_lengths = speech, speech_lengths
+
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder speech sequences.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ Return:
+ encoder_out: Encoder outputs. (B, T, D_enc)
+ encoder_out_lens: Encoder outputs lengths. (B,)
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # 4. Forward encoder
+ encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_chunk, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract features sequences and features sequences lengths.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ Return:
+ feats: Features sequences. (B, T, D_feats)
+ feats_lengths: Features sequences lengths. (B,)
+ """
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats, feats_lengths = speech, speech_lengths
+
+ return feats, feats_lengths
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
+ """
+ if self.criterion_transducer is None:
+ try:
+ # from warprnnt_pytorch import RNNTLoss
+ # self.criterion_transducer = RNNTLoss(
+ # reduction="mean",
+ # fastemit_lambda=self.fastemit_lambda,
+ # )
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ # loss_transducer = self.criterion_transducer(
+ # joint_out,
+ # target,
+ # t_len,
+ # u_len,
+ # )
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
+ if not self.training and (self.report_cer or self.report_wer):
+ if self.error_calculator is None:
+ self.error_calculator = ErrorCalculator(
+ self.decoder,
+ self.joint_network,
+ self.token_list,
+ self.sym_space,
+ self.sym_blank,
+ report_cer=self.report_cer,
+ report_wer=self.report_wer,
+ )
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+ Return:
+ loss_ctc: CTC loss value.
+ """
+ ctc_in = self.ctc_lin(
+ torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+ )
+ ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+ target_mask = target != 0
+ ctc_target = target[target_mask].cpu()
+
+ with torch.backends.cudnn.flags(deterministic=True):
+ loss_ctc = torch.nn.functional.ctc_loss(
+ ctc_in,
+ ctc_target,
+ t_len,
+ u_len,
+ zero_infinity=True,
+ reduction="sum",
+ )
+ loss_ctc /= target.size(0)
+
+ return loss_ctc
+
+ def _calc_lm_loss(
+ self,
+ decoder_out: torch.Tensor,
+ target: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute LM loss.
+ Args:
+ decoder_out: Decoder output sequences. (B, U, D_dec)
+ target: Target label ID sequences. (B, L)
+ Return:
+ loss_lm: LM loss value.
+ """
+ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+ lm_target = target.view(-1).type(torch.int64)
+
+ with torch.no_grad():
+ true_dist = lm_loss_in.clone()
+ true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+ # Ignore blank ID (0)
+ ignore = lm_target == 0
+ lm_target = lm_target.masked_fill(ignore, 0)
+
+ true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+ loss_lm = torch.nn.functional.kl_div(
+ torch.log_softmax(lm_loss_in, dim=1),
+ true_dist,
+ reduction="none",
+ )
+ loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+ 0
+ )
+
+ return loss_lm
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
+ ys_pad = torch.cat(
+ [
+ self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
+ ys_pad,
+ ],
+ dim=1,
+ )
+ ys_pad_lens += 1
+
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.att_decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ return loss_att, acc_att
diff --git a/funasr/models/e2e_asr_transducer_unified.py b/funasr/models/e2e_asr_transducer_unified.py
deleted file mode 100644
index ad61d12..0000000
--- a/funasr/models/e2e_asr_transducer_unified.py
+++ /dev/null
@@ -1,586 +0,0 @@
-"""ESPnet2 ASR Transducer model."""
-
-import logging
-from contextlib import contextmanager
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from packaging.version import parse as V
-from typeguard import check_argument_types
-
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.modules.nets_utils import get_transducer_task_io
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.modules.nets_utils import th_accuracy
-from funasr.losses.label_smoothing_loss import ( # noqa: H301
- LabelSmoothingLoss,
-)
-from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
-if V(torch.__version__) >= V("1.6.0"):
- from torch.cuda.amp import autocast
-else:
-
- @contextmanager
- def autocast(enabled=True):
- yield
-
-
-class UnifiedTransducerModel(AbsESPnetModel):
- """ESPnet2ASRTransducerModel module definition.
-
- Args:
- vocab_size: Size of complete vocabulary (w/ EOS and blank included).
- token_list: List of token
- frontend: Frontend module.
- specaug: SpecAugment module.
- normalize: Normalization module.
- encoder: Encoder module.
- decoder: Decoder module.
- joint_network: Joint Network module.
- transducer_weight: Weight of the Transducer loss.
- fastemit_lambda: FastEmit lambda value.
- auxiliary_ctc_weight: Weight of auxiliary CTC loss.
- auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
- auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
- auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
- ignore_id: Initial padding ID.
- sym_space: Space symbol.
- sym_blank: Blank Symbol
- report_cer: Whether to report Character Error Rate during validation.
- report_wer: Whether to report Word Error Rate during validation.
- extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
-
- """
-
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: Encoder,
- decoder: AbsDecoder,
- att_decoder: Optional[AbsAttDecoder],
- joint_network: JointNetwork,
- transducer_weight: float = 1.0,
- fastemit_lambda: float = 0.0,
- auxiliary_ctc_weight: float = 0.0,
- auxiliary_att_weight: float = 0.0,
- auxiliary_ctc_dropout_rate: float = 0.0,
- auxiliary_lm_loss_weight: float = 0.0,
- auxiliary_lm_loss_smoothing: float = 0.0,
- ignore_id: int = -1,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- report_cer: bool = True,
- report_wer: bool = True,
- sym_sos: str = "<sos/eos>",
- sym_eos: str = "<sos/eos>",
- extract_feats_in_collect_stats: bool = True,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- ) -> None:
- """Construct an ESPnetASRTransducerModel object."""
- super().__init__()
-
- assert check_argument_types()
-
- # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
- self.blank_id = 0
-
- if sym_sos in token_list:
- self.sos = token_list.index(sym_sos)
- else:
- self.sos = vocab_size - 1
- if sym_eos in token_list:
- self.eos = token_list.index(sym_eos)
- else:
- self.eos = vocab_size - 1
-
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.token_list = token_list.copy()
-
- self.sym_space = sym_space
- self.sym_blank = sym_blank
-
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
-
- self.encoder = encoder
- self.decoder = decoder
- self.joint_network = joint_network
-
- self.criterion_transducer = None
- self.error_calculator = None
-
- self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
- self.use_auxiliary_att = auxiliary_att_weight > 0
- self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
-
- if self.use_auxiliary_ctc:
- self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
- self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
-
- if self.use_auxiliary_att:
- self.att_decoder = att_decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
-
- if self.use_auxiliary_lm_loss:
- self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
- self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
-
- self.transducer_weight = transducer_weight
- self.fastemit_lambda = fastemit_lambda
-
- self.auxiliary_ctc_weight = auxiliary_ctc_weight
- self.auxiliary_att_weight = auxiliary_att_weight
- self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
-
- self.report_cer = report_cer
- self.report_wer = report_wer
-
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Forward architecture and compute loss(es).
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
- text: Label ID sequences. (B, L)
- text_lengths: Label ID sequences lengths. (B,)
- kwargs: Contains "utts_id".
-
- Return:
- loss: Main loss value.
- stats: Task statistics.
- weight: Task weights.
-
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-
- batch_size = speech.shape[0]
- text = text[:, : text_lengths.max()]
- #print(speech.shape)
- # 1. Encoder
- encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_att, loss_att_chunk = 0.0, 0.0
-
- if self.use_auxiliary_att:
- loss_att, _ = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- loss_att_chunk, _ = self._calc_att_loss(
- encoder_out_chunk, encoder_out_lens, text, text_lengths
- )
-
- # 2. Transducer-related I/O preparation
- decoder_in, target, t_len, u_len = get_transducer_task_io(
- text,
- encoder_out_lens,
- ignore_id=self.ignore_id,
- )
-
- # 3. Decoder
- self.decoder.set_device(encoder_out.device)
- decoder_out = self.decoder(decoder_in, u_len)
-
- # 4. Joint Network
- joint_out = self.joint_network(
- encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
- )
-
- joint_out_chunk = self.joint_network(
- encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
- )
-
- # 5. Losses
- loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
- encoder_out,
- joint_out,
- target,
- t_len,
- u_len,
- )
-
- loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
- encoder_out_chunk,
- joint_out_chunk,
- target,
- t_len,
- u_len,
- )
-
- loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
-
- if self.use_auxiliary_ctc:
- loss_ctc = self._calc_ctc_loss(
- encoder_out,
- target,
- t_len,
- u_len,
- )
- loss_ctc_chunk = self._calc_ctc_loss(
- encoder_out_chunk,
- target,
- t_len,
- u_len,
- )
-
- if self.use_auxiliary_lm_loss:
- loss_lm = self._calc_lm_loss(decoder_out, target)
-
- loss_trans = loss_trans_utt + loss_trans_chunk
- loss_ctc = loss_ctc + loss_ctc_chunk
- loss_ctc = loss_att + loss_att_chunk
-
- loss = (
- self.transducer_weight * loss_trans
- + self.auxiliary_ctc_weight * loss_ctc
- + self.auxiliary_att_weight * loss_att
- + self.auxiliary_lm_loss_weight * loss_lm
- )
-
- stats = dict(
- loss=loss.detach(),
- loss_transducer=loss_trans_utt.detach(),
- loss_transducer_chunk=loss_trans_chunk.detach(),
- aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
- aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
- aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
- aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
- aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
- cer_transducer=cer_trans,
- wer_transducer=wer_trans,
- cer_transducer_chunk=cer_trans_chunk,
- wer_transducer_chunk=wer_trans_chunk,
- )
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Dict[str, torch.Tensor]:
- """Collect features sequences and features lengths sequences.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
- text: Label ID sequences. (B, L)
- text_lengths: Label ID sequences lengths. (B,)
- kwargs: Contains "utts_id".
-
- Return:
- {}: "feats": Features sequences. (B, T, D_feats),
- "feats_lengths": Features sequences lengths. (B,)
-
- """
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
-
- feats, feats_lengths = speech, speech_lengths
-
- return {"feats": feats, "feats_lengths": feats_lengths}
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encoder speech sequences.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
-
- Return:
- encoder_out: Encoder outputs. (B, T, D_enc)
- encoder_out_lens: Encoder outputs lengths. (B,)
-
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # 4. Forward encoder
- encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
-
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
-
- return encoder_out, encoder_out_chunk, encoder_out_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Extract features sequences and features sequences lengths.
-
- Args:
- speech: Speech sequences. (B, S)
- speech_lengths: Speech sequences lengths. (B,)
-
- Return:
- feats: Features sequences. (B, T, D_feats)
- feats_lengths: Features sequences lengths. (B,)
-
- """
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
-
- if self.frontend is not None:
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- feats, feats_lengths = speech, speech_lengths
-
- return feats, feats_lengths
-
- def _calc_transducer_loss(
- self,
- encoder_out: torch.Tensor,
- joint_out: torch.Tensor,
- target: torch.Tensor,
- t_len: torch.Tensor,
- u_len: torch.Tensor,
- ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
- """Compute Transducer loss.
-
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- joint_out: Joint Network output sequences (B, T, U, D_joint)
- target: Target label ID sequences. (B, L)
- t_len: Encoder output sequences lengths. (B,)
- u_len: Target label ID sequences lengths. (B,)
-
- Return:
- loss_transducer: Transducer loss value.
- cer_transducer: Character error rate for Transducer.
- wer_transducer: Word Error Rate for Transducer.
-
- """
- if self.criterion_transducer is None:
- try:
- # from warprnnt_pytorch import RNNTLoss
- # self.criterion_transducer = RNNTLoss(
- # reduction="mean",
- # fastemit_lambda=self.fastemit_lambda,
- # )
- from warp_rnnt import rnnt_loss as RNNTLoss
- self.criterion_transducer = RNNTLoss
-
- except ImportError:
- logging.error(
- "warp-rnnt was not installed."
- "Please consult the installation documentation."
- )
- exit(1)
-
- # loss_transducer = self.criterion_transducer(
- # joint_out,
- # target,
- # t_len,
- # u_len,
- # )
- log_probs = torch.log_softmax(joint_out, dim=-1)
-
- loss_transducer = self.criterion_transducer(
- log_probs,
- target,
- t_len,
- u_len,
- reduction="mean",
- blank=self.blank_id,
- fastemit_lambda=self.fastemit_lambda,
- gather=True,
- )
-
- if not self.training and (self.report_cer or self.report_wer):
- if self.error_calculator is None:
- self.error_calculator = ErrorCalculator(
- self.decoder,
- self.joint_network,
- self.token_list,
- self.sym_space,
- self.sym_blank,
- report_cer=self.report_cer,
- report_wer=self.report_wer,
- )
-
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
- return loss_transducer, cer_transducer, wer_transducer
-
- return loss_transducer, None, None
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- target: torch.Tensor,
- t_len: torch.Tensor,
- u_len: torch.Tensor,
- ) -> torch.Tensor:
- """Compute CTC loss.
-
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- target: Target label ID sequences. (B, L)
- t_len: Encoder output sequences lengths. (B,)
- u_len: Target label ID sequences lengths. (B,)
-
- Return:
- loss_ctc: CTC loss value.
-
- """
- ctc_in = self.ctc_lin(
- torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
- )
- ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
-
- target_mask = target != 0
- ctc_target = target[target_mask].cpu()
-
- with torch.backends.cudnn.flags(deterministic=True):
- loss_ctc = torch.nn.functional.ctc_loss(
- ctc_in,
- ctc_target,
- t_len,
- u_len,
- zero_infinity=True,
- reduction="sum",
- )
- loss_ctc /= target.size(0)
-
- return loss_ctc
-
- def _calc_lm_loss(
- self,
- decoder_out: torch.Tensor,
- target: torch.Tensor,
- ) -> torch.Tensor:
- """Compute LM loss.
-
- Args:
- decoder_out: Decoder output sequences. (B, U, D_dec)
- target: Target label ID sequences. (B, L)
-
- Return:
- loss_lm: LM loss value.
-
- """
- lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
- lm_target = target.view(-1).type(torch.int64)
-
- with torch.no_grad():
- true_dist = lm_loss_in.clone()
- true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
-
- # Ignore blank ID (0)
- ignore = lm_target == 0
- lm_target = lm_target.masked_fill(ignore, 0)
-
- true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
-
- loss_lm = torch.nn.functional.kl_div(
- torch.log_softmax(lm_loss_in, dim=1),
- true_dist,
- reduction="none",
- )
- loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
- 0
- )
-
- return loss_lm
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
- ys_pad = torch.cat(
- [
- self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
- ys_pad,
- ],
- dim=1,
- )
- ys_pad_lens += 1
-
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.att_decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- )
-
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
- acc_att = th_accuracy(
- decoder_out.view(-1, self.vocab_size),
- ys_out_pad,
- ignore_label=self.ignore_id,
- )
-
- return loss_att, acc_att
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index b7b552c..9777cee 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -894,7 +894,7 @@
return x, cache
-class ConformerChunkEncoder(torch.nn.Module):
+class ConformerChunkEncoder(AbsEncoder):
"""Encoder module definition.
Args:
input_size: Input size.
@@ -1007,7 +1007,7 @@
output_size,
)
- self.output_size = output_size
+ self._output_size = output_size
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
@@ -1020,6 +1020,9 @@
self.time_reduction_factor = time_reduction_factor
+ def output_size(self) -> int:
+ return self._output_size
+
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
diff --git a/funasr/models/rnnt_predictor/__init__.py b/funasr/models/rnnt_predictor/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/rnnt_predictor/__init__.py
+++ /dev/null
diff --git a/funasr/models/rnnt_predictor/abs_decoder.py b/funasr/models/rnnt_predictor/abs_decoder.py
deleted file mode 100644
index 5b4a335..0000000
--- a/funasr/models/rnnt_predictor/abs_decoder.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""Abstract decoder definition for Transducer models."""
-
-from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Tuple
-
-import torch
-
-
-class AbsDecoder(torch.nn.Module, ABC):
- """Abstract decoder module."""
-
- @abstractmethod
- def forward(self, labels: torch.Tensor) -> torch.Tensor:
- """Encode source label sequences.
-
- Args:
- labels: Label ID sequences. (B, L)
-
- Returns:
- dec_out: Decoder output sequences. (B, T, D_dec)
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def score(
- self,
- label: torch.Tensor,
- label_sequence: List[int],
- dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]],
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
- """One-step forward hypothesis.
-
- Args:
- label: Previous label. (1, 1)
- label_sequence: Current label sequence.
- dec_state: Previous decoder hidden states.
- ((N, 1, D_dec), (N, 1, D_dec) or None) or None
-
- Returns:
- dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb)
- dec_state: Decoder hidden states.
- ((N, 1, D_dec), (N, 1, D_dec) or None) or None
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def batch_score(
- self,
- hyps: List[Any],
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
- """One-step forward hypotheses.
-
- Args:
- hyps: Hypotheses.
-
- Returns:
- dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb)
- states: Decoder hidden states.
- ((N, B, D_dec), (N, B, D_dec) or None) or None
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def set_device(self, device: torch.Tensor) -> None:
- """Set GPU device to use.
-
- Args:
- device: Device ID.
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def init_state(
- self, batch_size: int
- ) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]:
- """Initialize decoder states.
-
- Args:
- batch_size: Batch size.
-
- Returns:
- : Initial decoder hidden states.
- ((N, B, D_dec), (N, B, D_dec) or None) or None
-
- """
- raise NotImplementedError
-
- @abstractmethod
- def select_state(
- self,
- states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
- idx: int = 0,
- ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
- """Get specified ID state from batch of states, if provided.
-
- Args:
- states: Decoder hidden states.
- ((N, B, D_dec), (N, B, D_dec) or None) or None
- idx: State ID to extract.
-
- Returns:
- : Decoder hidden state for given ID.
- ((N, 1, D_dec), (N, 1, D_dec) or None) or None
-
- """
- raise NotImplementedError
diff --git a/funasr/models/rnnt_predictor/stateless_decoder.py b/funasr/models/rnnt_predictor/stateless_decoder.py
deleted file mode 100644
index 70cd877..0000000
--- a/funasr/models/rnnt_predictor/stateless_decoder.py
+++ /dev/null
@@ -1,145 +0,0 @@
-"""Stateless decoder definition for Transducer models."""
-
-from typing import List, Optional, Tuple
-
-import torch
-from typeguard import check_argument_types
-
-from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
-from funasr.models.specaug.specaug import SpecAug
-
-class StatelessDecoder(AbsDecoder):
- """Stateless Transducer decoder module.
-
- Args:
- vocab_size: Output size.
- embed_size: Embedding size.
- embed_dropout_rate: Dropout rate for embedding layer.
- embed_pad: Embed/Blank symbol ID.
-
- """
-
- def __init__(
- self,
- vocab_size: int,
- embed_size: int = 256,
- embed_dropout_rate: float = 0.0,
- embed_pad: int = 0,
- ) -> None:
- """Construct a StatelessDecoder object."""
- super().__init__()
-
- assert check_argument_types()
-
- self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
- self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate)
-
- self.output_size = embed_size
- self.vocab_size = vocab_size
-
- self.device = next(self.parameters()).device
- self.score_cache = {}
-
-
-
- def forward(
- self,
- labels: torch.Tensor,
- label_lens: torch.Tensor,
- states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
- ) -> torch.Tensor:
- """Encode source label sequences.
-
- Args:
- labels: Label ID sequences. (B, L)
- states: Decoder hidden states. None
-
- Returns:
- dec_embed: Decoder output sequences. (B, U, D_emb)
-
- """
- dec_embed = self.embed_dropout_rate(self.embed(labels))
- return dec_embed
-
- def score(
- self,
- label: torch.Tensor,
- label_sequence: List[int],
- state: None,
- ) -> Tuple[torch.Tensor, None]:
- """One-step forward hypothesis.
-
- Args:
- label: Previous label. (1, 1)
- label_sequence: Current label sequence.
- state: Previous decoder hidden states. None
-
- Returns:
- dec_out: Decoder output sequence. (1, D_emb)
- state: Decoder hidden states. None
-
- """
- str_labels = "_".join(map(str, label_sequence))
-
- if str_labels in self.score_cache:
- dec_embed = self.score_cache[str_labels]
- else:
- dec_embed = self.embed(label)
-
- self.score_cache[str_labels] = dec_embed
-
- return dec_embed[0], None
-
- def batch_score(
- self,
- hyps: List[Hypothesis],
- ) -> Tuple[torch.Tensor, None]:
- """One-step forward hypotheses.
-
- Args:
- hyps: Hypotheses.
-
- Returns:
- dec_out: Decoder output sequences. (B, D_dec)
- states: Decoder hidden states. None
-
- """
- labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
- dec_embed = self.embed(labels)
-
- return dec_embed.squeeze(1), None
-
- def set_device(self, device: torch.device) -> None:
- """Set GPU device to use.
-
- Args:
- device: Device ID.
-
- """
- self.device = device
-
- def init_state(self, batch_size: int) -> None:
- """Initialize decoder states.
-
- Args:
- batch_size: Batch size.
-
- Returns:
- : Initial decoder hidden states. None
-
- """
- return None
-
- def select_state(self, states: Optional[torch.Tensor], idx: int) -> None:
- """Get specified ID state from decoder hidden states.
-
- Args:
- states: Decoder hidden states. None
- idx: State ID to extract.
-
- Returns:
- : Decoder hidden state for given ID. None
-
- """
- return None
diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
index 8b7e613..3eb8e08 100644
--- a/funasr/modules/beam_search/beam_search_transducer.py
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -6,7 +6,6 @@
import numpy as np
import torch
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -68,7 +67,7 @@
def __init__(
self,
- decoder: AbsDecoder,
+ decoder,
joint_network: JointNetwork,
beam_size: int,
lm: Optional[torch.nn.Module] = None,
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index a01cd5e..f430fcb 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -18,7 +18,6 @@
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_net.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
@@ -268,7 +267,7 @@
def __init__(
self,
- decoder: AbsDecoder,
+ decoder,
joint_network: JointNetwork,
token_list: List[int],
sym_space: str,
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index e151473..87db05c 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,13 +38,16 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
@@ -150,6 +153,7 @@
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
+ chunk_conformer=ConformerChunkEncoder,
),
type_check=AbsEncoder,
default="rnn",
@@ -207,6 +211,16 @@
type_check=AbsDecoder,
default="rnn",
)
+
+rnnt_decoder_choices = ClassChoices(
+ "rnnt_decoder",
+ classes=dict(
+ rnnt=RNNTDecoder,
+ ),
+ type_check=RNNTDecoder,
+ default="rnnt",
+)
+
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
@@ -1331,3 +1345,378 @@
) -> Tuple[str, ...]:
retval = ("speech", "text")
return retval
+
+
+class ASRTransducerTask(AbsTask):
+ """ASR Transducer Task definition."""
+
+ num_optimizers: int = 1
+
+ class_choices_list = [
+ frontend_choices,
+ specaug_choices,
+ normalize_choices,
+ encoder_choices,
+ rnnt_decoder_choices,
+ ]
+
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ """Add Transducer task arguments.
+ Args:
+ cls: ASRTransducerTask object.
+ parser: Transducer arguments parser.
+ """
+ group = parser.add_argument_group(description="Task related.")
+
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="Integer-string mapper for tokens.",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of dimensions for input features.",
+ )
+ group.add_argument(
+ "--init",
+ type=str_or_none,
+ default=None,
+ help="Type of model initialization to use.",
+ )
+ group.add_argument(
+ "--model_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(TransducerModel),
+ help="The keyword arguments for the model class.",
+ )
+ # group.add_argument(
+ # "--encoder_conf",
+ # action=NestedDictAction,
+ # default={},
+ # help="The keyword arguments for the encoder class.",
+ # )
+ group.add_argument(
+ "--joint_network_conf",
+ action=NestedDictAction,
+ default={},
+ help="The keyword arguments for the joint network class.",
+ )
+ group = parser.add_argument_group(description="Preprocess related.")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Whether to apply preprocessing to input data.",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word", "phn"],
+ help="The type of tokens to use during tokenization.",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The path of the sentencepiece model.",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="The 'non_linguistic_symbols' file path.",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Text cleaner to use.",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="g2p method to use if --token_type=phn.",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Normalization value for maximum amplitude scaling.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The RIR SCP file path.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability of the applied RIR convolution.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The path of noise SCP file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability of the applied noise addition.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of the noise decibel level.",
+ )
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --decoder and --decoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ """Build collate function.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ train: Training mode.
+ Return:
+ : Callable collate function.
+ """
+ assert check_argument_types()
+
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ """Build pre-processing function.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ train: Training mode.
+ Return:
+ : Callable pre-processing function.
+ """
+ assert check_argument_types()
+
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ """Required data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ train: Training mode.
+ inference: Inference mode.
+ Return:
+ retval: Required task data.
+ """
+ if not inference:
+ retval = ("speech", "text")
+ else:
+ retval = ("speech",)
+
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ """Optional data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ train: Training mode.
+ inference: Inference mode.
+ Return:
+ retval: Optional task data.
+ """
+ retval = ()
+ assert check_return_type(retval)
+
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace) -> TransducerModel:
+ """Required data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ Return:
+ model: ASR Transducer model.
+ """
+ assert check_argument_types()
+
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size }")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Encoder
+
+ if getattr(args, "encoder", None) is not None:
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size, **args.encoder_conf)
+ else:
+ encoder = Encoder(input_size, **args.encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ # 5. Decoder
+ rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+ decoder = rnnt_decoder_class(
+ vocab_size,
+ **args.rnnt_decoder_conf,
+ )
+ decoder_output_size = decoder.output_size
+
+ if getattr(args, "decoder", None) is not None:
+ att_decoder_class = decoder_choices.get_class(args.att_decoder)
+
+ att_decoder = att_decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+ else:
+ att_decoder = None
+ # 6. Joint Network
+ joint_network = JointNetwork(
+ vocab_size,
+ encoder_output_size,
+ decoder_output_size,
+ **args.joint_network_conf,
+ )
+
+ # 7. Build model
+
+ if encoder.unified_model_training:
+ model = UnifiedTransducerModel(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ else:
+ model = TransducerModel(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ # 8. Initialize model
+ if args.init is not None:
+ raise NotImplementedError(
+ "Currently not supported.",
+ "Initialization part will be reworked in a short future.",
+ )
+
+ #assert check_return_type(model)
+
+ return model
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
deleted file mode 100644
index d4136d0..0000000
--- a/funasr/tasks/asr_transducer.py
+++ /dev/null
@@ -1,477 +0,0 @@
-"""ASR Transducer Task."""
-
-import argparse
-import logging
-from typing import Callable, Collection, Dict, List, Optional, Tuple
-
-import numpy as np
-import torch
-from typeguard import check_argument_types, check_return_type
-
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.decoder.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder,
- DynamicConvolutionTransformerDecoder,
- LightweightConvolution2DTransformerDecoder,
- LightweightConvolutionTransformerDecoder,
- TransformerDecoder,
-)
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
-from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder
-from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
-from funasr.models.e2e_asr_transducer import TransducerModel
-from funasr.models.e2e_asr_transducer_unified import UnifiedTransducerModel
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.class_choices import ClassChoices
-from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none, int_or_none, str2bool, str_or_none
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- ),
- type_check=AbsFrontend,
- default="default",
-)
-specaug_choices = ClassChoices(
- "specaug",
- classes=dict(
- specaug=SpecAug,
- ),
- type_check=AbsSpecAug,
- default=None,
- optional=True,
-)
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- type_check=AbsNormalize,
- default="utterance_mvn",
- optional=True,
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- chunk_conformer=ConformerChunkEncoder,
- ),
- default="chunk_conformer",
-)
-
-decoder_choices = ClassChoices(
- "decoder",
- classes=dict(
- rnn=RNNDecoder,
- stateless=StatelessDecoder,
- ),
- type_check=AbsDecoder,
- default="rnn",
-)
-
-att_decoder_choices = ClassChoices(
- "att_decoder",
- classes=dict(
- transformer=TransformerDecoder,
- lightweight_conv=LightweightConvolutionTransformerDecoder,
- lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
- dynamic_conv=DynamicConvolutionTransformerDecoder,
- dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
- ),
- type_check=AbsAttDecoder,
- default=None,
- optional=True,
-)
-class ASRTransducerTask(AbsTask):
- """ASR Transducer Task definition."""
-
- num_optimizers: int = 1
-
- class_choices_list = [
- frontend_choices,
- specaug_choices,
- normalize_choices,
- encoder_choices,
- decoder_choices,
- att_decoder_choices,
- ]
-
- trainer = Trainer
-
- @classmethod
- def add_task_arguments(cls, parser: argparse.ArgumentParser):
- """Add Transducer task arguments.
- Args:
- cls: ASRTransducerTask object.
- parser: Transducer arguments parser.
- """
- group = parser.add_argument_group(description="Task related.")
-
- # required = parser.get_default("required")
- # required += ["token_list"]
-
- group.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="Integer-string mapper for tokens.",
- )
- group.add_argument(
- "--split_with_space",
- type=str2bool,
- default=True,
- help="whether to split text using <space>",
- )
- group.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of dimensions for input features.",
- )
- group.add_argument(
- "--init",
- type=str_or_none,
- default=None,
- help="Type of model initialization to use.",
- )
- group.add_argument(
- "--model_conf",
- action=NestedDictAction,
- default=get_default_kwargs(TransducerModel),
- help="The keyword arguments for the model class.",
- )
- # group.add_argument(
- # "--encoder_conf",
- # action=NestedDictAction,
- # default={},
- # help="The keyword arguments for the encoder class.",
- # )
- group.add_argument(
- "--joint_network_conf",
- action=NestedDictAction,
- default={},
- help="The keyword arguments for the joint network class.",
- )
- group = parser.add_argument_group(description="Preprocess related.")
- group.add_argument(
- "--use_preprocessor",
- type=str2bool,
- default=True,
- help="Whether to apply preprocessing to input data.",
- )
- group.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word", "phn"],
- help="The type of tokens to use during tokenization.",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The path of the sentencepiece model.",
- )
- parser.add_argument(
- "--non_linguistic_symbols",
- type=str_or_none,
- help="The 'non_linguistic_symbols' file path.",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Text cleaner to use.",
- )
- parser.add_argument(
- "--g2p",
- type=str_or_none,
- choices=g2p_choices,
- default=None,
- help="g2p method to use if --token_type=phn.",
- )
- parser.add_argument(
- "--speech_volume_normalize",
- type=float_or_none,
- default=None,
- help="Normalization value for maximum amplitude scaling.",
- )
- parser.add_argument(
- "--rir_scp",
- type=str_or_none,
- default=None,
- help="The RIR SCP file path.",
- )
- parser.add_argument(
- "--rir_apply_prob",
- type=float,
- default=1.0,
- help="The probability of the applied RIR convolution.",
- )
- parser.add_argument(
- "--noise_scp",
- type=str_or_none,
- default=None,
- help="The path of noise SCP file.",
- )
- parser.add_argument(
- "--noise_apply_prob",
- type=float,
- default=1.0,
- help="The probability of the applied noise addition.",
- )
- parser.add_argument(
- "--noise_db_range",
- type=str,
- default="13_15",
- help="The range of the noise decibel level.",
- )
- for class_choices in cls.class_choices_list:
- # Append --<name> and --<name>_conf.
- # e.g. --decoder and --decoder_conf
- class_choices.add_arguments(group)
-
- @classmethod
- def build_collate_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Callable[
- [Collection[Tuple[str, Dict[str, np.ndarray]]]],
- Tuple[List[str], Dict[str, torch.Tensor]],
- ]:
- """Build collate function.
- Args:
- cls: ASRTransducerTask object.
- args: Task arguments.
- train: Training mode.
- Return:
- : Callable collate function.
- """
- assert check_argument_types()
-
- return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
-
- @classmethod
- def build_preprocess_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- """Build pre-processing function.
- Args:
- cls: ASRTransducerTask object.
- args: Task arguments.
- train: Training mode.
- Return:
- : Callable pre-processing function.
- """
- assert check_argument_types()
-
- if args.use_preprocessor:
- retval = CommonPreprocessor(
- train=train,
- token_type=args.token_type,
- token_list=args.token_list,
- bpemodel=args.bpemodel,
- non_linguistic_symbols=args.non_linguistic_symbols,
- text_cleaner=args.cleaner,
- g2p_type=args.g2p,
- split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
- rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
- rir_apply_prob=args.rir_apply_prob
- if hasattr(args, "rir_apply_prob")
- else 1.0,
- noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
- noise_apply_prob=args.noise_apply_prob
- if hasattr(args, "noise_apply_prob")
- else 1.0,
- noise_db_range=args.noise_db_range
- if hasattr(args, "noise_db_range")
- else "13_15",
- speech_volume_normalize=args.speech_volume_normalize
- if hasattr(args, "rir_scp")
- else None,
- )
- else:
- retval = None
-
- assert check_return_type(retval)
- return retval
-
- @classmethod
- def required_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Required data depending on task mode.
- Args:
- cls: ASRTransducerTask object.
- train: Training mode.
- inference: Inference mode.
- Return:
- retval: Required task data.
- """
- if not inference:
- retval = ("speech", "text")
- else:
- retval = ("speech",)
-
- return retval
-
- @classmethod
- def optional_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Optional data depending on task mode.
- Args:
- cls: ASRTransducerTask object.
- train: Training mode.
- inference: Inference mode.
- Return:
- retval: Optional task data.
- """
- retval = ()
- assert check_return_type(retval)
-
- return retval
-
- @classmethod
- def build_model(cls, args: argparse.Namespace) -> TransducerModel:
- """Required data depending on task mode.
- Args:
- cls: ASRTransducerTask object.
- args: Task arguments.
- Return:
- model: ASR Transducer model.
- """
- assert check_argument_types()
-
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
-
- # Overwriting token_list to keep it as "portable".
- args.token_list = list(token_list)
- elif isinstance(args.token_list, (tuple, list)):
- token_list = list(args.token_list)
- else:
- raise RuntimeError("token_list must be str or list")
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size }")
-
- # 1. frontend
- if args.input_size is None:
- # Extract features in the model
- frontend_class = frontend_choices.get_class(args.frontend)
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- # Give features from data-loader
- frontend = None
- input_size = args.input_size
-
- # 2. Data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
-
- # 3. Normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
-
- # 4. Encoder
-
- if getattr(args, "encoder", None) is not None:
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(input_size, **args.encoder_conf)
- else:
- encoder = Encoder(input_size, **args.encoder_conf)
- encoder_output_size = encoder.output_size
-
- # 5. Decoder
- decoder_class = decoder_choices.get_class(args.decoder)
- decoder = decoder_class(
- vocab_size,
- **args.decoder_conf,
- )
- decoder_output_size = decoder.output_size
-
- if getattr(args, "att_decoder", None) is not None:
- att_decoder_class = att_decoder_choices.get_class(args.att_decoder)
-
- att_decoder = att_decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **args.att_decoder_conf,
- )
- else:
- att_decoder = None
-
- # 6. Joint Network
- joint_network = JointNetwork(
- vocab_size,
- encoder_output_size,
- decoder_output_size,
- **args.joint_network_conf,
- )
-
- # 7. Build model
-
- if encoder.unified_model_training:
- model = UnifiedTransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- else:
- model = TransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- # 8. Initialize model
- if args.init is not None:
- raise NotImplementedError(
- "Currently not supported.",
- "Initialization part will be reworked in a short future.",
- )
-
- #assert check_return_type(model)
-
- return model
--
Gitblit v1.9.1