游雁
2024-06-06 e9acc5db07daa51a22cd51ea9233ee09a38d726d
auto frontend
3个文件已修改
89 ■■■■ 已修改文件
examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/openai_datasets/datasets.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py 76 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml
@@ -35,7 +35,7 @@
frontend_conf:
    fs: 16000
    whisper_model: large-v3
    do_pad_trim: true
    do_pad_trim: false
    permute: false # true: [bs, frames, dims]; false: [bs, dims, frames]
    filters_path: "/nfs/zhifu.gzf/init_model/SenseVoiceModelscope/assets/mel_filters.npz"
funasr/datasets/openai_datasets/datasets.py
@@ -123,21 +123,20 @@
                            )  # speech: [b, T, d]
                            if self.permute:
                                speech = speech.permute(0, 2, 1)
                            if speech_lengths > self.batch_size:
                                continue
                            # if speech_lengths > self.batch_size:
                            #     continue
                            fbank_lens = speech_lengths[0].item()
                            olens = 1 + (fbanks_len - 3 + 2 * 1) // 2
                            olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                            olens = 1 + (olens - 3 + 2 * 1) // 2
                            sub_token_len = (olens - 1) // 2 + 1
                            sub_token = [0] * sub_token_len[0]
                            sub_token = [0] * sub_token_len
                            fbank_beg_i = [len(source_ids)]
                            source_ids += sub_token
                            fbank_mask_i += [1] * len(sub_token)
                source_mask = [-100] * len(source_ids)
                target_out = f"{target_out}<|im_end|>"
                target_ids = tokenizer.encode(target_out)
                target_ids = self.tokenizer.encode(target_out)
                input_ids += source_ids + target_ids
                labels += source_mask + target_ids
                fbank_mask += fbank_mask_i
funasr/models/llm_asr/model.py
@@ -385,13 +385,6 @@
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        # audio encoder
        hub = audio_encoder_conf.get("hub", None)
        if hub == "ms":
@@ -422,23 +415,23 @@
        # llm
        hub = llm_conf.get("hub", "hf")
        self.llm = None
        # if hub == "hf":
        #     from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
        #
        #     init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
        #
        #     model = AutoModelForCausalLM.from_pretrained(
        #         init_param_path,
        #         load_in_8bit=None,
        #         device_map=None,
        #         use_cache=None,
        #     )
        #     freeze = llm_conf.get("freeze", True)
        #     if freeze:
        #         for name, param in model.named_parameters():
        #             param.requires_grad = False
        #         model.eval()
        #     self.llm = model
        if hub == "hf":
            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
            model = AutoModelForCausalLM.from_pretrained(
                init_param_path,
                load_in_8bit=None,
                device_map=None,
                use_cache=None,
            )
            freeze = llm_conf.get("freeze", True)
            if freeze:
                for name, param in model.named_parameters():
                    param.requires_grad = False
                model.eval()
            self.llm = model
        # adaptor
        adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@@ -446,21 +439,6 @@
        audio_adaptor = adaptor_class(**audio_adaptor_conf)
        self.audio_adaptor = audio_adaptor
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.specaug = specaug
        self.normalize = normalize
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        self.error_calculator = None
@@ -493,10 +471,10 @@
        batch_size = speech.shape[0]
        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
        # audio_adaptor
        encoder_out = self.audio_adaptor(encoder_out)
        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
        input_ids[input_ids == -1] = 0
        input_ids[input_ids == -100] = 0
@@ -530,23 +508,9 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
            batch_size = int((labels_ids > 0 + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        speech = speech.permute(0, 2, 1)
        res = self.audio_encoder(speech)
        if isinstance(res, (list, tuple)):
            encoder_out, encoder_out_lens = res[0], res[1]
        else:
            encoder_out, encoder_out_lens = res, speech_lengths
        return encoder_out, encoder_out_lens
    def inference(
        self,