| | |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | |
| | | yield |
| | | |
| | | |
| | | class ESPnetASRModel(FunASRModel): |
| | | class SAASRModel(FunASRModel): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | def __init__( |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | preencoder: Optional[AbsPreEncoder], |
| | | asr_encoder: AbsEncoder, |
| | | spk_encoder: torch.nn.Module, |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | spk_weight: float = 0.5, |
| | |
| | | sym_blank: str = "<blank>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | self.preencoder = preencoder |
| | | self.postencoder = postencoder |
| | | self.asr_encoder = asr_encoder |
| | | self.spk_encoder = spk_encoder |
| | | |
| | |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # Pre-encoder, e.g. used for raw input data |
| | | if self.preencoder is not None: |
| | | feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | |
| | | encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1) |
| | | else: |
| | | encoder_out_spk=encoder_out_spk_ori |
| | | # Post-encoder, e.g. NLU |
| | | if self.postencoder is not None: |
| | | encoder_out, encoder_out_lens = self.postencoder( |
| | | encoder_out, encoder_out_lens |
| | | ) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | ) |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk |
| | | |
| | | return encoder_out, encoder_out_lens, encoder_out_spk |
| | | |