From f4db9317f2cdfce81b09b80c183d4197d29533ca Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 18 四月 2023 10:55:27 +0800
Subject: [PATCH] Update modelscope_models.md

---
 funasr/tasks/asr.py |   21 ++++++---------------
 1 files changed, 6 insertions(+), 15 deletions(-)

diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 05eace7..52a0ce7 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -39,7 +39,7 @@
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
 from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_asr_mfcca import MFCCA
 from funasr.models.e2e_uni_asr import UniASR
@@ -121,11 +121,12 @@
         asr=ESPnetASRModel,
         uniasr=UniASR,
         paraformer=Paraformer,
+        paraformer_online=ParaformerOnline,
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
         mfcca=MFCCA,
-        timestamp_predictor=TimestampPredictor,
+        timestamp_prediction=TimestampPredictor,
     ),
     type_check=AbsESPnetModel,
     default="asr",
@@ -826,7 +827,7 @@
             if "model.ckpt-" in model_name or ".bin" in model_name:
                 model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                             '.pb')) if ".bin" in model_name else os.path.join(
-                    model_dir, "{}.pth".format(model_name))
+                    model_dir, "{}.pb".format(model_name))
                 if os.path.exists(model_name_pth):
                     logging.info("model_file is load from pth: {}".format(model_name_pth))
                     model_dict = torch.load(model_name_pth, map_location=device)
@@ -1073,7 +1074,7 @@
             if "model.ckpt-" in model_name or ".bin" in model_name:
                 model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                             '.pb')) if ".bin" in model_name else os.path.join(
-                    model_dir, "{}.pth".format(model_name))
+                    model_dir, "{}.pb".format(model_name))
                 if os.path.exists(model_name_pth):
                     logging.info("model_file is load from pth: {}".format(model_name_pth))
                     model_dict = torch.load(model_name_pth, map_location=device)
@@ -1278,8 +1279,6 @@
             token_list = list(args.token_list)
         else:
             raise RuntimeError("token_list must be str or list")
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
 
         # 1. frontend
         if args.input_size is None:
@@ -1316,6 +1315,7 @@
             frontend=frontend,
             encoder=encoder,
             predictor=predictor,
+            token_list=token_list,
             **args.model_conf,
         )
 
@@ -1332,12 +1332,3 @@
     ) -> Tuple[str, ...]:
         retval = ("speech", "text")
         return retval
-
-
-class ASRTaskAligner(ASRTaskParaformer):
-    @classmethod
-    def required_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        retval = ("speech", "text")
-        return retval
\ No newline at end of file

--
Gitblit v1.9.1