From 35b1c051f6db3649a818547902497d219c871b84 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 14 三月 2024 09:33:30 +0800
Subject: [PATCH] Dev gzf llm (#1493)

---
 funasr/models/llm_asr_nar/model.py                              |  333 +++++++++++++++++++++++++++
 funasr/train_utils/trainer.py                                   |   13 
 runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py        |   14 
 runtime/python/onnxruntime/setup.py                             |    2 
 funasr/auto/auto_model.py                                       |   37 +-
 funasr/datasets/llm_datasets_qwenaudio/datasets.py              |    2 
 runtime/python/onnxruntime/funasr_onnx/punc_bin.py              |   12 
 runtime/python/onnxruntime/funasr_onnx/vad_bin.py               |   10 
 funasr/datasets/llm_datasets/datasets.py                        |  192 +++++++++++++--
 funasr/utils/export_utils.py                                    |    4 
 runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py |    5 
 funasr/datasets/llm_datasets_vicuna/datasets.py                 |    2 
 funasr/bin/train_llm.py                                         |   35 +-
 funasr/datasets/audio_datasets/datasets.py                      |    4 
 14 files changed, 567 insertions(+), 98 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 47456a3..2df1910 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -164,22 +164,23 @@
             tokenizer_class = tables.tokenizer_classes.get(tokenizer)
             tokenizer_conf = kwargs.get("tokenizer_conf", {})
             tokenizer = tokenizer_class(**tokenizer_conf)
-            kwargs["tokenizer"] = tokenizer
+            
 
             kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
             kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
             vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
         else:
             vocab_size = -1
+        kwargs["tokenizer"] = tokenizer
+        
         # build frontend
         frontend = kwargs.get("frontend", None)
         kwargs["input_size"] = None
         if frontend is not None:
             frontend_class = tables.frontend_classes.get(frontend)
             frontend = frontend_class(**kwargs["frontend_conf"])
-            kwargs["frontend"] = frontend
             kwargs["input_size"] = frontend.output_size() if hasattr(frontend, "output_size") else None
-        
+        kwargs["frontend"] = frontend
         # build model
         model_class = tables.model_classes.get(kwargs["model"])
         model = model_class(**kwargs, **kwargs.get("model_conf", {}), vocab_size=vocab_size)
@@ -469,13 +470,19 @@
         #                      f"time_escape_all: {time_escape_total_all_samples:0.3f}")
         return results_ret_list
 
-    def export(self, input=None,
-               type : str = "onnx",
-               quantize: bool = False,
-               fallback_num: int = 5,
-               calib_num: int = 100,
-               opset_version: int = 14,
-               **cfg):
+    def export(self, input=None, **cfg):
+    
+        """
+        
+        :param input:
+        :param type:
+        :param quantize:
+        :param fallback_num:
+        :param calib_num:
+        :param opset_version:
+        :param cfg:
+        :return:
+        """
     
         device = cfg.get("device", "cpu")
         model = self.model.to(device=device)
@@ -485,7 +492,7 @@
         del kwargs["model"]
         model.eval()
 
-        batch_size = 1
+        type = kwargs.get("type", "onnx")
 
         key_list, data_list = prepare_data_iterator(input, input_len=None, data_type=kwargs.get("data_type", None), key=None)
 
@@ -495,19 +502,11 @@
                 export_dir = export_utils.export_onnx(
                                         model=model,
                                         data_in=data_list,
-                                        quantize=quantize,
-                                        fallback_num=fallback_num,
-                                        calib_num=calib_num,
-                                        opset_version=opset_version,
                                         **kwargs)
             else:
                 export_dir = export_utils.export_torchscripts(
                                         model=model,
                                         data_in=data_list,
-                                        quantize=quantize,
-                                        fallback_num=fallback_num,
-                                        calib_num=calib_num,
-                                        opset_version=opset_version,
                                         **kwargs)
 
         return export_dir
\ No newline at end of file
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 3c93371..a33cd53 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -26,7 +26,7 @@
 # from funasr.tokenizer.build_tokenizer import build_tokenizer
 # from funasr.tokenizer.token_id_converter import TokenIDConverter
 # from funasr.tokenizer.funtoken import build_tokenizer
-
+from funasr import AutoModel
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
@@ -60,6 +60,16 @@
     if use_ddp or use_fsdp:
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
+        
+    device = kwargs.get("device", "cpu")
+    kwargs["device"] = "cpu"
+    model = AutoModel(**kwargs)
+    kwargs["device"] = device
+    model = model.model
+    tokenizer = kwargs["tokenizer"]
+    frontend = kwargs["frontend"]
+    
+    
     
     # save config.yaml
     if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
@@ -68,28 +78,9 @@
         OmegaConf.save(config=kwargs, f=yaml_file)
         logging.info("config.yaml is saved to: %s", yaml_file)
 
-    tokenizer = kwargs.get("tokenizer", None)
-    if tokenizer is not None:
-        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
-        kwargs["tokenizer"] = tokenizer
+
     
-    # build frontend if frontend is none None
-    frontend = kwargs.get("frontend", None)
-    if frontend is not None:
-        frontend_class = tables.frontend_classes.get(frontend)
-        frontend = frontend_class(**kwargs["frontend_conf"])
-        kwargs["frontend"] = frontend
-        kwargs["input_size"] = frontend.output_size()
-
-
-    # build model
-    model_class = tables.model_classes.get(kwargs["model"])
-    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
-    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
-    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
-
+    
 
     # init_param
     init_param = kwargs.get("init_param", None)
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 260236c..9825c86 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -90,7 +90,7 @@
 
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
-                if data_list[0].dtype == torch.int64:
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
     
                     pad_value = self.int_pad_value
                 else:
@@ -192,7 +192,7 @@
 
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
-                if data_list[0].dtype == torch.int64:
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                     pad_value = self.int_pad_value
                 else:
                     pad_value = self.float_pad_value
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index d48046b..a49122d 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -5,8 +5,8 @@
 from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
 
 
-@tables.register("dataset_classes", "AudioLLMDataset")
-class AudioLLMDataset(torch.utils.data.Dataset):
+@tables.register("dataset_classes", "AudioLLMNARDataset")
+class AudioLLMNARDataset(torch.utils.data.Dataset):
     """
     AudioLLMDataset
     """
@@ -38,8 +38,8 @@
         self.tokenizer = tokenizer
 
         self.float_pad_value = float_pad_value
-        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
-        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt = kwargs.get("prompt", "Please copy the following text.")
+        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nINPUT: {}\nASSISTANT: "
         self.prompt_af = ""
         self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
         self.int_pad_value = self.IGNORE_INDEX
@@ -72,29 +72,170 @@
         
         
         prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
-        prompt_pre_length = len(prompt_ids_pre)
+        prompt_ids_length = len(prompt_ids_pre)
+        
+        # bos prompt audio bos target
+        # prompt_input = "{}{}".format(self.prompt_pre, target)
+        # prompt_input_ids = self.tokenizer.encode(prompt_input) #[bos, prompt, input]
+        # audio_length = len(prompt_input_ids) - prompt_ids_length
+        target_ids = self.tokenizer.encode(target)
+        if target_ids[0] == self.tokenizer.bos_token_id:
+            target_ids = target_ids[1:]
+        target_ids_length = len(target_ids)
+        audio_length = target_ids_length
+        input_ids = prompt_ids_pre + target_ids + [self.tokenizer.pad_token_id] + target_ids #[bos, prompt, input, pad, target]
+        input_ids = torch.tensor(copy.deepcopy(input_ids), dtype=torch.int64) #[bos, prompt, input, pad, target]
+        input_ids[prompt_ids_length:prompt_ids_length+audio_length] = -1  # [bos, prompt,-1, pad, target] # it is no need, only for check
+        attention_mask = input_ids.ge(-1) # [true, true, true, true, true], length mask
+        
+        # bos prompt audio target eos
+        # prompt_answer = "{}{}".format(self.prompt_pre, target)
+        # prompt_answer_ids = self.tokenizer.encode(prompt_answer) #[bos, prompt, input]
+        # answer_length = len(prompt_answer_ids) - prompt_ids_length
+        target_ids = self.tokenizer.encode(target)
+        if target_ids[0] == self.tokenizer.bos_token_id:
+            target_ids = target_ids[1:]
+        # target_ids_length = len(target_ids)
+        labels_ids = prompt_ids_pre + target_ids + target_ids + [self.tokenizer.eos_token_id] # [bos, prompt, input, target, eos]
+        labels_ids = torch.tensor(copy.deepcopy(labels_ids), dtype=torch.int64)  # [bos, prompt, input, target, eos]
+        labels_ids[:prompt_ids_length] = -1  # [-1, -1, input, target, eos]
+        label_mask = labels_ids.ge(0)  # [false, false, true, true, true], length mask
+        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-1, -1, input, target, eos]
+        
+        audio_mask = [0] * prompt_ids_length + [1] * audio_length + [0] * target_ids_length + [0] # [0, 0, 1, 0, 0]
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+        
+        ids = target_ids #self.tokenizer.encode(target) # token ids is different from labels_ids
+        text = torch.tensor(ids, dtype=torch.int64)
+        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+
+        prompt_bos_length = torch.tensor([len(prompt_ids_pre)], dtype=torch.int32)
+        
+        return {"speech": speech,
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "input_ids": input_ids,
+                "attention_mask": attention_mask,
+                "labels_ids": labels_ids,
+                "label_mask": label_mask,
+                "audio_mask": audio_mask,
+                "prompt_bos_length":  prompt_bos_length,
+                }
+    
+    
+    def collator(self, samples: list=None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+    
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+                
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
+
+
+@tables.register("dataset_classes", "AudioLLMDataset")
+class AudioLLMDataset(torch.utils.data.Dataset):
+    """
+    AudioLLMDataset
+    """
+    
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                 **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+        
+        self.float_pad_value = float_pad_value
+        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
+            self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+        self.int_pad_value = self.IGNORE_INDEX
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
+                                               is_final=True)  # speech: [b, T, d]
+        speech = speech.squeeze(0)
+        
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+        
+        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
+        prompt_ids_length = len(prompt_ids_pre)
         
         prompt_input = "{}{}".format(self.prompt_pre, target)
         prompt_input_ids = self.tokenizer.encode(prompt_input)
-        audio_length = len(prompt_input_ids) - prompt_pre_length
+        audio_length = len(prompt_input_ids) - prompt_ids_length
         input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
-        input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
-        input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
-        attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
-
+        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
+        input_ids[prompt_ids_length:] = -1  # [bos, prompt,-1,-1]
+        attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
+        
         prompt_answer = "{}{}".format(self.prompt_pre, target)
         prompt_answer_ids = self.tokenizer.encode(prompt_answer)
-        answer_length = len(prompt_answer_ids) - prompt_pre_length
+        answer_length = len(prompt_answer_ids) - prompt_ids_length
         labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
         labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, input, eos]
-        labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
+        labels_ids[:prompt_ids_length] = -1  # [-1, -1, input, eos]
         label_mask = labels_ids.ge(0)  # [False,False,True,True]
         labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
         
-        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+        audio_mask = [0] * prompt_ids_length + [1] * audio_length + [0]
         audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
         
-        ids = self.tokenizer.encode(target) # token ids is different from labels_ids
+        ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
         text = torch.tensor(ids, dtype=torch.int64)
         text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
         
@@ -109,19 +250,18 @@
                 "audio_mask": audio_mask,
                 }
     
-    
-    def collator(self, samples: list=None):
+    def collator(self, samples: list = None):
         outputs = {}
         for sample in samples:
             for key in sample.keys():
                 if key not in outputs:
                     outputs[key] = []
                 outputs[key].append(sample[key])
-
+        
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
-                if data_list[0].dtype == torch.int64:
-    
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+                    
                     pad_value = self.int_pad_value
                 else:
                     pad_value = self.float_pad_value
@@ -199,26 +339,26 @@
             target = self.preprocessor_text(target)
         
         prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
-        prompt_pre_length = len(prompt_ids_pre)
+        prompt_ids_length = len(prompt_ids_pre)
         
         prompt_input = "{}{}".format(self.prompt_pre, target)
         prompt_input_ids = self.tokenizer.encode(prompt_input)
-        audio_length = len(prompt_input_ids) - prompt_pre_length
+        audio_length = len(prompt_input_ids) - prompt_ids_length
         input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
         input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
-        input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
+        input_ids[prompt_ids_length:] = -1  # [bos, prompt,-1,-1]
         attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
         
         prompt_answer = "{}{}".format(self.prompt_pre, target)
         prompt_answer_ids = self.tokenizer.encode(prompt_answer)
-        answer_length = len(prompt_answer_ids) - prompt_pre_length
+        answer_length = len(prompt_answer_ids) - prompt_ids_length
         labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
         labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, input, eos]
-        labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
+        labels_ids[:prompt_ids_length] = -1  # [-1, -1, input, eos]
         label_mask = labels_ids.ge(0)  # [False,False,True,True]
         labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
         
-        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+        audio_mask = [0] * prompt_ids_length + [1] * audio_length + [0]
         audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
         
         ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
@@ -246,7 +386,7 @@
         
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
-                if data_list[0].dtype == torch.int64:
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                     
                     pad_value = self.int_pad_value
                 else:
diff --git a/funasr/datasets/llm_datasets_qwenaudio/datasets.py b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
index 674217c..7a2ce22 100644
--- a/funasr/datasets/llm_datasets_qwenaudio/datasets.py
+++ b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
@@ -140,7 +140,7 @@
         
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
-                if data_list[0].dtype == torch.int64:
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                     
                     pad_value = self.int_pad_value
                 else:
diff --git a/funasr/datasets/llm_datasets_vicuna/datasets.py b/funasr/datasets/llm_datasets_vicuna/datasets.py
index 6bf6d04..b5fc5bc 100644
--- a/funasr/datasets/llm_datasets_vicuna/datasets.py
+++ b/funasr/datasets/llm_datasets_vicuna/datasets.py
@@ -140,7 +140,7 @@
         
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
-                if data_list[0].dtype == torch.int64:
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                     
                     pad_value = self.int_pad_value
                 else:
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index d83f571..a6096b2 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -16,6 +16,7 @@
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 from funasr.utils import postprocess_utils
+from funasr.models.paraformer.cif_predictor import mae_loss
 from funasr.utils.datadir_writer import DatadirWriter
 from funasr.register import tables
 
@@ -348,3 +349,335 @@
         
         return results, meta_data
 
+
+@tables.register("model_classes", "LLMASRNARPrompt")
+class LLMASRNARPrompt(nn.Module):
+    """ """
+    
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        ctc: str = None,
+        ctc_conf: dict = None,
+        ctc_weight: float = 0.5,
+        llm: str = None,
+        llm_conf: dict = None,
+        adaptor: str = None,
+        adaptor_conf: dict = None,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        predictor_weight: int = 1.0,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        **kwargs,
+    ):
+        
+        super().__init__()
+        
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        
+        # audio encoder
+        hub = encoder_conf.get("hub", None)
+        if hub == "funasr":
+            from funasr import AutoModel
+            init_param_path = encoder_conf.get("init_param_path",
+                                               "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+            # frontend = model.kwargs.get("frontend")
+            model.model.decoder = None
+            
+            self.audio_encoder = model.model
+            # self.frontend = frontend
+            self.predictor_weight = predictor_weight
+        
+        elif hub == "hf":
+            pass
+        else:
+            encoder_class = tables.encoder_classes.get(encoder)
+            encoder = encoder_class(input_size=input_size, **encoder_conf)
+            encoder_output_size = encoder.output_size()
+        
+        # llm
+        hub = llm_conf.get("hub", "hf")
+        self.llm = None
+        if hub == "hf":
+            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+            
+            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+            model = AutoModelForCausalLM.from_pretrained(
+                init_param_path,
+                load_in_8bit=None,
+                device_map=None,
+                use_cache=None,
+            )
+            freeze = llm_conf.get("freeze", True)
+            if freeze:
+                for name, param in model.named_parameters():
+                    param.requires_grad = False
+                model.eval()
+            self.llm = model
+        
+        # adaptor
+        adaptor_class = tables.adaptor_classes.get(adaptor)
+        adaptor = adaptor_class(**adaptor_conf)
+        
+        self.adaptor = adaptor
+        
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+        #
+        # if report_cer or report_wer:
+        #     self.error_calculator = ErrorCalculator(
+        #         token_list, sym_space, sym_blank, report_cer, report_wer
+        #     )
+        #
+        self.error_calculator = None
+        
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        input_ids: torch.Tensor,
+        attention_mask: torch.Tensor,
+        labels_ids: torch.Tensor,
+        label_mask: torch.Tensor,
+        audio_mask: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        
+        batch_size = speech.shape[0]
+        
+        # audio encoder
+        encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+        
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
+        
+        if input_ids is not None:
+            input_ids[input_ids == -1] = 0
+            input_ids[input_ids == -100] = 0
+            if hasattr(self.llm.model, "embed_tokens"):
+                inputs_embeds = self.llm.model.embed_tokens(input_ids)
+            elif hasattr(self.llm.model.model, "embed_tokens"):
+                inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+            else:
+                inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+            
+            if audio_mask is not None:
+                # inputs_embeds锛� [bos, prompt, input, pad, target]
+                prompt_bos_length = kwargs.get("prompt_bos_length", None)
+                assert prompt_bos_length is not None
+                prompt_bos_length = prompt_bos_length[0].item()
+                batch_size, token_num, dims = inputs_embeds.shape
+                _, l, _ = encoder_out.shape
+                encoder_outs_pad = F.pad(encoder_out, (0, 0, prompt_bos_length, token_num - prompt_bos_length - l, 0, 0), value=0.0)
+                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0 - audio_mask[:, :, None])
+                inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0) # [prompt, input, pad, target, 0.0]
+        
+        # labels_ids: [bos, prompt, input, target, eos] -> [-1, -1, input, target, eos]
+        # loss:
+        # inputs_embeds[:-1] -> [prompt, input, pad, target]
+        # labels_ids[1:] ->  [prompt, input, target, eos] -> [-1, input, target, eos];
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
+        loss_llm = model_outputs.loss
+        loss = loss_llm + loss_pre * self.predictor_weight
+        stats = {}
+        with torch.no_grad():
+            preds = torch.argmax(model_outputs.logits, -1)
+            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+            stats["acc"] = acc_att
+        
+        
+        stats["loss_pre"] = torch.clone(loss_pre.detach())
+        stats["loss_llm"] = torch.clone(loss_llm.detach())
+        stats["loss"] = torch.clone(loss.detach())
+        
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+    
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ):
+        
+        audio_mask = kwargs.get("audio_mask", None)
+        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+        text_token_int = kwargs.get("text_token_int", None)
+        if audio_token_lengths is None:
+            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
+        
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        enc, enc_lens = self.audio_encoder.encode(**batch)
+        with autocast(False):
+            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+                                                                                       mask=enc_mask,
+                                                                                       target_label_length=audio_token_lengths,
+                                                                                       )
+            loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
+        
+        return pre_acoustic_embeds, pre_token_length, loss_pre
+    
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+        
+        prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+        
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=None)
+            if len(kwargs.get("data_type")) > 1:
+                audio_sample_list, text_token_int_list = audio_sample_list
+                text_token_int = text_token_int_list[0].replace(" ", "")
+                text_token_int = tokenizer.encode(text_token_int)
+            else:
+                text_token_int = None
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+        
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+        
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
+        
+        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+        prompt_ids = tokenizer.encode(prompt_pre)
+        prompt_length = len(prompt_ids)
+        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+        
+        if hasattr(self.llm.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+        elif hasattr(self.llm.model.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+        else:
+            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+        
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
+        
+        # model_outputs = self.llm.generate(
+        #     inputs_embeds=inputs_embeds,
+        #     max_length=kwargs.get("max_length", 200),
+        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
+        #     num_beams=kwargs.get("num_beams", 4),
+        #     do_sample=kwargs.get("do_sample", False),
+        #     min_length=kwargs.get("min_length", 1),
+        #     top_p=kwargs.get("top_p", 1.0),
+        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+        #     length_penalty=kwargs.get("length_penalty", 1.0),
+        #     temperature=kwargs.get("temperature", 1.0),
+        #     attention_mask=attention_mask,
+        #     bos_token_id=tokenizer.bos_token_id,
+        #     eos_token_id=tokenizer.eos_token_id,
+        #     pad_token_id=tokenizer.pad_token_id
+        # )
+        
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+        preds = torch.argmax(model_outputs.logits, -1)
+        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+        
+        text = text[0].split(': ')[-1]
+        text = text.strip()
+        
+        # preds = torch.argmax(model_outputs.logits, -1)
+        
+        ibest_writer = None
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer[f"{0 + 1}best_recog"]
+        
+        results = []
+        result_i = {"key": key[0], "text": text}
+        results.append(result_i)
+        
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = text
+        
+        return results, meta_data
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 723a149..a00b3de 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -88,6 +88,7 @@
         scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
         scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
         self.scaler = scaler
+        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
         
     
         try:
@@ -104,7 +105,7 @@
         self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
         
     
-    def _save_checkpoint(self, epoch):
+    def _save_checkpoint(self, epoch, step=None):
         """
         Saves a checkpoint containing the model's state, the optimizer's state,
         and the scheduler's state at the end of the given epoch. This method is
@@ -123,7 +124,11 @@
             state["scaler_state"] = self.scaler.state_dict()
         # Create output directory if it does not exist
         os.makedirs(self.output_dir, exist_ok=True)
-        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+        if step is None:
+            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+        else:
+            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
+        
         torch.save(state, filename)
         
         print(f'\nCheckpoint saved to {filename}\n')
@@ -337,8 +342,10 @@
                     for key, var in speed_stats.items():
                         self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
 
-
+            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+                self._save_checkpoint(epoch, step=batch_idx+1)
         pbar.close()
+        
 
     def _validate_epoch(self, epoch):
         """
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index f563a9b..29502b1 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -4,8 +4,6 @@
 def export_onnx(model,
                 data_in=None,
 				quantize: bool = False,
-				fallback_num: int = 5,
-				calib_num: int = 100,
 				opset_version: int = 14,
 				**kwargs):
 	model_scripts = model.export(**kwargs)
@@ -19,8 +17,6 @@
 		_onnx(m,
 		      data_in=data_in,
 		      quantize=quantize,
-		      fallback_num=fallback_num,
-		      calib_num=calib_num,
 		      opset_version=opset_version,
 		      export_dir=export_dir,
 		      **kwargs
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index af3c9b9..ee21609 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -35,7 +35,8 @@
                  plot_timestamp_to: str = "",
                  quantize: bool = False,
                  intra_op_num_threads: int = 4,
-                 cache_dir: str = None
+                 cache_dir: str = None,
+                 **kwargs
                  ):
         if not Path(model_dir).exists():
             try:
@@ -64,7 +65,7 @@
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="onnx", quantize=quantize)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
             
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
@@ -243,7 +244,8 @@
                  plot_timestamp_to: str = "",
                  quantize: bool = False,
                  intra_op_num_threads: int = 4,
-                 cache_dir: str = None
+                 cache_dir: str = None,
+                 **kwargs
                  ):
 
         if not Path(model_dir).exists():
@@ -277,7 +279,7 @@
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="onnx", quantize=quantize)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
             
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
@@ -403,4 +405,6 @@
     
     
 class SeacoParaformer(ContextualParaformer):
-    pass # no difference with contextual_paraformer in method of calling onnx models
\ No newline at end of file
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # no difference with contextual_paraformer in method of calling onnx models
\ No newline at end of file
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index 6925960..05563aa 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -24,7 +24,8 @@
                  device_id: Union[str, int] = "-1",
                  quantize: bool = False,
                  intra_op_num_threads: int = 4,
-                 cache_dir: str = None
+                 cache_dir: str = None,
+                 **kwargs
                  ):
 
         if not Path(model_dir).exists():
@@ -56,7 +57,7 @@
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="onnx", quantize=quantize)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
 
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 1b8a1a2..db45baa 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -26,6 +26,7 @@
                  quantize: bool = False,
                  intra_op_num_threads: int = 4,
                  cache_dir: str = None,
+                 **kwargs
                  ):
     
         if not Path(model_dir).exists():
@@ -56,7 +57,7 @@
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(quantize=quantize)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
             
         config_file = os.path.join(model_dir, 'config.yaml')
         config = read_yaml(config_file)
@@ -168,14 +169,9 @@
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
-    def __init__(self, model_dir: Union[str, Path] = None,
-                 batch_size: int = 1,
-                 device_id: Union[str, int] = "-1",
-                 quantize: bool = False,
-                 intra_op_num_threads: int = 4,
-                 cache_dir: str = None
+    def __init__(self, *args, **kwargs
                  ):
-        super().__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
+        super().__init__(*args, **kwargs)
 
     def __call__(self, text: str, param_dict: map, split_size=20):
         cache_key = "cache"
diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index 6b3a1bc..a2f443a 100644
--- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -31,7 +31,8 @@
 	             quantize: bool = False,
 	             intra_op_num_threads: int = 4,
 	             max_end_sil: int = None,
-	             cache_dir: str = None
+	             cache_dir: str = None,
+	             **kwargs
 	             ):
 		
 		if not Path(model_dir).exists():
@@ -62,7 +63,7 @@
 				      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 			
 			model = AutoModel(model=model_dir)
-			model_dir = model.export(type="onnx", quantize=quantize)
+			model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
 		config_file = os.path.join(model_dir, 'config.yaml')
 		cmvn_file = os.path.join(model_dir, 'am.mvn')
 		config = read_yaml(config_file)
@@ -196,7 +197,8 @@
 	             quantize: bool = False,
 	             intra_op_num_threads: int = 4,
 	             max_end_sil: int = None,
-	             cache_dir: str = None
+	             cache_dir: str = None,
+	             **kwargs
 	             ):
 		if not Path(model_dir).exists():
 			try:
@@ -226,7 +228,7 @@
 				      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 			
 			model = AutoModel(model=model_dir)
-			model_dir = model.export(type="onnx", quantize=quantize)
+			model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
 			
 		config_file = os.path.join(model_dir, 'config.yaml')
 		cmvn_file = os.path.join(model_dir, 'am.mvn')
diff --git a/runtime/python/onnxruntime/setup.py b/runtime/python/onnxruntime/setup.py
index df20fd9..39d57e7 100644
--- a/runtime/python/onnxruntime/setup.py
+++ b/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
 
 
 MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.3.0'
+VERSION_NUM = '0.3.1'
 
 setuptools.setup(
     name=MODULE_NAME,

--
Gitblit v1.9.1