From e67ed1d45d5a9d7fb7bb22d15fd2bfef17e9076f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 一月 2024 10:57:14 +0800
Subject: [PATCH] Update load_pretrained_model.py

---
 funasr/auto/auto_model.py |   14 ++++++++------
 1 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 25edeb7..0bc5e0e 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -146,7 +146,7 @@
         device = kwargs.get("device", "cuda")
         if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
             device = "cpu"
-            # kwargs["batch_size"] = 1
+            kwargs["batch_size"] = 1
         kwargs["device"] = device
         
         if kwargs.get("ncpu", None):
@@ -183,9 +183,11 @@
             logging.info(f"Loading pretrained params from {init_param}")
             load_pretrained_model(
                 model=model,
-                init_param=init_param,
+                path=init_param,
                 ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
                 oss_bucket=kwargs.get("oss_bucket", None),
+                scope_map=kwargs.get("scope_map", None),
+                excludes=kwargs.get("excludes", None),
             )
         
         return model, kwargs
@@ -264,7 +266,7 @@
         # step.1: compute the vad model
         self.vad_kwargs.update(cfg)
         beg_vad = time.time()
-        res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
+        res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
         end_vad = time.time()
         print(f"time cost vad: {end_vad - beg_vad:0.3f}")
 
@@ -316,7 +318,7 @@
                 batch_size_ms_cum = 0
                 end_idx = j + 1
                 speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])       
-                results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
+                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
                 if self.spk_model is not None:
                     all_segments = []
                     # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
@@ -327,7 +329,7 @@
                         segments = sv_chunk(vad_segments)
                         all_segments.extend(segments)
                         speech_b = [i[2] for i in segments]
-                        spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
+                        spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
                         results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
                 beg_idx = end_idx
                 if len(results) < 1:
@@ -378,7 +380,7 @@
             # step.3 compute punc model
             if self.punc_model is not None:
                 self.punc_kwargs.update(cfg)
-                punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
+                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                 result["text_with_punc"] = punc_res[0]["text"]
                      
             # speaker embedding cluster after resorted

--
Gitblit v1.9.1