From e67ed1d45d5a9d7fb7bb22d15fd2bfef17e9076f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 一月 2024 10:57:14 +0800
Subject: [PATCH] Update load_pretrained_model.py
---
funasr/auto/auto_model.py | 14 ++++++++------
1 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 25edeb7..0bc5e0e 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -146,7 +146,7 @@
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
- # kwargs["batch_size"] = 1
+ kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
@@ -183,9 +183,11 @@
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
- init_param=init_param,
+ path=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", None),
+ excludes=kwargs.get("excludes", None),
)
return model, kwargs
@@ -264,7 +266,7 @@
# step.1: compute the vad model
self.vad_kwargs.update(cfg)
beg_vad = time.time()
- res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
+ res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
@@ -316,7 +318,7 @@
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
- results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
+ results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
@@ -327,7 +329,7 @@
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
- spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
+ spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
@@ -378,7 +380,7 @@
# step.3 compute punc model
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
- punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
+ punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
result["text_with_punc"] = punc_res[0]["text"]
# speaker embedding cluster after resorted
--
Gitblit v1.9.1