From be26169447b2e5f8f38c97af8f5f6a201bc6ce40 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 14:00:34 +0800
Subject: [PATCH] decoding

---
 funasr/datasets/openai_datasets/datasets.py |    1 +
 funasr/train_utils/load_pretrained_model.py |   53 +++++++++++++++++------------------------------------
 2 files changed, 18 insertions(+), 36 deletions(-)

diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index 6307930..39b8453 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -376,6 +376,7 @@
                 target_ids = self.tokenizer.encode(target_out)
                 input_ids += source_ids + target_ids
                 labels += source_mask + target_ids
+                fbank.append(speech)
                 fbank_mask += fbank_mask_i
                 fbank_beg.append(fbank_beg_i)
 
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 02abfd5..c31a4d2 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -10,36 +10,6 @@
 import pdb
 
 
-def filter_state_dict(
-    dst_state: Dict[str, Union[float, torch.Tensor]],
-    src_state: Dict[str, Union[float, torch.Tensor]],
-):
-    """Filter name, size mismatch instances between dicts.
-
-    Args:
-            dst_state: reference state dict for filtering
-            src_state: target state dict for filtering
-
-    """
-    match_state = {}
-    for key, value in src_state.items():
-        if key in dst_state and (dst_state[key].size() == src_state[key].size()):
-            match_state[key] = value
-        else:
-            if key not in dst_state:
-                logging.warning(
-                    f"Filter out {key} from pretrained dict"
-                    + " because of name not found in target dict"
-                )
-            else:
-                logging.warning(
-                    f"Filter out {key} from pretrained dict"
-                    + " because of size mismatch"
-                    + f"({dst_state[key].size()}-{src_state[key].size()})"
-                )
-    return match_state
-
-
 def load_pretrained_model(
     path: str,
     model: torch.nn.Module,
@@ -62,7 +32,7 @@
     obj = model
     dst_state = obj.state_dict()
 
-    print(f"ckpt: {path}")
+    logging.info(f"ckpt: {path}")
 
     if oss_bucket is None:
         src_state = torch.load(path, map_location=map_location)
@@ -77,8 +47,19 @@
     if isinstance(scope_map, str):
         scope_map = scope_map.split(",")
     scope_map += ["module.", "None"]
+    logging.info(f"scope_map: {scope_map}")
+
+    if excludes is not None:
+        if isinstance(excludes, str):
+            excludes = excludes.split(",")
+    logging.info(f"excludes: {excludes}")
 
     for k in dst_state.keys():
+
+        for k_ex in excludes:
+            if k.startswith(k_ex):
+                logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+                continue
 
         k_src = k
 
@@ -92,25 +73,25 @@
                 if dst_prefix == "" and (src_prefix + k) in src_state.keys():
                     k_src = src_prefix + k
                     if not k_src.startswith("module."):
-                        print(f"init param, map: {k} from {k_src} in ckpt")
+                        logging.info(f"init param, map: {k} from {k_src} in ckpt")
                 elif (
                     k.startswith(dst_prefix)
                     and k.replace(dst_prefix, src_prefix, 1) in src_state.keys()
                 ):
                     k_src = k.replace(dst_prefix, src_prefix, 1)
                     if not k_src.startswith("module."):
-                        print(f"init param, map: {k} from {k_src} in ckpt")
+                        logging.info(f"init param, map: {k} from {k_src} in ckpt")
 
         if k_src in src_state.keys():
             if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
-                print(
+                logging.info(
                     f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
                 )
             else:
                 dst_state[k] = src_state[k_src]
 
         else:
-            print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
+            logging.info(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
 
     flag = obj.load_state_dict(dst_state, strict=True)
-    # print(flag)
+    logging.info(f"Loading ckpt: {path}, status: {flag}")

--
Gitblit v1.9.1