From e33bb15d269bb3e2e41f7a3540d9b92703bb5c50 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 15 三月 2023 10:51:52 +0800
Subject: [PATCH] Merge branch 'main' into dev_aky

---
 funasr/tasks/sv.py |  110 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 107 insertions(+), 3 deletions(-)

diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index 16384a7..1b08c4d 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -1,14 +1,18 @@
 import argparse
 import logging
+import os
+from pathlib import Path
 from typing import Callable
 from typing import Collection
 from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
+from typing import Union
 
 import numpy as np
 import torch
+import yaml
 from typeguard import check_argument_types
 from typeguard import check_return_type
 
@@ -21,7 +25,7 @@
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
 from funasr.models.pooling.statistic_pooling import StatisticPooling
 from funasr.models.decoder.sv_decoder import DenseDecoder
 from funasr.models.e2e_sv import ESPnetSVModel
@@ -103,6 +107,7 @@
     "encoder",
     classes=dict(
         resnet34=ResNet34,
+        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
         rnn=RNNEncoder,
     ),
     type_check=AbsEncoder,
@@ -394,9 +399,16 @@
 
         # 7. Pooling layer
         pooling_class = pooling_choices.get_class(args.pooling_type)
+        pooling_dim = (2, 3)
+        eps = 1e-12
+        if hasattr(args, "pooling_type_conf"):
+            if "pooling_dim" in args.pooling_type_conf:
+                pooling_dim = args.pooling_type_conf["pooling_dim"]
+            if "eps" in args.pooling_type_conf:
+                eps = args.pooling_type_conf["eps"]
         pooling_layer = pooling_class(
-            pooling_dim=(2, 3),
-            eps=1e-12,
+            pooling_dim=pooling_dim,
+            eps=eps,
         )
         if args.pooling_type == "statistic":
             encoder_output_size *= 2
@@ -435,3 +447,95 @@
 
         assert check_return_type(model)
         return model
+
+    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
+    @classmethod
+    def build_model_from_file(
+            cls,
+            config_file: Union[Path, str] = None,
+            model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            device: str = "cpu",
+    ):
+        """Build model from the files.
+
+        This method is used for inference or fine-tuning.
+
+        Args:
+            config_file: The yaml file saved when training.
+            model_file: The model file saved when training.
+            cmvn_file: The cmvn file for front-end
+            device: Device type, "cpu", "cuda", or "cuda:N".
+
+        """
+        assert check_argument_types()
+        if config_file is None:
+            assert model_file is not None, (
+                "The argument 'model_file' must be provided "
+                "if the argument 'config_file' is not specified."
+            )
+            config_file = Path(model_file).parent / "config.yaml"
+        else:
+            config_file = Path(config_file)
+
+        with config_file.open("r", encoding="utf-8") as f:
+            args = yaml.safe_load(f)
+        if cmvn_file is not None:
+            args["cmvn_file"] = cmvn_file
+        args = argparse.Namespace(**args)
+        model = cls.build_model(args)
+        if not isinstance(model, AbsESPnetModel):
+            raise RuntimeError(
+                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+            )
+        model.to(device)
+        model_dict = dict()
+        model_name_pth = None
+        if model_file is not None:
+            logging.info("model_file is {}".format(model_file))
+            if device == "cuda":
+                device = f"cuda:{torch.cuda.current_device()}"
+            model_dir = os.path.dirname(model_file)
+            model_name = os.path.basename(model_file)
+            if "model.ckpt-" in model_name or ".bin" in model_name:
+                if ".bin" in model_name:
+                    model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
+                else:
+                    model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
+                if os.path.exists(model_name_pth):
+                    logging.info("model_file is load from pth: {}".format(model_name_pth))
+                    model_dict = torch.load(model_name_pth, map_location=device)
+                else:
+                    model_dict = cls.convert_tf2torch(model, model_file)
+                model.load_state_dict(model_dict)
+            else:
+                model_dict = torch.load(model_file, map_location=device)
+        model.load_state_dict(model_dict)
+        if model_name_pth is not None and not os.path.exists(model_name_pth):
+            torch.save(model_dict, model_name_pth)
+            logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+        return model, args
+
+    @classmethod
+    def convert_tf2torch(
+            cls,
+            model,
+            ckpt,
+    ):
+        logging.info("start convert tf model to torch model")
+        from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+        var_dict_tf = load_tf_dict(ckpt)
+        var_dict_torch = model.state_dict()
+        var_dict_torch_update = dict()
+        # speech encoder
+        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # pooling layer
+        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+
+        return var_dict_torch_update

--
Gitblit v1.9.1