| | |
| | | |
| | | import torch |
| | | from packaging.version import parse as V |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | | from funasr.modules.nets_utils import th_accuracy |
| | | from funasr.modules.add_sos_eos import add_sos_eos |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | if V(torch.__version__) >= V("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | yield |
| | | |
| | | |
| | | class TransducerModel(AbsESPnetModel): |
| | | class TransducerModel(FunASRModel): |
| | | """ESPnet2ASRTransducerModel module definition. |
| | | |
| | | Args: |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: Encoder, |
| | | encoder: AbsEncoder, |
| | | decoder: RNNTDecoder, |
| | | joint_network: JointNetwork, |
| | | att_decoder: Optional[AbsAttDecoder] = None, |
| | |
| | | ) -> None: |
| | | """Construct an ESPnetASRTransducerModel object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) |
| | | self.blank_id = 0 |
| | |
| | | self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) |
| | | self.ctc_dropout_rate = auxiliary_ctc_dropout_rate |
| | | |
| | | if self.use_auxiliary_lm_loss: |
| | |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None: |
| | | encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, |
| | | chunk_outs=None) |
| | | # 2. Transducer-related I/O preparation |
| | | decoder_in, target, t_len, u_len = get_transducer_task_io( |
| | | text, |
| | |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | """ |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | # from warprnnt_pytorch import RNNTLoss |
| | | # self.criterion_transducer = RNNTLoss( |
| | | # reduction="mean", |
| | | # fastemit_lambda=self.fastemit_lambda, |
| | | # ) |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | |
| | | ) |
| | | exit(1) |
| | | |
| | | # loss_transducer = self.criterion_transducer( |
| | | # joint_out, |
| | | # target, |
| | | # t_len, |
| | | # u_len, |
| | | # ) |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | loss_transducer = self.criterion_transducer( |
| | |
| | | |
| | | return loss_lm |
| | | |
| | | class UnifiedTransducerModel(AbsESPnetModel): |
| | | class UnifiedTransducerModel(FunASRModel): |
| | | """ESPnet2ASRTransducerModel module definition. |
| | | Args: |
| | | vocab_size: Size of complete vocabulary (w/ EOS and blank included). |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: Encoder, |
| | | encoder: AbsEncoder, |
| | | decoder: RNNTDecoder, |
| | | joint_network: JointNetwork, |
| | | att_decoder: Optional[AbsAttDecoder] = None, |
| | |
| | | ) -> None: |
| | | """Construct an ESPnetASRTransducerModel object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) |
| | | self.blank_id = 0 |
| | |
| | | self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) |
| | | self.ctc_dropout_rate = auxiliary_ctc_dropout_rate |
| | | |
| | | if self.use_auxiliary_att: |
| | |
| | | |
| | | batch_size = speech.shape[0] |
| | | text = text[:, : text_lengths.max()] |
| | | #print(speech.shape) |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | |
| | | loss_lm = self._calc_lm_loss(decoder_out, target) |
| | | |
| | | loss_trans = loss_trans_utt + loss_trans_chunk |
| | | loss_ctc = loss_ctc + loss_ctc_chunk |
| | | loss_ctc = loss_att + loss_att_chunk |
| | | loss_ctc = loss_ctc + loss_ctc_chunk |
| | | loss_att = loss_att + loss_att_chunk |
| | | |
| | | loss = ( |
| | | self.transducer_weight * loss_trans |
| | |
| | | """ |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | # from warprnnt_pytorch import RNNTLoss |
| | | # self.criterion_transducer = RNNTLoss( |
| | | # reduction="mean", |
| | | # fastemit_lambda=self.fastemit_lambda, |
| | | # ) |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | |
| | | ) |
| | | exit(1) |
| | | |
| | | # loss_transducer = self.criterion_transducer( |
| | | # joint_out, |
| | | # target, |
| | | # t_len, |
| | | # u_len, |
| | | # ) |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | loss_transducer = self.criterion_transducer( |