Shi Xian
2024-03-21 980007a486a4a2cfbcf4f98c7a88e28e17fd0f48
update seaco finetune (#1526)

3个文件已修改
15 ■■■■ 已修改文件
examples/industrial_data_pretraining/seaco_paraformer/finetune.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/model.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
@@ -10,7 +10,7 @@
## option 1, download model automatically
model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
model_revision="v2.0.7"
## option 2, download model by git
#local_path_root=${workspace}/modelscope_models
funasr/models/contextual_paraformer/model.py
@@ -94,10 +94,8 @@
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        text_lengths = text_lengths.squeeze()
        speech_lengths = speech_lengths.squeeze()
        batch_size = speech.shape[0]
funasr/models/seaco_paraformer/model.py
@@ -117,6 +117,8 @@
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        text_lengths = text_lengths.squeeze()
        speech_lengths = speech_lengths.squeeze()
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
@@ -164,7 +166,7 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
            batch_size = (text_lengths + self.predictor_bias).sum()
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -190,8 +192,7 @@
        # predictor forward
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                                  ignore_id=self.ignore_id)
        pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
        # decoder forward
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
        selected = self._hotword_representation(hotword_pad,