From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/tasks/asr.py |   18 ++++--------------
 1 files changed, 4 insertions(+), 14 deletions(-)

diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 05eace7..e151473 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -125,7 +125,7 @@
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
         mfcca=MFCCA,
-        timestamp_predictor=TimestampPredictor,
+        timestamp_prediction=TimestampPredictor,
     ),
     type_check=AbsESPnetModel,
     default="asr",
@@ -826,7 +826,7 @@
             if "model.ckpt-" in model_name or ".bin" in model_name:
                 model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                             '.pb')) if ".bin" in model_name else os.path.join(
-                    model_dir, "{}.pth".format(model_name))
+                    model_dir, "{}.pb".format(model_name))
                 if os.path.exists(model_name_pth):
                     logging.info("model_file is load from pth: {}".format(model_name_pth))
                     model_dict = torch.load(model_name_pth, map_location=device)
@@ -1073,7 +1073,7 @@
             if "model.ckpt-" in model_name or ".bin" in model_name:
                 model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                             '.pb')) if ".bin" in model_name else os.path.join(
-                    model_dir, "{}.pth".format(model_name))
+                    model_dir, "{}.pb".format(model_name))
                 if os.path.exists(model_name_pth):
                     logging.info("model_file is load from pth: {}".format(model_name_pth))
                     model_dict = torch.load(model_name_pth, map_location=device)
@@ -1278,8 +1278,6 @@
             token_list = list(args.token_list)
         else:
             raise RuntimeError("token_list must be str or list")
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
 
         # 1. frontend
         if args.input_size is None:
@@ -1316,6 +1314,7 @@
             frontend=frontend,
             encoder=encoder,
             predictor=predictor,
+            token_list=token_list,
             **args.model_conf,
         )
 
@@ -1332,12 +1331,3 @@
     ) -> Tuple[str, ...]:
         retval = ("speech", "text")
         return retval
-
-
-class ASRTaskAligner(ASRTaskParaformer):
-    @classmethod
-    def required_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        retval = ("speech", "text")
-        return retval
\ No newline at end of file

--
Gitblit v1.9.1