From c3192dffdd79c7b8a75ce1dc880b0a17b72d33a1 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 12 三月 2024 17:27:02 +0800
Subject: [PATCH] Dev gzf (#1480)

---
 funasr/datasets/llm_datasets/datasets.py           |    3 +--
 funasr/train_utils/trainer.py                      |    4 +++-
 funasr/auto/auto_model.py                          |    3 ++-
 runtime/python/onnxruntime/funasr_onnx/punc_bin.py |    2 +-
 4 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index a18224f..47456a3 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -162,7 +162,8 @@
         tokenizer = kwargs.get("tokenizer", None)
         if tokenizer is not None:
             tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-            tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+            tokenizer_conf = kwargs.get("tokenizer_conf", {})
+            tokenizer = tokenizer_class(**tokenizer_conf)
             kwargs["tokenizer"] = tokenizer
 
             kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index 22151a1..d48046b 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -39,8 +39,7 @@
 
         self.float_pad_value = float_pad_value
         self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
-        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
-            self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
         self.prompt_af = ""
         self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
         self.int_pad_value = self.IGNORE_INDEX
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 2a57a9b..723a149 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -401,4 +401,6 @@
                                                    epoch * len(self.dataloader_val) + batch_idx)
                         for key, var in speed_stats.items():
                             self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
-                                                   epoch * len(self.dataloader_val) + batch_idx)
\ No newline at end of file
+                                                   epoch * len(self.dataloader_val) + batch_idx)
+
+        self.model.train()
\ No newline at end of file
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index d9bac14..1b8a1a2 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -58,7 +58,7 @@
             model = AutoModel(model=model_dir)
             model_dir = model.export(quantize=quantize)
             
-        config_file = os.path.join(model_dir, 'confi.yaml')
+        config_file = os.path.join(model_dir, 'config.yaml')
         config = read_yaml(config_file)
         token_list = os.path.join(model_dir, 'tokens.json')
         with open(token_list, 'r', encoding='utf-8') as f:

--
Gitblit v1.9.1