From c3192dffdd79c7b8a75ce1dc880b0a17b72d33a1 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 12 三月 2024 17:27:02 +0800
Subject: [PATCH] Dev gzf (#1480)
---
funasr/datasets/llm_datasets/datasets.py | 3 +--
funasr/train_utils/trainer.py | 4 +++-
funasr/auto/auto_model.py | 3 ++-
runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 2 +-
4 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index a18224f..47456a3 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -162,7 +162,8 @@
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
- tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+ tokenizer_conf = kwargs.get("tokenizer_conf", {})
+ tokenizer = tokenizer_class(**tokenizer_conf)
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index 22151a1..d48046b 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -39,8 +39,7 @@
self.float_pad_value = float_pad_value
self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
- self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
- self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+ self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
self.prompt_af = ""
self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
self.int_pad_value = self.IGNORE_INDEX
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 2a57a9b..723a149 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -401,4 +401,6 @@
epoch * len(self.dataloader_val) + batch_idx)
for key, var in speed_stats.items():
self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
- epoch * len(self.dataloader_val) + batch_idx)
\ No newline at end of file
+ epoch * len(self.dataloader_val) + batch_idx)
+
+ self.model.train()
\ No newline at end of file
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index d9bac14..1b8a1a2 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -58,7 +58,7 @@
model = AutoModel(model=model_dir)
model_dir = model.export(quantize=quantize)
- config_file = os.path.join(model_dir, 'confi.yaml')
+ config_file = os.path.join(model_dir, 'config.yaml')
config = read_yaml(config_file)
token_list = os.path.join(model_dir, 'tokens.json')
with open(token_list, 'r', encoding='utf-8') as f:
--
Gitblit v1.9.1