zhifu gao
2024-03-11 4a7a984a5f3e3f894f86ce82e76ddd13d8a42a20
funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -99,6 +99,7 @@
        )
        self.train_decoder = kwargs.get("train_decoder", False)
        self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
        self.predictor_name = kwargs.get("predictor")
        
    def forward(
        self,
@@ -127,7 +128,7 @@
    
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        dha_pad = kwargs.get("dha_pad")
        seaco_label_pad = kwargs.get("seaco_label_pad")
        
        batch_size = speech.shape[0]
        # for data-parallel
@@ -147,7 +148,7 @@
                                        ys_lengths, 
                                        hotword_pad, 
                                        hotword_lengths, 
                                        dha_pad,
                                        seaco_label_pad,
                                        )
        if self.train_decoder:
            loss_att, acc_att = self._calc_att_loss(
@@ -170,6 +171,12 @@
    def _merge(self, cif_attended, dec_attended):
        return cif_attended + dec_attended
    
    def calc_predictor(self, encoder_out, encoder_out_lens):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        predictor_outs = self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
        return predictor_outs[:4]
    def _calc_seaco_loss(
            self,
            encoder_out: torch.Tensor,
@@ -178,7 +185,7 @@
            ys_lengths: torch.Tensor,
            hotword_pad: torch.Tensor,
            hotword_lengths: torch.Tensor,
            dha_pad: torch.Tensor,
            seaco_label_pad: torch.Tensor,
    ):  
        # predictor forward
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
@@ -197,7 +204,7 @@
        dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
        merged = self._merge(cif_attended, dec_attended)
        dha_output = self.hotword_output_layer(merged[:, :-1])  # remove the last token in loss calculation
        loss_att = self.criterion_seaco(dha_output, dha_pad)
        loss_att = self.criterion_seaco(dha_output, seaco_label_pad)
        return loss_att
    def _seaco_decode_with_ASF(self, 
@@ -248,7 +255,7 @@
            def _merge_res(dec_output, dha_output):
                lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
                dha_ids = dha_output.max(-1)[-1]# [0]
                dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
                dha_mask = (dha_ids == self.NO_BIAS).int().unsqueeze(-1)
                a = (1 - lmbd) / lmbd
                b = 1 / lmbd
                a, b = a.to(dec_output.device), b.to(dec_output.device)
@@ -332,23 +339,28 @@
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
                                                   pre_acoustic_embeds,
                                                   pre_token_length,
                                                   hw_list=self.hotword_list)
        decoder_out = self._seaco_decode_with_ASF(encoder_out,
                                                  encoder_out_lens,
                                                  pre_acoustic_embeds,
                                                  pre_token_length,
                                                  hw_list=self.hotword_list
                                                  )
        # decoder_out, _ = decoder_outs[0], decoder_outs[1]
        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                  pre_token_length)
        if self.predictor_name == "CifPredictorV3":
            _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out,
                                                                      encoder_out_lens,
                                                                      pre_token_length)
        else:
            us_alphas = None
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
@@ -393,23 +405,25 @@
                    # Change integer-ids to tokens
                    token = tokenizer.ids2tokens(token_int)
                    text = tokenizer.tokens2text(token)
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
                                                               us_peaks[i][:encoder_out_lens[i] * 3],
                                                               copy.copy(token),
                                                               vad_offset=kwargs.get("begin_time", 0))
                    text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
                        token, timestamp)
                    result_i = {"key": key[i], "text": text_postprocessed,
                                "timestamp": time_stamp_postprocessed
                                }
                    if ibest_writer is not None:
                        ibest_writer["token"][key[i]] = " ".join(token)
                        ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
                        ibest_writer["text"][key[i]] = text_postprocessed
                    if us_alphas is not None:
                        _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
                                                                us_peaks[i][:encoder_out_lens[i] * 3],
                                                                copy.copy(token),
                                                                vad_offset=kwargs.get("begin_time", 0))
                        text_postprocessed, time_stamp_postprocessed, _ = \
                            postprocess_utils.sentence_postprocess(token, timestamp)
                        result_i = {"key": key[i], "text": text_postprocessed,
                                    "timestamp": time_stamp_postprocessed}
                        if ibest_writer is not None:
                            ibest_writer["token"][key[i]] = " ".join(token)
                            ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
                            ibest_writer["text"][key[i]] = text_postprocessed
                    else:
                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                        result_i = {"key": key[i], "text": text_postprocessed}
                        if ibest_writer is not None:
                            ibest_writer["token"][key[i]] = " ".join(token)
                            ibest_writer["text"][key[i]] = text_postprocessed
                else:
                    result_i = {"key": key[i], "token_int": token_int}
                results.append(result_i)