From c3c78fc5e790d48b3a2f9da79199320c06108d38 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 12 一月 2024 18:23:56 +0800
Subject: [PATCH] bug fix
---
funasr/models/paraformer/model.py | 991 ++++++------
funasr/models/fsmn_vad_streaming/model.py | 3
funasr/models/fsmn_vad/model.py | 3
funasr/models/contextual_paraformer/model.py | 935 +++++-----
funasr/models/paraformer_streaming/model.py | 1035 ++++++------
funasr/models/transducer/model.py | 995 ++++++------
funasr/models/monotonic_aligner/model.py | 3
funasr/models/transformer/model.py | 823 +++++-----
8 files changed, 2,398 insertions(+), 2,390 deletions(-)
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 6fdf2dc..67d4fb0 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -19,7 +19,7 @@
import time
# from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
+ LabelSmoothingLoss, # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
@@ -40,12 +40,12 @@
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
+ from torch.cuda.amp import autocast
else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
@@ -57,477 +57,478 @@
@tables.register("model_classes", "ContextualParaformer")
class ContextualParaformer(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- FunASR: A Fundamental End-to-End Speech Recognition Toolkit
- https://arxiv.org/abs/2305.11013
- """
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
-
- self.target_buffer_length = kwargs.get("target_buffer_length", -1)
- inner_dim = kwargs.get("inner_dim", 256)
- bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
- use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
- crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
- crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
- bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ FunASR: A Fundamental End-to-End Speech Recognition Toolkit
+ https://arxiv.org/abs/2305.11013
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.target_buffer_length = kwargs.get("target_buffer_length", -1)
+ inner_dim = kwargs.get("inner_dim", 256)
+ bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
+ use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
+ crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
+ crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
+ bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
- if bias_encoder_type == 'lstm':
- logging.warning("enable bias encoder sampling and contextual training")
- self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
- self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
- elif bias_encoder_type == 'mean':
- logging.warning("enable bias encoder sampling and contextual training")
- self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
- else:
- logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
-
- if self.target_buffer_length > 0:
- self.hotword_buffer = None
- self.length_record = []
- self.current_buffer_length = 0
- self.use_decoder_embedding = use_decoder_embedding
- self.crit_attn_weight = crit_attn_weight
- if self.crit_attn_weight > 0:
- self.attn_loss = torch.nn.L1Loss()
- self.crit_attn_smooth = crit_attn_smooth
+ if bias_encoder_type == 'lstm':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
+ self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+ elif bias_encoder_type == 'mean':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+ else:
+ logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
+
+ if self.target_buffer_length > 0:
+ self.hotword_buffer = None
+ self.length_record = []
+ self.current_buffer_length = 0
+ self.use_decoder_embedding = use_decoder_embedding
+ self.crit_attn_weight = crit_attn_weight
+ if self.crit_attn_weight > 0:
+ self.attn_loss = torch.nn.L1Loss()
+ self.crit_attn_smooth = crit_attn_smooth
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
-
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
- hotword_pad = kwargs.get("hotword_pad")
- hotword_lengths = kwargs.get("hotword_lengths")
- dha_pad = kwargs.get("dha_pad")
-
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ hotword_pad = kwargs.get("hotword_pad")
+ hotword_lengths = kwargs.get("hotword_lengths")
+ dha_pad = kwargs.get("dha_pad")
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_ctc, cer_ctc = None, None
-
- stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
+
+ loss_ctc, cer_ctc = None, None
+
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
- # 2b. Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
- encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
- if loss_ideal is not None:
- loss = loss + loss_ideal * self.crit_attn_weight
- stats["loss_ideal"] = loss_ideal.detach().cpu()
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + self.predictor_bias).sum())
-
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
-
- def _calc_att_clas_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- hotword_pad: torch.Tensor,
- hotword_lengths: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
-
- # -1. bias encoder
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hotword_pad)
- else:
- hw_embed = self.bias_embed(hotword_pad)
- hw_embed, (_, _) = self.bias_encoder(hw_embed)
- _ind = np.arange(0, hotword_pad.shape[0]).tolist()
- selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
- contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds, contextual_info)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- '''
- if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
- ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
- attn_non_blank = attn[:,:,:,:-1]
- ideal_attn_non_blank = ideal_attn[:,:,:-1]
- loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
- else:
- loss_ideal = None
- '''
- loss_ideal = None
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
-
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0,
- index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
- value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
- clas_scale=1.0):
- if hw_list is None:
- hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
- hw_list_pad = pad_list(hw_list, 0)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
- hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
- else:
- hw_lengths = [len(i) for i in hw_list]
- hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
- enforce_sorted=False)
- _, (h_n, _) = self.bias_encoder(hw_embed)
- hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def generate(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
-
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
-
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data[
- "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+ # 2b. Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ if loss_ideal is not None:
+ loss = loss + loss_ideal * self.crit_attn_weight
+ stats["loss_ideal"] = loss_ideal.detach().cpu()
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + self.predictor_bias).sum())
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def _calc_att_clas_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ hotword_pad: torch.Tensor,
+ hotword_lengths: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # -1. bias encoder
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hotword_pad)
+ else:
+ hw_embed = self.bias_embed(hotword_pad)
+ hw_embed, (_, _) = self.bias_encoder(hw_embed)
+ _ind = np.arange(0, hotword_pad.shape[0]).tolist()
+ selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
+ contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds, contextual_info)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ '''
+ if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
+ ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
+ attn_non_blank = attn[:,:,:,:-1]
+ ideal_attn_non_blank = ideal_attn[:,:,:-1]
+ loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
+ else:
+ loss_ideal = None
+ '''
+ loss_ideal = None
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
+
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0,
+ index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
+ value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
+ clas_scale=1.0):
+ if hw_list is None:
+ hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
+ hw_list_pad = pad_list(hw_list, 0)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+ else:
+ hw_lengths = [len(i) for i in hw_list]
+ hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+ enforce_sorted=False)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
- # hotword
- self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
-
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # predictor
- predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
+ # hotword
+ self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
- decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
- pre_acoustic_embeds,
- pre_token_length,
- hw_list=self.hotword_list,
- clas_scale=kwargs.get("clas_scale", 1.0))
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
-
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(
- filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- if tokenizer is not None:
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "text": text_postprocessed}
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
- else:
- result_i = {"key": key[i], "token_int": token_int}
- results.append(result_i)
-
- return results, meta_data
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
+ pre_acoustic_embeds,
+ pre_token_length,
+ hw_list=self.hotword_list,
+ clas_scale=kwargs.get("clas_scale", 1.0))
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(
+ filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ if tokenizer is not None:
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "text": text_postprocessed}
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+ else:
+ result_i = {"key": key[i], "token_int": token_int}
+ results.append(result_i)
+
+ return results, meta_data
- def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
- def load_seg_dict(seg_dict_file):
- seg_dict = {}
- assert isinstance(seg_dict_file, str)
- with open(seg_dict_file, "r", encoding="utf8") as f:
- lines = f.readlines()
- for line in lines:
- s = line.strip().split()
- key = s[0]
- value = s[1:]
- seg_dict[key] = " ".join(value)
- return seg_dict
-
- def seg_tokenize(txt, seg_dict):
- pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
- out_txt = ""
- for word in txt:
- word = word.lower()
- if word in seg_dict:
- out_txt += seg_dict[word] + " "
- else:
- if pattern.match(word):
- for char in word:
- if char in seg_dict:
- out_txt += seg_dict[char] + " "
- else:
- out_txt += "<unk>" + " "
- else:
- out_txt += "<unk>" + " "
- return out_txt.strip().split()
-
- seg_dict = None
- if frontend.cmvn_file is not None:
- model_dir = os.path.dirname(frontend.cmvn_file)
- seg_dict_file = os.path.join(model_dir, 'seg_dict')
- if os.path.exists(seg_dict_file):
- seg_dict = load_seg_dict(seg_dict_file)
- else:
- seg_dict = None
- # for None
- if hotword_list_or_file is None:
- hotword_list = None
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords from local txt...")
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hw_list = hw.split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_str_list.append(hw)
- hotword_list.append(tokenizer.tokens2ids(hw_list))
- hotword_list.append([self.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- elif hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hw_list = hw.split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_str_list.append(hw)
- hotword_list.append(tokenizer.tokens2ids(hw_list))
- hotword_list.append([self.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for text str input
- elif not hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords as str...")
- hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- hw_list = hw.strip().split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_list.append(tokenizer.tokens2ids(hw_list))
- hotword_list.append([self.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- else:
- hotword_list = None
- return hotword_list
+ def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
+ def load_seg_dict(seg_dict_file):
+ seg_dict = {}
+ assert isinstance(seg_dict_file, str)
+ with open(seg_dict_file, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+ def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+ seg_dict = None
+ if frontend.cmvn_file is not None:
+ model_dir = os.path.dirname(frontend.cmvn_file)
+ seg_dict_file = os.path.join(model_dir, 'seg_dict')
+ if os.path.exists(seg_dict_file):
+ seg_dict = load_seg_dict(seg_dict_file)
+ else:
+ seg_dict = None
+ # for None
+ if hotword_list_or_file is None:
+ hotword_list = None
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords from local txt...")
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_str_list.append(hw)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ elif hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_str_list.append(hw)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for text str input
+ elif not hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords as str...")
+ hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ hw_list = hw.strip().split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ else:
+ hotword_list = None
+ return hotword_list
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index 15d2af5..b31e061 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -555,7 +555,8 @@
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
# b. Forward Encoder streaming
t_offset = 0
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index e0d104a..9ceacf6 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -578,7 +578,8 @@
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
batch = {
"feats": speech,
diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
index 584b692..1b43c2f 100644
--- a/funasr/models/monotonic_aligner/model.py
+++ b/funasr/models/monotonic_aligner/model.py
@@ -166,7 +166,8 @@
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 78a72ec..f60bead 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -8,7 +8,7 @@
import time
from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
+ LabelSmoothingLoss, # noqa: H301
)
from funasr.models.paraformer.cif_predictor import mae_loss
@@ -30,416 +30,416 @@
@tables.register("model_classes", "Paraformer")
class Paraformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- # token_list: Union[Tuple[str, ...], List[str]],
- specaug: Optional[str] = None,
- specaug_conf: Optional[Dict] = None,
- normalize: str = None,
- normalize_conf: Optional[Dict] = None,
- encoder: str = None,
- encoder_conf: Optional[Dict] = None,
- decoder: str = None,
- decoder_conf: Optional[Dict] = None,
- ctc: str = None,
- ctc_conf: Optional[Dict] = None,
- predictor: str = None,
- predictor_conf: Optional[Dict] = None,
- ctc_weight: float = 0.5,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- # report_cer: bool = True,
- # report_wer: bool = True,
- # sym_space: str = "<space>",
- # sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- # predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- use_1st_decoder_loss: bool = False,
- **kwargs,
- ):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ # token_list: Union[Tuple[str, ...], List[str]],
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ predictor: str = None,
+ predictor_conf: Optional[Dict] = None,
+ ctc_weight: float = 0.5,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ # report_cer: bool = True,
+ # report_wer: bool = True,
+ # sym_space: str = "<space>",
+ # sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ # predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ use_1st_decoder_loss: bool = False,
+ **kwargs,
+ ):
- super().__init__()
+ super().__init__()
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug.lower())
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = tables.normalize_classes.get(normalize.lower())
- normalize = normalize_class(**normalize_conf)
- encoder_class = tables.encoder_classes.get(encoder.lower())
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug.lower())
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize.lower())
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = tables.encoder_classes.get(encoder.lower())
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
- if decoder is not None:
- decoder_class = tables.decoder_classes.get(decoder.lower())
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
- if ctc_weight > 0.0:
-
- if ctc_conf is None:
- ctc_conf = {}
-
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
- )
- if predictor is not None:
- predictor_class = tables.predictor_classes.get(predictor.lower())
- predictor = predictor_class(**predictor_conf)
-
- # note that eos is the same as sos (equivalent ID)
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- # self.token_list = token_list.copy()
- #
- # self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- # self.preencoder = preencoder
- # self.postencoder = postencoder
- self.encoder = encoder
- #
- # if not hasattr(self.encoder, "interctc_use_conditioning"):
- # self.encoder.interctc_use_conditioning = False
- # if self.encoder.interctc_use_conditioning:
- # self.encoder.conditioning_layer = torch.nn.Linear(
- # vocab_size, self.encoder.output_size()
- # )
- #
- # self.error_calculator = None
- #
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- #
- # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
- self.predictor = predictor
- self.predictor_weight = predictor_weight
- self.predictor_bias = predictor_bias
- self.sampling_ratio = sampling_ratio
- self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- # self.step_cur = 0
- #
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
-
- self.use_1st_decoder_loss = use_1st_decoder_loss
- self.length_normalized_loss = length_normalized_loss
- self.beam_search = None
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- # import pdb;
- # pdb.set_trace()
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
-
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if decoder is not None:
+ decoder_class = tables.decoder_classes.get(decoder.lower())
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+ if predictor is not None:
+ predictor_class = tables.predictor_classes.get(predictor.lower())
+ predictor = predictor_class(**predictor_conf)
+
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ # self.token_list = token_list.copy()
+ #
+ # self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ # self.preencoder = preencoder
+ # self.postencoder = postencoder
+ self.encoder = encoder
+ #
+ # if not hasattr(self.encoder, "interctc_use_conditioning"):
+ # self.encoder.interctc_use_conditioning = False
+ # if self.encoder.interctc_use_conditioning:
+ # self.encoder.conditioning_layer = torch.nn.Linear(
+ # vocab_size, self.encoder.output_size()
+ # )
+ #
+ # self.error_calculator = None
+ #
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+ #
+ # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+ self.predictor = predictor
+ self.predictor_weight = predictor_weight
+ self.predictor_bias = predictor_bias
+ self.sampling_ratio = sampling_ratio
+ self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+ # self.step_cur = 0
+ #
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.use_1st_decoder_loss = use_1st_decoder_loss
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
-
- # decoder: CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
- # decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum()
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = (text_lengths + self.predictor_bias).sum()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
-
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
- # Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
- return encoder_out, encoder_out_lens
-
- def calc_predictor(self, encoder_out, encoder_out_lens):
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
- encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
-
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
+ return encoder_out, encoder_out_lens
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
-
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
-
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0,
- index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
- value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0,
+ index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+ value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
-
- def init_beam_search(self,
- **kwargs,
- ):
- from funasr.models.paraformer.search import BeamSearchPara
- from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- if self.ctc != None:
- ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = kwargs.get("token_list")
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.paraformer.search import BeamSearchPara
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- weights = dict(
- decoder=1.0 - kwargs.get("decoding_ctc_weight"),
- ctc=kwargs.get("decoding_ctc_weight", 0.0),
- lm=kwargs.get("lm_weight", 0.0),
- ngram=kwargs.get("ngram_weight", 0.0),
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearchPara(
- beam_size=kwargs.get("beam_size", 2),
- weights=weights,
- scorers=scorers,
- sos=self.sos,
- eos=self.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
- )
- # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
- # for scorer in scorers.values():
- # if isinstance(scorer, torch.nn.Module):
- # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
- self.beam_search = beam_search
-
- def generate(self,
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+ ctc=kwargs.get("decoding_ctc_weight", 0.0),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearchPara(
+ beam_size=kwargs.get("beam_size", 2),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+ # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ # for scorer in scorers.values():
+ # if isinstance(scorer, torch.nn.Module):
+ # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ self.beam_search = beam_search
+
+ def generate(self,
data_in,
data_lengths=None,
key: list=None,
@@ -447,105 +447,106 @@
frontend=None,
**kwargs,
):
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
- if isinstance(data_in, torch.Tensor): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # predictor
- predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
- pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- results = []
- b, n, d = decoder_out.size()
- if isinstance(key[0], (list, tuple)):
- key = key[0]
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
+ results = []
+ b, n, d = decoder_out.size()
+ if isinstance(key[0], (list, tuple)):
+ key = key[0]
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx+1}best_recog"]
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- if tokenizer is not None:
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-
- result_i = {"key": key[i], "text": text_postprocessed}
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ if tokenizer is not None:
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+
+ result_i = {"key": key[i], "text": text_postprocessed}
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- # ibest_writer["text"][key[i]] = text
- ibest_writer["text"][key[i]] = text_postprocessed
- else:
- result_i = {"key": key[i], "token_int": token_int}
- results.append(result_i)
-
- return results, meta_data
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ # ibest_writer["text"][key[i]] = text
+ ibest_writer["text"][key[i]] = text_postprocessed
+ else:
+ result_i = {"key": key[i], "token_int": token_int}
+ results.append(result_i)
+
+ return results, meta_data
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index b736aa9..e6f3038 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -19,7 +19,7 @@
import time
# from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
+ LabelSmoothingLoss, # noqa: H301
)
from funasr.models.paraformer.cif_predictor import mae_loss
@@ -32,12 +32,12 @@
from funasr.models.paraformer.search import Hypothesis
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
+ from torch.cuda.amp import autocast
else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
@@ -50,531 +50,532 @@
@tables.register("model_classes", "ParaformerStreaming")
class ParaformerStreaming(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
-
- super().__init__(*args, **kwargs)
-
- # import pdb;
- # pdb.set_trace()
- self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+
+ super().__init__(*args, **kwargs)
+
+ # import pdb;
+ # pdb.set_trace()
+ self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
- self.scama_mask = None
- if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
- from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
- self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
- self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
+ self.scama_mask = None
+ if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
+ from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+ self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+ self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- # import pdb;
- # pdb.set_trace()
- decoding_ind = kwargs.get("decoding_ind")
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # Encoder
- if hasattr(self.encoder, "overlap_chunk_cls"):
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
-
- # decoder: CTC branch
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ decoding_ind = kwargs.get("decoding_ind")
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
- if self.ctc_weight > 0.0:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
-
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- else:
- loss = self.ctc_weight * loss_ctc + (
- 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum()
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode_chunk(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
-
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
-
- # Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- return encoder_out, torch.tensor([encoder_out.size(1)])
-
- def _calc_att_predictor_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- mask_chunk_predictor = None
- if self.encoder.overlap_chunk_cls is not None:
- mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(
- 0))
- mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
- batch_size=encoder_out.size(0))
- encoder_out = encoder_out * mask_shfit_chunk
- pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
- ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id,
- mask_chunk_predictor=mask_chunk_predictor,
- target_label_length=ys_pad_lens,
- )
- predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
- encoder_out_lens)
-
- scama_mask = None
- if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
- encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
- attention_chunk_center_bias = 0
- attention_chunk_size = encoder_chunk_size
- decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
- mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
- get_mask_shift_att_chunk_decoder(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(0)
- )
- scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
- predictor_alignments=predictor_alignments,
- encoder_sequence_length=encoder_out_lens,
- chunk_size=1,
- encoder_chunk_size=encoder_chunk_size,
- attention_chunk_center_bias=attention_chunk_center_bias,
- attention_chunk_size=attention_chunk_size,
- attention_chunk_type=self.decoder_attention_chunk_type,
- step=None,
- predictor_mask_chunk_hopping=mask_chunk_predictor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
- target_length=ys_pad_lens,
- is_training=self.training,
- )
- elif self.encoder.overlap_chunk_cls is not None:
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- if self.use_1st_decoder_loss:
- sematic_embeds, decoder_out_1st, pre_loss_att = \
- self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
- ys_pad_lens, pre_acoustic_embeds, scama_mask)
- else:
- sematic_embeds, decoder_out_1st = \
- self.sampler(encoder_out, encoder_out_lens, ys_pad,
- ys_pad_lens, pre_acoustic_embeds, scama_mask)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
-
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
+ if self.ctc_weight > 0.0:
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+ )
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = (text_lengths + self.predictor_bias).sum()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode_chunk(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, torch.tensor([encoder_out.size(1)])
+
+ def _calc_att_predictor_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+ ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=ys_pad_lens,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=ys_pad_lens,
+ is_training=self.training,
+ )
+ elif self.encoder.overlap_chunk_cls is not None:
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ if self.use_1st_decoder_loss:
+ sematic_embeds, decoder_out_1st, pre_loss_att = \
+ self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
+ ys_pad_lens, pre_acoustic_embeds, scama_mask)
+ else:
+ sematic_embeds, decoder_out_1st = \
+ self.sampler(encoder_out, encoder_out_lens, ys_pad,
+ ys_pad_lens, pre_acoustic_embeds, scama_mask)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
- def calc_predictor(self, encoder_out, encoder_out_lens):
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- mask_chunk_predictor = None
- if self.encoder.overlap_chunk_cls is not None:
- mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(
- 0))
- mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
- batch_size=encoder_out.size(0))
- encoder_out = encoder_out * mask_shfit_chunk
- pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
- None,
- encoder_out_mask,
- ignore_id=self.ignore_id,
- mask_chunk_predictor=mask_chunk_predictor,
- target_label_length=None,
- )
- predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
- encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
-
- scama_mask = None
- if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
- encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
- attention_chunk_center_bias = 0
- attention_chunk_size = encoder_chunk_size
- decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
- mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
- get_mask_shift_att_chunk_decoder(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(0)
- )
- scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
- predictor_alignments=predictor_alignments,
- encoder_sequence_length=encoder_out_lens,
- chunk_size=1,
- encoder_chunk_size=encoder_chunk_size,
- attention_chunk_center_bias=attention_chunk_center_bias,
- attention_chunk_size=attention_chunk_size,
- attention_chunk_type=self.decoder_attention_chunk_type,
- step=None,
- predictor_mask_chunk_hopping=mask_chunk_predictor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
- target_length=None,
- is_training=self.training,
- )
- self.scama_mask = scama_mask
-
- return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
-
- def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
- is_final = kwargs.get("is_final", False)
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=None,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=None,
+ is_training=self.training,
+ )
+ self.scama_mask = scama_mask
+
+ return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
+ def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
+ is_final = kwargs.get("is_final", False)
- return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
- decoder_outs = self.decoder.forward_chunk(
- encoder_out, sematic_embeds, cache["decoder"]
- )
- decoder_out = decoder_outs
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def init_cache(self, cache: dict = {}, **kwargs):
- chunk_size = kwargs.get("chunk_size", [0, 10, 5])
- encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
- decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
- batch_size = 1
+ return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
+ decoder_outs = self.decoder.forward_chunk(
+ encoder_out, sematic_embeds, cache["decoder"]
+ )
+ decoder_out = decoder_outs
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def init_cache(self, cache: dict = {}, **kwargs):
+ chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+ encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+ decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+ batch_size = 1
- enc_output_size = kwargs["encoder_conf"]["output_size"]
- feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
- cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
- "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
- "tail_chunk": False}
- cache["encoder"] = cache_encoder
-
- cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
- "chunk_size": chunk_size}
- cache["decoder"] = cache_decoder
- cache["frontend"] = {}
- cache["prev_samples"] = torch.empty(0)
-
- return cache
-
- def generate_chunk(self,
- speech,
- speech_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- cache = kwargs.get("cache", {})
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
- # Encoder
- encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # predictor
- predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
- encoder_out_lens,
- pre_acoustic_embeds,
- pre_token_length,
- cache=cache
- )
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ enc_output_size = kwargs["encoder_conf"]["output_size"]
+ feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+ cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+ "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+ "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+ "tail_chunk": False}
+ cache["encoder"] = cache_encoder
+
+ cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
+ "chunk_size": chunk_size}
+ cache["decoder"] = cache_decoder
+ cache["frontend"] = {}
+ cache["prev_samples"] = torch.empty(0)
+
+ return cache
+
+ def generate_chunk(self,
+ speech,
+ speech_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ cache = kwargs.get("cache", {})
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
+ encoder_out_lens,
+ pre_acoustic_embeds,
+ pre_token_length,
+ cache=cache
+ )
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- results = []
- b, n, d = decoder_out.size()
- if isinstance(key[0], (list, tuple)):
- key = key[0]
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
-
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for nbest_idx, hyp in enumerate(nbest_hyps):
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
+ results = []
+ b, n, d = decoder_out.size()
+ if isinstance(key[0], (list, tuple)):
+ key = key[0]
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- # text = tokenizer.tokens2text(token)
-
- result_i = token
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ # text = tokenizer.tokens2text(token)
+
+ result_i = token
- results.extend(result_i)
-
- return results
-
- def generate(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- cache: dict={},
- **kwargs,
- ):
+ results.extend(result_i)
+
+ return results
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ cache: dict={},
+ **kwargs,
+ ):
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
- if len(cache) == 0:
- self.init_cache(cache, **kwargs)
-
-
- meta_data = {}
- chunk_size = kwargs.get("chunk_size", [0, 10, 5])
- chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
-
- time1 = time.perf_counter()
- cfg = {"is_final": kwargs.get("is_final", False)}
- audio_sample_list = load_audio_text_image_video(data_in,
- fs=frontend.fs,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- cache=cfg,
- )
- _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
-
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- assert len(audio_sample_list) == 1, "batch_size must be set 1"
-
- audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
-
- n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
- m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
- tokens = []
- for i in range(n):
- kwargs["is_final"] = _is_final and i == n -1
- audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
+ if len(cache) == 0:
+ self.init_cache(cache, **kwargs)
+
+
+ meta_data = {}
+ chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+ chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
+
+ time1 = time.perf_counter()
+ cfg = {"is_final": kwargs.get("is_final", False)}
+ audio_sample_list = load_audio_text_image_video(data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ cache=cfg,
+ )
+ _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
+
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+ audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+
+ n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+ m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
+ tokens = []
+ for i in range(n):
+ kwargs["is_final"] = _is_final and i == n -1
+ audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
- # extract fbank feats
- speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
- frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
- tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
- tokens.extend(tokens_i)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
-
- result_i = {"key": key[0], "text": text_postprocessed}
- result = [result_i]
-
-
- cache["prev_samples"] = audio_sample[:-m]
- if _is_final:
- self.init_cache(cache, **kwargs)
-
- if kwargs.get("output_dir"):
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{1}best_recog"]
- ibest_writer["token"][key[0]] = " ".join(tokens)
- ibest_writer["text"][key[0]] = text_postprocessed
-
- return result, meta_data
+ # extract fbank feats
+ speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
+ tokens.extend(tokens_i)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
+
+ result_i = {"key": key[0], "text": text_postprocessed}
+ result = [result_i]
+
+
+ cache["prev_samples"] = audio_sample[:-m]
+ if _is_final:
+ self.init_cache(cache, **kwargs)
+
+ if kwargs.get("output_dir"):
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{1}best_recog"]
+ ibest_writer["token"][key[0]] = " ".join(tokens)
+ ibest_writer["text"][key[0]] = text_postprocessed
+
+ return result, meta_data
diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py
index 1b33b6c..906aa60 100644
--- a/funasr/models/transducer/model.py
+++ b/funasr/models/transducer/model.py
@@ -17,7 +17,7 @@
import numpy as np
import time
from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
+ LabelSmoothingLoss, # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
@@ -39,12 +39,12 @@
from funasr.models.model_class_factory import *
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
+ from torch.cuda.amp import autocast
else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
@@ -52,525 +52,526 @@
class Transducer(nn.Module):
- """ESPnet2ASRTransducerModel module definition."""
+ """ESPnet2ASRTransducerModel module definition."""
-
- def __init__(
- self,
- frontend: Optional[str] = None,
- frontend_conf: Optional[Dict] = None,
- specaug: Optional[str] = None,
- specaug_conf: Optional[Dict] = None,
- normalize: str = None,
- normalize_conf: Optional[Dict] = None,
- encoder: str = None,
- encoder_conf: Optional[Dict] = None,
- decoder: str = None,
- decoder_conf: Optional[Dict] = None,
- joint_network: str = None,
- joint_network_conf: Optional[Dict] = None,
- transducer_weight: float = 1.0,
- fastemit_lambda: float = 0.0,
- auxiliary_ctc_weight: float = 0.0,
- auxiliary_ctc_dropout_rate: float = 0.0,
- auxiliary_lm_loss_weight: float = 0.0,
- auxiliary_lm_loss_smoothing: float = 0.0,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- # report_cer: bool = True,
- # report_wer: bool = True,
- # sym_space: str = "<space>",
- # sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- **kwargs,
- ):
+
+ def __init__(
+ self,
+ frontend: Optional[str] = None,
+ frontend_conf: Optional[Dict] = None,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ joint_network: str = None,
+ joint_network_conf: Optional[Dict] = None,
+ transducer_weight: float = 1.0,
+ fastemit_lambda: float = 0.0,
+ auxiliary_ctc_weight: float = 0.0,
+ auxiliary_ctc_dropout_rate: float = 0.0,
+ auxiliary_lm_loss_weight: float = 0.0,
+ auxiliary_lm_loss_smoothing: float = 0.0,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ # report_cer: bool = True,
+ # report_wer: bool = True,
+ # sym_space: str = "<space>",
+ # sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
- super().__init__()
+ super().__init__()
- if frontend is not None:
- frontend_class = frontend_classes.get_class(frontend)
- frontend = frontend_class(**frontend_conf)
- if specaug is not None:
- specaug_class = specaug_classes.get_class(specaug)
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = normalize_classes.get_class(normalize)
- normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_classes.get_class(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
+ if frontend is not None:
+ frontend_class = frontend_classes.get_class(frontend)
+ frontend = frontend_class(**frontend_conf)
+ if specaug is not None:
+ specaug_class = specaug_classes.get_class(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = normalize_classes.get_class(normalize)
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = encoder_classes.get_class(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
- decoder_class = decoder_classes.get_class(decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
- decoder_output_size = decoder.output_size
+ decoder_class = decoder_classes.get_class(decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ decoder_output_size = decoder.output_size
- joint_network_class = joint_network_classes.get_class(decoder)
- joint_network = joint_network_class(
- vocab_size,
- encoder_output_size,
- decoder_output_size,
- **joint_network_conf,
- )
-
-
- self.criterion_transducer = None
- self.error_calculator = None
-
- self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
- self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
-
- if self.use_auxiliary_ctc:
- self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
- self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
-
- if self.use_auxiliary_lm_loss:
- self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
- self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
-
- self.transducer_weight = transducer_weight
- self.fastemit_lambda = fastemit_lambda
-
- self.auxiliary_ctc_weight = auxiliary_ctc_weight
- self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- self.encoder = encoder
- self.decoder = decoder
- self.joint_network = joint_network
+ joint_network_class = joint_network_classes.get_class(decoder)
+ joint_network = joint_network_class(
+ vocab_size,
+ encoder_output_size,
+ decoder_output_size,
+ **joint_network_conf,
+ )
+
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joint_network = joint_network
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
- self.length_normalized_loss = length_normalized_loss
- self.beam_search = None
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- # import pdb;
- # pdb.set_trace()
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
- chunk_outs=None)
- # 2. Transducer-related I/O preparation
- decoder_in, target, t_len, u_len = get_transducer_task_io(
- text,
- encoder_out_lens,
- ignore_id=self.ignore_id,
- )
-
- # 3. Decoder
- self.decoder.set_device(encoder_out.device)
- decoder_out = self.decoder(decoder_in, u_len)
-
- # 4. Joint Network
- joint_out = self.joint_network(
- encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
- )
-
- # 5. Losses
- loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
- encoder_out,
- joint_out,
- target,
- t_len,
- u_len,
- )
-
- loss_ctc, loss_lm = 0.0, 0.0
-
- if self.use_auxiliary_ctc:
- loss_ctc = self._calc_ctc_loss(
- encoder_out,
- target,
- t_len,
- u_len,
- )
-
- if self.use_auxiliary_lm_loss:
- loss_lm = self._calc_lm_loss(decoder_out, target)
-
- loss = (
- self.transducer_weight * loss_trans
- + self.auxiliary_ctc_weight * loss_ctc
- + self.auxiliary_lm_loss_weight * loss_lm
- )
-
- stats = dict(
- loss=loss.detach(),
- loss_transducer=loss_trans.detach(),
- aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
- aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
- cer_transducer=cer_trans,
- wer_transducer=wer_trans,
- )
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
- return loss, stats, weight
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
+ chunk_outs=None)
+ # 2. Transducer-related I/O preparation
+ decoder_in, target, t_len, u_len = get_transducer_task_io(
+ text,
+ encoder_out_lens,
+ ignore_id=self.ignore_id,
+ )
+
+ # 3. Decoder
+ self.decoder.set_device(encoder_out.device)
+ decoder_out = self.decoder(decoder_in, u_len)
+
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ # 5. Losses
+ loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_lm = 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+ return loss, stats, weight
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
-
- # Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- speech, speech_lengths, ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
-
- return encoder_out, encoder_out_lens
-
- def _calc_transducer_loss(
- self,
- encoder_out: torch.Tensor,
- joint_out: torch.Tensor,
- target: torch.Tensor,
- t_len: torch.Tensor,
- u_len: torch.Tensor,
- ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
- """Compute Transducer loss.
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ speech, speech_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- joint_out: Joint Network output sequences (B, T, U, D_joint)
- target: Target label ID sequences. (B, L)
- t_len: Encoder output sequences lengths. (B,)
- u_len: Target label ID sequences lengths. (B,)
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
- Return:
- loss_transducer: Transducer loss value.
- cer_transducer: Character error rate for Transducer.
- wer_transducer: Word Error Rate for Transducer.
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
- """
- if self.criterion_transducer is None:
- try:
- from warp_rnnt import rnnt_loss as RNNTLoss
- self.criterion_transducer = RNNTLoss
-
- except ImportError:
- logging.error(
- "warp-rnnt was not installed."
- "Please consult the installation documentation."
- )
- exit(1)
-
- log_probs = torch.log_softmax(joint_out, dim=-1)
-
- loss_transducer = self.criterion_transducer(
- log_probs,
- target,
- t_len,
- u_len,
- reduction="mean",
- blank=self.blank_id,
- fastemit_lambda=self.fastemit_lambda,
- gather=True,
- )
-
- if not self.training and (self.report_cer or self.report_wer):
- if self.error_calculator is None:
- from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
-
- self.error_calculator = ErrorCalculator(
- self.decoder,
- self.joint_network,
- self.token_list,
- self.sym_space,
- self.sym_blank,
- report_cer=self.report_cer,
- report_wer=self.report_wer,
- )
-
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
-
- return loss_transducer, cer_transducer, wer_transducer
-
- return loss_transducer, None, None
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- target: torch.Tensor,
- t_len: torch.Tensor,
- u_len: torch.Tensor,
- ) -> torch.Tensor:
- """Compute CTC loss.
+ """
+ if self.criterion_transducer is None:
+ try:
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
+ if not self.training and (self.report_cer or self.report_wer):
+ if self.error_calculator is None:
+ from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
+
+ self.error_calculator = ErrorCalculator(
+ self.decoder,
+ self.joint_network,
+ self.token_list,
+ self.sym_space,
+ self.sym_blank,
+ report_cer=self.report_cer,
+ report_wer=self.report_wer,
+ )
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
+
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
- Args:
- encoder_out: Encoder output sequences. (B, T, D_enc)
- target: Target label ID sequences. (B, L)
- t_len: Encoder output sequences lengths. (B,)
- u_len: Target label ID sequences lengths. (B,)
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
- Return:
- loss_ctc: CTC loss value.
+ Return:
+ loss_ctc: CTC loss value.
- """
- ctc_in = self.ctc_lin(
- torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
- )
- ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
-
- target_mask = target != 0
- ctc_target = target[target_mask].cpu()
-
- with torch.backends.cudnn.flags(deterministic=True):
- loss_ctc = torch.nn.functional.ctc_loss(
- ctc_in,
- ctc_target,
- t_len,
- u_len,
- zero_infinity=True,
- reduction="sum",
- )
- loss_ctc /= target.size(0)
-
- return loss_ctc
-
- def _calc_lm_loss(
- self,
- decoder_out: torch.Tensor,
- target: torch.Tensor,
- ) -> torch.Tensor:
- """Compute LM loss.
+ """
+ ctc_in = self.ctc_lin(
+ torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+ )
+ ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+ target_mask = target != 0
+ ctc_target = target[target_mask].cpu()
+
+ with torch.backends.cudnn.flags(deterministic=True):
+ loss_ctc = torch.nn.functional.ctc_loss(
+ ctc_in,
+ ctc_target,
+ t_len,
+ u_len,
+ zero_infinity=True,
+ reduction="sum",
+ )
+ loss_ctc /= target.size(0)
+
+ return loss_ctc
+
+ def _calc_lm_loss(
+ self,
+ decoder_out: torch.Tensor,
+ target: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute LM loss.
- Args:
- decoder_out: Decoder output sequences. (B, U, D_dec)
- target: Target label ID sequences. (B, L)
+ Args:
+ decoder_out: Decoder output sequences. (B, U, D_dec)
+ target: Target label ID sequences. (B, L)
- Return:
- loss_lm: LM loss value.
+ Return:
+ loss_lm: LM loss value.
- """
- lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
- lm_target = target.view(-1).type(torch.int64)
-
- with torch.no_grad():
- true_dist = lm_loss_in.clone()
- true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
-
- # Ignore blank ID (0)
- ignore = lm_target == 0
- lm_target = lm_target.masked_fill(ignore, 0)
-
- true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
-
- loss_lm = torch.nn.functional.kl_div(
- torch.log_softmax(lm_loss_in, dim=1),
- true_dist,
- reduction="none",
- )
- loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
- 0
- )
-
- return loss_lm
-
- def init_beam_search(self,
- **kwargs,
- ):
- from funasr.models.transformer.search import BeamSearch
- from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- if self.ctc != None:
- ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = kwargs.get("token_list")
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
+ """
+ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+ lm_target = target.view(-1).type(torch.int64)
+
+ with torch.no_grad():
+ true_dist = lm_loss_in.clone()
+ true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+ # Ignore blank ID (0)
+ ignore = lm_target == 0
+ lm_target = lm_target.masked_fill(ignore, 0)
+
+ true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+ loss_lm = torch.nn.functional.kl_div(
+ torch.log_softmax(lm_loss_in, dim=1),
+ true_dist,
+ reduction="none",
+ )
+ loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+ 0
+ )
+
+ return loss_lm
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.transformer.search import BeamSearch
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- weights = dict(
- decoder=1.0 - kwargs.get("decoding_ctc_weight"),
- ctc=kwargs.get("decoding_ctc_weight", 0.0),
- lm=kwargs.get("lm_weight", 0.0),
- ngram=kwargs.get("ngram_weight", 0.0),
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 2),
- weights=weights,
- scorers=scorers,
- sos=self.sos,
- eos=self.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
- )
- # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
- # for scorer in scorers.values():
- # if isinstance(scorer, torch.nn.Module):
- # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
- self.beam_search = beam_search
-
- def generate(self,
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+ ctc=kwargs.get("decoding_ctc_weight", 0.0),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 2),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+ # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ # for scorer in scorers.values():
+ # if isinstance(scorer, torch.nn.Module):
+ # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ self.beam_search = beam_search
+
+ def generate(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):
-
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
- results = []
- b, n, d = encoder_out.size()
- for i in range(b):
+ results = []
+ b, n, d = encoder_out.size()
+ for i in range(b):
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx+1}best_recog"]
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-
- return results, meta_data
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+ return results, meta_data
diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py
index 4e91751..2f6e15a 100644
--- a/funasr/models/transformer/model.py
+++ b/funasr/models/transformer/model.py
@@ -19,433 +19,434 @@
@tables.register("model_classes", "Transformer")
class Transformer(nn.Module):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """CTC-attention hybrid Encoder-Decoder model"""
-
- def __init__(
- self,
- frontend: Optional[str] = None,
- frontend_conf: Optional[Dict] = None,
- specaug: Optional[str] = None,
- specaug_conf: Optional[Dict] = None,
- normalize: str = None,
- normalize_conf: Optional[Dict] = None,
- encoder: str = None,
- encoder_conf: Optional[Dict] = None,
- decoder: str = None,
- decoder_conf: Optional[Dict] = None,
- ctc: str = None,
- ctc_conf: Optional[Dict] = None,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- **kwargs,
- ):
+
+ def __init__(
+ self,
+ frontend: Optional[str] = None,
+ frontend_conf: Optional[Dict] = None,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
- super().__init__()
+ super().__init__()
- if frontend is not None:
- frontend_class = tables.frontend_classes.get_class(frontend.lower())
- frontend = frontend_class(**frontend_conf)
- if specaug is not None:
- specaug_class = tables.specaug_classes.get_class(specaug.lower())
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = tables.normalize_classes.get_class(normalize.lower())
- normalize = normalize_class(**normalize_conf)
- encoder_class = tables.encoder_classes.get_class(encoder.lower())
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
- if decoder is not None:
- decoder_class = tables.decoder_classes.get_class(decoder.lower())
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
- if ctc_weight > 0.0:
-
- if ctc_conf is None:
- ctc_conf = {}
-
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
- )
-
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- self.encoder = encoder
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get_class(frontend.lower())
+ frontend = frontend_class(**frontend_conf)
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get_class(specaug.lower())
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get_class(normalize.lower())
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = tables.encoder_classes.get_class(encoder.lower())
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+ if decoder is not None:
+ decoder_class = tables.decoder_classes.get_class(decoder.lower())
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
- if not hasattr(self.encoder, "interctc_use_conditioning"):
- self.encoder.interctc_use_conditioning = False
- if self.encoder.interctc_use_conditioning:
- self.encoder.conditioning_layer = torch.nn.Linear(
- vocab_size, self.encoder.output_size()
- )
- self.interctc_weight = interctc_weight
+ if not hasattr(self.encoder, "interctc_use_conditioning"):
+ self.encoder.interctc_use_conditioning = False
+ if self.encoder.interctc_use_conditioning:
+ self.encoder.conditioning_layer = torch.nn.Linear(
+ vocab_size, self.encoder.output_size()
+ )
+ self.interctc_weight = interctc_weight
- # self.error_calculator = None
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
-
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
-
- self.length_normalized_loss = length_normalized_loss
- self.beam_search = None
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- # import pdb;
- # pdb.set_trace()
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- stats = dict()
-
- # decoder: CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
-
- # Collect total loss stats
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
+ # self.error_calculator = None
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+
+ # Collect total loss stats
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
-
- # Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- speech, speech_lengths, ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
-
- return encoder_out, encoder_out_lens
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- )
-
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
- acc_att = th_accuracy(
- decoder_out.view(-1, self.vocab_size),
- ys_out_pad,
- ignore_label=self.ignore_id,
- )
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
-
- def init_beam_search(self,
- **kwargs,
- ):
- from funasr.models.transformer.search import BeamSearch
- from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- if self.ctc != None:
- ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = kwargs.get("token_list")
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ speech, speech_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.transformer.search import BeamSearch
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- weights = dict(
- decoder=1.0 - kwargs.get("decoding_ctc_weight"),
- ctc=kwargs.get("decoding_ctc_weight", 0.0),
- lm=kwargs.get("lm_weight", 0.0),
- ngram=kwargs.get("ngram_weight", 0.0),
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 2),
- weights=weights,
- scorers=scorers,
- sos=self.sos,
- eos=self.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
- )
- # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
- # for scorer in scorers.values():
- # if isinstance(scorer, torch.nn.Module):
- # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
- self.beam_search = beam_search
-
- def generate(self,
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+ ctc=kwargs.get("decoding_ctc_weight", 0.0),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 2),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+ # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ # for scorer in scorers.values():
+ # if isinstance(scorer, torch.nn.Module):
+ # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ self.beam_search = beam_search
+
+ def generate(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):
-
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
- results = []
- b, n, d = encoder_out.size()
- for i in range(b):
+ results = []
+ b, n, d = encoder_out.size()
+ for i in range(b):
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx+1}best_recog"]
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-
- return results, meta_data
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+ return results, meta_data
--
Gitblit v1.9.1