From 341182c3bfc62831aa02781d0e6bbe2a479f3fb2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 07 三月 2024 10:03:02 +0800
Subject: [PATCH] Dev gzf (#1440)

---
 examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml |   93 ++++++++
 funasr/models/llm_asr_nar/model.py                                           |    2 
 examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh       |    4 
 funasr/models/llm_asr/model.py                                               |  151 ++++++-------
 funasr/train_utils/trainer.py                                                |    2 
 funasr/register.py                                                           |    6 
 funasr/datasets/llm_datasets_vicuna/__init__.py                              |    0 
 funasr/models/whisper/model.py                                               |    2 
 funasr/datasets/llm_datasets_vicuna/datasets.py                              |  150 +++++++++++++
 funasr/datasets/llm_datasets_vicuna/samplers.py                              |  179 ++++++++++++++++
 funasr/frontends/whisper_frontend.py                                         |    8 
 funasr/datasets/llm_datasets_vicuna/preprocessor.py                          |   37 +++
 12 files changed, 545 insertions(+), 89 deletions(-)

diff --git a/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml b/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml
new file mode 100644
index 0000000..eacbd35
--- /dev/null
+++ b/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml
@@ -0,0 +1,93 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: LLMASR
+model_conf:
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: true
+
+# encoder
+audio_encoder: "/nfs/zhifu.gzf/init_model/Whisper-large-v3" #iic/Whisper-large-v3
+audio_encoder_conf:
+    hub: ms
+    freeze: true
+
+llm: Vicuna
+llm_conf:
+  hub: hf
+  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+  freeze: true
+
+audio_adaptor: Linear
+audio_adaptor_conf:
+  downsample_rate: 5
+  llm_dim: 4096
+  encoder_dim: 512
+
+# frontend related
+frontend: WhisperFrontend
+frontend_conf:
+    fs: 16000
+    whisper_model: large-v3
+    do_pad_trim: true
+    permute: true # true: [bs, frames, dims]; false: [bs, dims, frames]
+
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  keep_nbest_models: 10
+  log_interval: 10
+
+optim: adamw
+optim_conf:
+   lr: 0.0001
+   weight_decay: 0.000001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 1500
+
+dataset: AudioLLMVicunaDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_type: example # example or length
+    batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 4
+#    preprocessor_text: TextPreprocessRemovePunctuation
+    audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
+    audio_encoder_downsample_rate: 2
+
+
+
+tokenizer: HuggingfaceTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
index a518d57..6632175 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
@@ -19,11 +19,11 @@
 val_data="/nfs/zhifu.gzf/data/datalist/aishell1_aishell2_wav_speech_llm_train_data_tail500.json"
 
 # exp output dir
-output_dir="/Users/zhifu/exp"
+output_dir="/nfs/zhifu.gzf/ckpt/exp/llm_asr_whisper_vicuna_exp1"
 log_file="${output_dir}/log.txt"
 
 workspace=`pwd`
-config="template.yaml"
+config="whisper_vicuna_linear.yaml"
 
 init_param="${output_dir}/model.pt"
 
diff --git a/funasr/datasets/llm_datasets_vicuna/__init__.py b/funasr/datasets/llm_datasets_vicuna/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/llm_datasets_vicuna/__init__.py
diff --git a/funasr/datasets/llm_datasets_vicuna/datasets.py b/funasr/datasets/llm_datasets_vicuna/datasets.py
new file mode 100644
index 0000000..6bf6d04
--- /dev/null
+++ b/funasr/datasets/llm_datasets_vicuna/datasets.py
@@ -0,0 +1,150 @@
+import torch
+import copy
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "AudioLLMVicunaDataset")
+class AudioLLMVicunaDataset(torch.utils.data.Dataset):
+    """
+    AudioLLMDataset
+    """
+    
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                 **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+        
+        self.float_pad_value = float_pad_value
+        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        # self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+        self.int_pad_value = self.IGNORE_INDEX
+        self.audio_adaptor_downsample_rate = kwargs.get("audio_adaptor_downsample_rate", 5)
+        self.audio_encoder_downsample_rate = kwargs.get("audio_encoder_downsample_rate", 2)
+        self.prompt_template = "USER: {}\n ASSISTANT:"
+        self.answer_template = "{}"
+        
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
+                                               is_final=True)  # speech: [b, T, d]
+        speech = speech.squeeze(0)
+
+        audio_pseudo_length = (speech.shape[0] + 1) // self.audio_adaptor_downsample_rate // self.audio_encoder_downsample_rate
+        audio_pseudo = torch.full((audio_pseudo_length,), -1) # placeholder
+        
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+
+        self.prompt_pre = self.prompt_template.format(self.prompt)
+        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
+        prompt_pre_length = len(prompt_ids_pre)
+        
+        # input
+        input = self.answer_template.format(target.lower())
+        prompt_input = "{}{}".format(self.prompt_pre, input)
+        prompt_input_ids = self.tokenizer.encode(prompt_input) # [bos, prompt, input]
+        # audio_length = len(prompt_input_ids) - prompt_pre_length
+        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id] # [bos, prompt, input, pad]
+        input_ids_length = len(input_ids)
+        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
+        input_ids = torch.cat((audio_pseudo, input_ids)) # [audio, bos, prompt, input, pad]
+        # input_ids[:audio_pseudo_length] = -1 # [-1, bos, prompt, input, pad]
+        attention_mask = input_ids.ge(-1)  # [true, true, true, true, true], length mask
+        # input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
+        # attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
+        
+        # label
+        answer = self.answer_template.format(target.lower())
+        prompt_answer = "{}{}".format(self.prompt_pre, answer)
+        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+        # answer_length = len(prompt_answer_ids) - prompt_pre_length
+        labels_ids = copy.deepcopy(prompt_answer_ids) + [self.tokenizer.eos_token_id]
+        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, answer, eos]
+        labels_ids = torch.cat((audio_pseudo, labels_ids))  # [audio, bos, prompt, answer, eos]
+        labels_ids[:audio_pseudo_length+prompt_pre_length] = -1 # [-1, -1, -1, answer, eos]
+        # labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
+        label_mask = labels_ids.ge(0)  # [false, false, false, true, true]
+        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100, -100, -100, answer, eos]
+        
+        # audio_mask for input_ids
+        audio_mask = [1]*audio_pseudo_length + [0]*input_ids_length
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+        
+        ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
+        text = torch.tensor(ids, dtype=torch.int64)
+        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+        
+        return {"speech": speech,
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "input_ids": input_ids,
+                "attention_mask": attention_mask,
+                "labels_ids": labels_ids,
+                "label_mask": label_mask,
+                "audio_mask": audio_mask,
+                }
+    
+    def collator(self, samples: list = None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+        
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64:
+                    
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+                
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
diff --git a/funasr/datasets/llm_datasets_vicuna/preprocessor.py b/funasr/datasets/llm_datasets_vicuna/preprocessor.py
new file mode 100644
index 0000000..9f20672
--- /dev/null
+++ b/funasr/datasets/llm_datasets_vicuna/preprocessor.py
@@ -0,0 +1,37 @@
+import os
+import json
+import torch
+import logging
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from typing import Collection
+import torch
+import torchaudio
+from torch import nn
+import random
+import re
+import string
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.register import tables
+
+
+
+@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
+class TextPreprocessSegDict(nn.Module):
+	def __init__(self,
+	             **kwargs):
+		super().__init__()
+		
+	
+	def forward(self, text, **kwargs):
+		# 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
+		en_punct = string.punctuation
+		# 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
+		cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
+		# 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
+		all_punct = en_punct + cn_punct
+		# 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
+		punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
+		# 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
+		return punct_pattern.sub('', text)
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
new file mode 100644
index 0000000..29ec11d
--- /dev/null
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -0,0 +1,179 @@
+import torch
+import numpy as np
+import logging
+import torch.distributed as dist
+
+from funasr.register import tables
+
+
+@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
+class BatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = False,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 5000)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+    
+    def __len__(self):
+        return (self.total_samples-1) // self.batch_size + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+        
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                sample_len_cur = source_len + target_len
+                
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for item in datalen_with_index_sort:
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+                
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                if self.batch_type != 'example':
+                    max_token_padding *= max_token_cur
+                if max_token_padding <= self.batch_size:
+                    batch.append(idx)
+                    max_token = max_token_cur
+                    num_sample += 1
+                else:
+                    yield batch
+                    batch = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
+
+
+@tables.register("batch_sampler_classes", "BatchSampler")
+@tables.register("batch_sampler_classes", "RankFullGlobalShuffleBatchSampler")
+class RankFullGlobalShuffleBatchSampler(torch.utils.data.BatchSampler):
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = True,
+                 shuffle: bool = True,
+                 is_training: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = int(batch_size)
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 1500)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle and is_training
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+        self.rank = rank
+        self.world_size = world_size
+        
+    def __len__(self):
+        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+    
+        batch_size_total = self.batch_size * self.world_size
+        
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            # if iter == iter_num -1 and self.drop_last:
+            #     continue
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                
+                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
+                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
+                sample_len_cur = source_len + target_len
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for item in datalen_with_index_sort:
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                # if self.batch_type != 'example':
+                #     max_token_padding *= max_token_cur
+                if max_token_padding <= batch_size_total:
+                    batch.append(idx)
+                    max_token = max_token_cur
+                    num_sample += 1
+                else:
+                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
+                    yield batch_rank
+                    batch = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
+
diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index 9290a25..0598c61 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -20,6 +20,8 @@
             whisper_model: str = None,
             do_pad_trim: bool = True,
             n_mels: int = 80,
+            permute: bool = False,
+            **kwargs,
     ):
         super().__init__()
         assert fs == 16000
@@ -39,6 +41,7 @@
         self.do_pad_trim = do_pad_trim
         if do_pad_trim:
             self.pad_or_trim = whisper.pad_or_trim
+        self.permute = permute
 
         # assert whisper_model in whisper.available_models()
 
@@ -77,7 +80,7 @@
         return log_spec, olens
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+            self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         batch_size = input.size(0)
         feats = []
@@ -98,5 +101,6 @@
             feats_pad = pad_sequence(feats,
                                      batch_first=True,
                                      padding_value=0.0)
-
+        if self.permute:
+            feats_pad = feats_pad.permute(0, 2, 1)
         return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 4139d8c..3223190 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -12,7 +12,7 @@
 from funasr.models.ctc.ctc import CTC
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
-# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.metrics.common import ErrorCalculator
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 from funasr.utils import postprocess_utils
@@ -30,8 +30,10 @@
         specaug_conf: dict = None,
         normalize: str = None,
         normalize_conf: dict = None,
-        encoder: str = None,
-        encoder_conf: dict = None,
+        audio_encoder: str = None,
+        audio_encoder_conf: dict = None,
+        audio_adaptor: str = None,
+        audio_adaptor_conf: dict = None,
         decoder: str = None,
         decoder_conf: dict = None,
         ctc: str = None,
@@ -39,8 +41,6 @@
         ctc_weight: float = 0.5,
         llm: str = None,
         llm_conf: dict = None,
-        adaptor: str = None,
-        adaptor_conf: dict = None,
         input_size: int = 80,
         vocab_size: int = -1,
         ignore_id: int = -1,
@@ -70,23 +70,30 @@
             normalize = normalize_class(**normalize_conf)
         
         # audio encoder
-        hub = encoder_conf.get("hub", None)
-        if hub == "funasr":
+        hub = audio_encoder_conf.get("hub", None)
+        if hub == "ms":
             from funasr import AutoModel
-            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+            model = AutoModel(model=audio_encoder, model_revision="v2.0.4")
             # frontend = model.kwargs.get("frontend")
-            model.model.decoder = None
+            audio_encoder_output_size = model.model.encoder_output_size
+
+            audio_encoder = model.model.model.encoder
             
-            self.audio_encoder = model.model
             # self.frontend = frontend
             
         elif hub == "hf":
             pass
         else:
-            encoder_class = tables.encoder_classes.get(encoder)
-            encoder = encoder_class(input_size=input_size, **encoder_conf)
-            encoder_output_size = encoder.output_size()
+            encoder_class = tables.encoder_classes.get(audio_encoder)
+            audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
+            audio_encoder_output_size = audio_encoder.output_size()
+        freeze = audio_encoder_conf.get("freeze", True)
+        if freeze:
+            for name, param in audio_encoder.named_parameters():
+                param.requires_grad = False
+            audio_encoder.eval()
+            
+        self.audio_encoder = audio_encoder
 
         # llm
         hub = llm_conf.get("hub", "hf")
@@ -95,6 +102,7 @@
             from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
 
             init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+            
             model = AutoModelForCausalLM.from_pretrained(
                 init_param_path,
                 load_in_8bit=None,
@@ -109,10 +117,11 @@
             self.llm = model
         
         # adaptor
-        adaptor_class = tables.adaptor_classes.get(adaptor)
-        adaptor = adaptor_class(**adaptor_conf)
+        adaptor_class = tables.adaptor_classes.get(audio_adaptor)
+        audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
+        audio_adaptor = adaptor_class(**audio_adaptor_conf)
         
-        self.adaptor = adaptor
+        self.audio_adaptor = audio_adaptor
         
         
         self.blank_id = blank_id
@@ -122,8 +131,6 @@
         self.ignore_id = ignore_id
         self.specaug = specaug
         self.normalize = normalize
-        self.encoder = encoder
-
 
         self.criterion_att = LabelSmoothingLoss(
             size=vocab_size,
@@ -131,12 +138,7 @@
             smoothing=lsm_weight,
             normalize_length=length_normalized_loss,
         )
-        #
-        # if report_cer or report_wer:
-        #     self.error_calculator = ErrorCalculator(
-        #         token_list, sym_space, sym_blank, report_cer, report_wer
-        #     )
-        #
+
         self.error_calculator = None
 
         self.length_normalized_loss = length_normalized_loss
@@ -172,27 +174,26 @@
         batch_size = speech.shape[0]
         
         # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         
-        # adaptor
-        encoder_out = self.adaptor(encoder_out)
+        # audio_adaptor
+        encoder_out = self.audio_adaptor(encoder_out)
+        
+        input_ids[input_ids == -1] = 0
+        input_ids[input_ids == -100] = 0
+        if hasattr(self.llm.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.embed_tokens(input_ids)
+        elif hasattr(self.llm.model.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+        else:
+            inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
 
-        if input_ids is not None:
-            input_ids[input_ids == -1] = 0
-            input_ids[input_ids == -100] = 0
-            if hasattr(self.llm.model, "embed_tokens"):
-                inputs_embeds = self.llm.model.embed_tokens(input_ids)
-            elif hasattr(self.llm.model.model, "embed_tokens"):
-                inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
-            else:
-                inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
-
-            if audio_mask is not None:
-                batch_size, token_num, dims = inputs_embeds.shape
-                _, l, _ = encoder_out.shape
-                encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
-                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
-                inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
+        if audio_mask is not None:
+            batch_size, token_num, dims = inputs_embeds.shape
+            _, l, _ = encoder_out.shape
+            # [audio, bos, prompt, input, pad]
+            encoder_outs_pad = F.pad(encoder_out, (0, 0, 0, token_num - l, 0, 0), value=0.0)
+            inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
 
         model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
         loss = model_outputs.loss
@@ -211,26 +212,18 @@
             batch_size = int((text_lengths + 1).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
-    
+
     def encode(
         self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    ):
+        speech = speech.permute(0, 2, 1)
+        res = self.audio_encoder(speech)
+        if len(res) > 1:
+            encoder_out, encoder_out_lens = res[0], res[1]
+        else:
+            encoder_out, encoder_out_lens = res, speech_lengths
+        return encoder_out, encoder_out_lens
     
-        audio_mask = kwargs.get("audio_mask", None)
-        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
-
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        enc, enc_lens = self.audio_encoder.encode(**batch)
-        with autocast(False):
-            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
-            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
-                                                                               mask=enc_mask,
-                                                                               target_label_length=audio_token_lengths,
-                                                                               )
-
-        return pre_acoustic_embeds, pre_token_length
-
-
     def inference(self,
                   data_in,
                   data_lengths=None,
@@ -275,7 +268,7 @@
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
         # adaptor
-        encoder_out = self.adaptor(encoder_out)
+        encoder_out = self.audio_adaptor(encoder_out)
         
     
         prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
@@ -294,26 +287,24 @@
         inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
         attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
-        # model_outputs = self.llm.generate(
-        #     inputs_embeds=inputs_embeds,
-        #     max_length=kwargs.get("max_length", 200),
-        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
-        #     num_beams=kwargs.get("num_beams", 4),
-        #     do_sample=kwargs.get("do_sample", False),
-        #     min_length=kwargs.get("min_length", 1),
-        #     top_p=kwargs.get("top_p", 1.0),
-        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
-        #     length_penalty=kwargs.get("length_penalty", 1.0),
-        #     temperature=kwargs.get("temperature", 1.0),
-        #     attention_mask=attention_mask,
-        #     bos_token_id=tokenizer.bos_token_id,
-        #     eos_token_id=tokenizer.eos_token_id,
-        #     pad_token_id=tokenizer.pad_token_id
-        # )
+        preds = self.llm.generate(
+            inputs_embeds=inputs_embeds,
+            max_length=kwargs.get("max_length", 200),
+            max_new_tokens=kwargs.get("max_new_tokens", 200),
+            num_beams=kwargs.get("num_beams", 4),
+            do_sample=kwargs.get("do_sample", False),
+            min_length=kwargs.get("min_length", 1),
+            top_p=kwargs.get("top_p", 1.0),
+            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+            length_penalty=kwargs.get("length_penalty", 1.0),
+            temperature=kwargs.get("temperature", 1.0),
+            attention_mask=attention_mask,
+            bos_token_id=tokenizer.bos_token_id,
+            eos_token_id=tokenizer.eos_token_id,
+            pad_token_id=tokenizer.pad_token_id
+        )
 
 
-        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
-        preds = torch.argmax(model_outputs.logits, -1)
         text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
 
         text = text[0].split(': ')[-1]
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index db81c47..0f4973d 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -214,7 +214,7 @@
     
     def encode(
         self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    ):
     
         audio_mask = kwargs.get("audio_mask", None)
         audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index 73d70d7..35de1c9 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -41,6 +41,8 @@
         
         self.model = model
         
+        self.encoder_output_size = self.model.dims.n_audio_state
+        
     def forward(self, ):
         pass
     
diff --git a/funasr/register.py b/funasr/register.py
index ebfdaac..cfa1b20 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -29,7 +29,7 @@
                 flag = key in classes_key
             if classes_key.endswith("_meta") and flag:
                 print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
-                headers = ["class name", "class location"]
+                headers = ["register name", "class name", "class location"]
                 metas = []
                 for register_key, meta in classes_dict.items():
                     metas.append(meta)
@@ -67,8 +67,8 @@
             class_line = inspect.getsourcelines(target_class)[1]
             pattern = r'^.+/funasr/'
             class_file = re.sub(pattern, 'funasr/', class_file)
-            meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
-            # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
+            # meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
+            meata_data = [f"{registry_key}", f"{target_class.__name__}", f"{class_file}:{class_line}"]
             registry_meta[registry_key] = meata_data
             # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
             return target_class
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 3b20596..2a57a9b 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -163,7 +163,7 @@
                 self.scaler.load_state_dict(checkpoint['scaler_state'])
             print(f"Checkpoint loaded successfully from '{ckpt}'")
         else:
-            print(f"No checkpoint found at '{ckpt}', starting from scratch")
+            print(f"No checkpoint found at '{ckpt}', does not resume status!")
 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()

--
Gitblit v1.9.1