zhifu gao
2024-03-07 341182c3bfc62831aa02781d0e6bbe2a479f3fb2
Dev gzf (#1440)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm
7个文件已修改
5个文件已添加
606 ■■■■ 已修改文件
examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml 93 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_vicuna/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_vicuna/datasets.py 150 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_vicuna/preprocessor.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_vicuna/samplers.py 179 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/frontends/whisper_frontend.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py 123 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr_nar/model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper/model.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/register.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml
New file
@@ -0,0 +1,93 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: LLMASR
model_conf:
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: true
# encoder
audio_encoder: "/nfs/zhifu.gzf/init_model/Whisper-large-v3" #iic/Whisper-large-v3
audio_encoder_conf:
    hub: ms
    freeze: true
llm: Vicuna
llm_conf:
  hub: hf
  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
  freeze: true
audio_adaptor: Linear
audio_adaptor_conf:
  downsample_rate: 5
  llm_dim: 4096
  encoder_dim: 512
# frontend related
frontend: WhisperFrontend
frontend_conf:
    fs: 16000
    whisper_model: large-v3
    do_pad_trim: true
    permute: true # true: [bs, frames, dims]; false: [bs, dims, frames]
specaug: SpecAugLFR
specaug_conf:
    apply_time_warp: false
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    lfr_rate: 6
    num_freq_mask: 1
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 12
    num_time_mask: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  keep_nbest_models: 10
  log_interval: 10
optim: adamw
optim_conf:
   lr: 0.0001
   weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 1500
dataset: AudioLLMVicunaDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: RankFullLocalShuffleBatchSampler
    batch_type: example # example or length
    batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 500
    shuffle: True
    num_workers: 4
#    preprocessor_text: TextPreprocessRemovePunctuation
    audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
    audio_encoder_downsample_rate: 2
tokenizer: HuggingfaceTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
@@ -19,11 +19,11 @@
val_data="/nfs/zhifu.gzf/data/datalist/aishell1_aishell2_wav_speech_llm_train_data_tail500.json"
# exp output dir
output_dir="/Users/zhifu/exp"
output_dir="/nfs/zhifu.gzf/ckpt/exp/llm_asr_whisper_vicuna_exp1"
log_file="${output_dir}/log.txt"
workspace=`pwd`
config="template.yaml"
config="whisper_vicuna_linear.yaml"
init_param="${output_dir}/model.pt"
funasr/datasets/llm_datasets_vicuna/__init__.py
funasr/datasets/llm_datasets_vicuna/datasets.py
New file
@@ -0,0 +1,150 @@
import torch
import copy
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@tables.register("dataset_classes", "AudioLLMVicunaDataset")
class AudioLLMVicunaDataset(torch.utils.data.Dataset):
    """
    AudioLLMDataset
    """
    def __init__(self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                 **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.float_pad_value = float_pad_value
        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
        # self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
        self.prompt_af = ""
        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
        self.int_pad_value = self.IGNORE_INDEX
        self.audio_adaptor_downsample_rate = kwargs.get("audio_adaptor_downsample_rate", 5)
        self.audio_encoder_downsample_rate = kwargs.get("audio_encoder_downsample_rate", 2)
        self.prompt_template = "USER: {}\n ASSISTANT:"
        self.answer_template = "{}"
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
                                               is_final=True)  # speech: [b, T, d]
        speech = speech.squeeze(0)
        audio_pseudo_length = (speech.shape[0] + 1) // self.audio_adaptor_downsample_rate // self.audio_encoder_downsample_rate
        audio_pseudo = torch.full((audio_pseudo_length,), -1) # placeholder
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        self.prompt_pre = self.prompt_template.format(self.prompt)
        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
        prompt_pre_length = len(prompt_ids_pre)
        # input
        input = self.answer_template.format(target.lower())
        prompt_input = "{}{}".format(self.prompt_pre, input)
        prompt_input_ids = self.tokenizer.encode(prompt_input) # [bos, prompt, input]
        # audio_length = len(prompt_input_ids) - prompt_pre_length
        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id] # [bos, prompt, input, pad]
        input_ids_length = len(input_ids)
        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
        input_ids = torch.cat((audio_pseudo, input_ids)) # [audio, bos, prompt, input, pad]
        # input_ids[:audio_pseudo_length] = -1 # [-1, bos, prompt, input, pad]
        attention_mask = input_ids.ge(-1)  # [true, true, true, true, true], length mask
        # input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
        # attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
        # label
        answer = self.answer_template.format(target.lower())
        prompt_answer = "{}{}".format(self.prompt_pre, answer)
        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
        # answer_length = len(prompt_answer_ids) - prompt_pre_length
        labels_ids = copy.deepcopy(prompt_answer_ids) + [self.tokenizer.eos_token_id]
        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, answer, eos]
        labels_ids = torch.cat((audio_pseudo, labels_ids))  # [audio, bos, prompt, answer, eos]
        labels_ids[:audio_pseudo_length+prompt_pre_length] = -1 # [-1, -1, -1, answer, eos]
        # labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
        label_mask = labels_ids.ge(0)  # [false, false, false, true, true]
        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100, -100, -100, answer, eos]
        # audio_mask for input_ids
        audio_mask = [1]*audio_pseudo_length + [0]*input_ids_length
        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
        ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
        text = torch.tensor(ids, dtype=torch.int64)
        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
        return {"speech": speech,
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels_ids": labels_ids,
                "label_mask": label_mask,
                "audio_mask": audio_mask,
                }
    def collator(self, samples: list = None):
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs
funasr/datasets/llm_datasets_vicuna/preprocessor.py
New file
@@ -0,0 +1,37 @@
import os
import json
import torch
import logging
import concurrent.futures
import librosa
import torch.distributed as dist
from typing import Collection
import torch
import torchaudio
from torch import nn
import random
import re
import string
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
class TextPreprocessSegDict(nn.Module):
    def __init__(self,
                 **kwargs):
        super().__init__()
    def forward(self, text, **kwargs):
        # 定义英文标点符号
        en_punct = string.punctuation
        # 定义中文标点符号(部分常用的)
        cn_punct = '。?!,、;:“”‘’()《》【】…—~·'
        # 合并英文和中文标点符号
        all_punct = en_punct + cn_punct
        # 创建正则表达式模式,匹配任何在all_punct中的字符
        punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
        # 使用正则表达式的sub方法替换掉这些字符
        return punct_pattern.sub('', text)
funasr/datasets/llm_datasets_vicuna/samplers.py
New file
@@ -0,0 +1,179 @@
import torch
import numpy as np
import logging
import torch.distributed as dist
from funasr.register import tables
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = False,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
    def __len__(self):
        return (self.total_samples-1) // self.batch_size + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_type != 'example':
                    max_token_padding *= max_token_cur
                if max_token_padding <= self.batch_size:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    yield batch
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
@tables.register("batch_sampler_classes", "BatchSampler")
@tables.register("batch_sampler_classes", "RankFullGlobalShuffleBatchSampler")
class RankFullGlobalShuffleBatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset,
                 batch_type: str = "example",
                 batch_size: int = 100,
                 buffer_size: int = 30,
                 drop_last: bool = True,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 1500)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle and is_training
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
        self.rank = rank
        self.world_size = world_size
    def __len__(self):
        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
    def set_epoch(self, epoch):
        np.random.seed(epoch)
    def __iter__(self):
        batch_size_total = self.batch_size * self.world_size
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples - 1) // self.buffer_size + 1
        # print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            # if iter == iter_num -1 and self.drop_last:
            #     continue
            datalen_with_index = []
            for i in range(self.buffer_size):
                idx = iter * self.buffer_size + i
                if idx >= self.total_samples:
                    continue
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
                sample_len_cur = source_len + target_len
                datalen_with_index.append([idx, sample_len_cur])
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_token_length:
                    continue
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                # if self.batch_type != 'example':
                #     max_token_padding *= max_token_cur
                if max_token_padding <= batch_size_total:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
                    yield batch_rank
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
funasr/frontends/whisper_frontend.py
@@ -20,6 +20,8 @@
            whisper_model: str = None,
            do_pad_trim: bool = True,
            n_mels: int = 80,
            permute: bool = False,
            **kwargs,
    ):
        super().__init__()
        assert fs == 16000
@@ -39,6 +41,7 @@
        self.do_pad_trim = do_pad_trim
        if do_pad_trim:
            self.pad_or_trim = whisper.pad_or_trim
        self.permute = permute
        # assert whisper_model in whisper.available_models()
@@ -77,7 +80,7 @@
        return log_spec, olens
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor
            self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
@@ -98,5 +101,6 @@
            feats_pad = pad_sequence(feats,
                                     batch_first=True,
                                     padding_value=0.0)
        if self.permute:
            feats_pad = feats_pad.permute(0, 2, 1)
        return feats_pad, feats_lens
funasr/models/llm_asr/model.py
@@ -12,7 +12,7 @@
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.metrics.common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
@@ -30,8 +30,10 @@
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        audio_encoder: str = None,
        audio_encoder_conf: dict = None,
        audio_adaptor: str = None,
        audio_adaptor_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
@@ -39,8 +41,6 @@
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        adaptor: str = None,
        adaptor_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
@@ -70,23 +70,30 @@
            normalize = normalize_class(**normalize_conf)
        
        # audio encoder
        hub = encoder_conf.get("hub", None)
        if hub == "funasr":
        hub = audio_encoder_conf.get("hub", None)
        if hub == "ms":
            from funasr import AutoModel
            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
            model = AutoModel(model=audio_encoder, model_revision="v2.0.4")
            # frontend = model.kwargs.get("frontend")
            model.model.decoder = None
            audio_encoder_output_size = model.model.encoder_output_size
            
            self.audio_encoder = model.model
            audio_encoder = model.model.model.encoder
            # self.frontend = frontend
            
        elif hub == "hf":
            pass
        else:
            encoder_class = tables.encoder_classes.get(encoder)
            encoder = encoder_class(input_size=input_size, **encoder_conf)
            encoder_output_size = encoder.output_size()
            encoder_class = tables.encoder_classes.get(audio_encoder)
            audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
            audio_encoder_output_size = audio_encoder.output_size()
        freeze = audio_encoder_conf.get("freeze", True)
        if freeze:
            for name, param in audio_encoder.named_parameters():
                param.requires_grad = False
            audio_encoder.eval()
        self.audio_encoder = audio_encoder
        # llm
        hub = llm_conf.get("hub", "hf")
@@ -95,6 +102,7 @@
            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
            model = AutoModelForCausalLM.from_pretrained(
                init_param_path,
                load_in_8bit=None,
@@ -109,10 +117,11 @@
            self.llm = model
        
        # adaptor
        adaptor_class = tables.adaptor_classes.get(adaptor)
        adaptor = adaptor_class(**adaptor_conf)
        adaptor_class = tables.adaptor_classes.get(audio_adaptor)
        audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
        audio_adaptor = adaptor_class(**audio_adaptor_conf)
        
        self.adaptor = adaptor
        self.audio_adaptor = audio_adaptor
        
        
        self.blank_id = blank_id
@@ -122,8 +131,6 @@
        self.ignore_id = ignore_id
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
@@ -131,12 +138,7 @@
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
        self.error_calculator = None
        self.length_normalized_loss = length_normalized_loss
@@ -172,12 +174,11 @@
        batch_size = speech.shape[0]
        
        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        
        # adaptor
        encoder_out = self.adaptor(encoder_out)
        # audio_adaptor
        encoder_out = self.audio_adaptor(encoder_out)
        if input_ids is not None:
            input_ids[input_ids == -1] = 0
            input_ids[input_ids == -100] = 0
            if hasattr(self.llm.model, "embed_tokens"):
@@ -190,9 +191,9 @@
            if audio_mask is not None:
                batch_size, token_num, dims = inputs_embeds.shape
                _, l, _ = encoder_out.shape
                encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
            # [audio, bos, prompt, input, pad]
            encoder_outs_pad = F.pad(encoder_out, (0, 0, 0, token_num - l, 0, 0), value=0.0)
                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
                inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
        loss = model_outputs.loss
@@ -214,22 +215,14 @@
    
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
        with autocast(False):
            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
                                                                               mask=enc_mask,
                                                                               target_label_length=audio_token_lengths,
                                                                               )
        return pre_acoustic_embeds, pre_token_length
    ):
        speech = speech.permute(0, 2, 1)
        res = self.audio_encoder(speech)
        if len(res) > 1:
            encoder_out, encoder_out_lens = res[0], res[1]
        else:
            encoder_out, encoder_out_lens = res, speech_lengths
        return encoder_out, encoder_out_lens
    def inference(self,
                  data_in,
@@ -275,7 +268,7 @@
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # adaptor
        encoder_out = self.adaptor(encoder_out)
        encoder_out = self.audio_adaptor(encoder_out)
        
    
        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
@@ -294,26 +287,24 @@
        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
        
        # model_outputs = self.llm.generate(
        #     inputs_embeds=inputs_embeds,
        #     max_length=kwargs.get("max_length", 200),
        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
        #     num_beams=kwargs.get("num_beams", 4),
        #     do_sample=kwargs.get("do_sample", False),
        #     min_length=kwargs.get("min_length", 1),
        #     top_p=kwargs.get("top_p", 1.0),
        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
        #     length_penalty=kwargs.get("length_penalty", 1.0),
        #     temperature=kwargs.get("temperature", 1.0),
        #     attention_mask=attention_mask,
        #     bos_token_id=tokenizer.bos_token_id,
        #     eos_token_id=tokenizer.eos_token_id,
        #     pad_token_id=tokenizer.pad_token_id
        # )
        preds = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_length=kwargs.get("max_length", 200),
            max_new_tokens=kwargs.get("max_new_tokens", 200),
            num_beams=kwargs.get("num_beams", 4),
            do_sample=kwargs.get("do_sample", False),
            min_length=kwargs.get("min_length", 1),
            top_p=kwargs.get("top_p", 1.0),
            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
            length_penalty=kwargs.get("length_penalty", 1.0),
            temperature=kwargs.get("temperature", 1.0),
            attention_mask=attention_mask,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
        preds = torch.argmax(model_outputs.logits, -1)
        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
        text = text[0].split(': ')[-1]
funasr/models/llm_asr_nar/model.py
@@ -214,7 +214,7 @@
    
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    ):
    
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
funasr/models/whisper/model.py
@@ -41,6 +41,8 @@
        
        self.model = model
        
        self.encoder_output_size = self.model.dims.n_audio_state
    def forward(self, ):
        pass
    
funasr/register.py
@@ -29,7 +29,7 @@
                flag = key in classes_key
            if classes_key.endswith("_meta") and flag:
                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                headers = ["class name", "class location"]
                headers = ["register name", "class name", "class location"]
                metas = []
                for register_key, meta in classes_dict.items():
                    metas.append(meta)
@@ -67,8 +67,8 @@
            class_line = inspect.getsourcelines(target_class)[1]
            pattern = r'^.+/funasr/'
            class_file = re.sub(pattern, 'funasr/', class_file)
            meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
            # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
            # meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
            meata_data = [f"{registry_key}", f"{target_class.__name__}", f"{class_file}:{class_line}"]
            registry_meta[registry_key] = meata_data
            # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
            return target_class
funasr/train_utils/trainer.py
@@ -163,7 +163,7 @@
                self.scaler.load_state_dict(checkpoint['scaler_state'])
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', starting from scratch")
            print(f"No checkpoint found at '{ckpt}', does not resume status!")
        if self.use_ddp or self.use_fsdp:
            dist.barrier()