From bb97d3ed19ee3a219e67b9568d662df489aa2823 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 一月 2024 15:47:01 +0800
Subject: [PATCH] fix win bug

---
 examples/industrial_data_pretraining/emotion2vec/demo.py |    2 
 funasr/bin/train.py                                      |    4 
 funasr/auto/auto_model.py                                |    4 
 funasr/datasets/audio_datasets/datasets.py               |    4 
 funasr/train_utils/load_pretrained_model.py              |  202 ++++++++++++++++++++++++--------------------------
 5 files changed, 105 insertions(+), 111 deletions(-)

diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index 91d00aa..a41641e 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -7,6 +7,6 @@
 
 model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
 
-wav_file = f"{model.model_path}/example/example/test.wav"
+wav_file = f"{model.model_path}/example/test.wav"
 res = model.generate(wav_file, output_dir="./outputs", granularity="utterance")
 print(res)
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 580cca8..ffb56a5 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -183,9 +183,11 @@
             logging.info(f"Loading pretrained params from {init_param}")
             load_pretrained_model(
                 model=model,
-                init_param=init_param,
+                path=init_param,
                 ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
                 oss_bucket=kwargs.get("oss_bucket", None),
+                scope_map=kwargs.get("scope_map", None),
+                excludes=kwargs.get("excludes", None),
             )
         
         return model, kwargs
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 0881cb2..ef0d205 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -96,9 +96,11 @@
             logging.info(f"Loading pretrained params from {p}")
             load_pretrained_model(
                 model=model,
-                init_param=p,
+                path=p,
                 ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
                 oss_bucket=kwargs.get("oss_bucket", None),
+                scope_map=kwargs.get("scope_map", None),
+                excludes=kwargs.get("excludes", None),
             )
     else:
         initialize(model, kwargs.get("init", "kaiming_normal"))
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index edf127f..5af33fc 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -1,7 +1,7 @@
 import torch
 
 from funasr.register import tables
-from funasr.utils.load_utils import extract_fbank
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
 
 
 @tables.register("dataset_classes", "AudioDataset")
@@ -55,7 +55,7 @@
         # import pdb;
         # pdb.set_trace()
         source = item["source"]
-        data_src = load_audio(source, fs=self.fs)
+        data_src = load_audio_text_image_video(source, fs=self.fs)
         if self.preprocessor_speech:
             data_src = self.preprocessor_speech(data_src)
         speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index ef9d93a..16feabd 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -10,119 +10,109 @@
 
 
 def filter_state_dict(
-    dst_state: Dict[str, Union[float, torch.Tensor]],
-    src_state: Dict[str, Union[float, torch.Tensor]],
+	dst_state: Dict[str, Union[float, torch.Tensor]],
+	src_state: Dict[str, Union[float, torch.Tensor]],
 ):
-    """Filter name, size mismatch instances between dicts.
+	"""Filter name, size mismatch instances between dicts.
 
-    Args:
-        dst_state: reference state dict for filtering
-        src_state: target state dict for filtering
+	Args:
+		dst_state: reference state dict for filtering
+		src_state: target state dict for filtering
 
-    """
-    match_state = {}
-    for key, value in src_state.items():
-        if key in dst_state and (dst_state[key].size() == src_state[key].size()):
-            match_state[key] = value
-        else:
-            if key not in dst_state:
-                logging.warning(
-                    f"Filter out {key} from pretrained dict"
-                    + " because of name not found in target dict"
-                )
-            else:
-                logging.warning(
-                    f"Filter out {key} from pretrained dict"
-                    + " because of size mismatch"
-                    + f"({dst_state[key].size()}-{src_state[key].size()})"
-                )
-    return match_state
+	"""
+	match_state = {}
+	for key, value in src_state.items():
+		if key in dst_state and (dst_state[key].size() == src_state[key].size()):
+			match_state[key] = value
+		else:
+			if key not in dst_state:
+				logging.warning(
+					f"Filter out {key} from pretrained dict"
+					+ " because of name not found in target dict"
+				)
+			else:
+				logging.warning(
+					f"Filter out {key} from pretrained dict"
+					+ " because of size mismatch"
+					+ f"({dst_state[key].size()}-{src_state[key].size()})"
+				)
+	return match_state
 
+def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
+	"""Compute the union of the current variables and checkpoint variables."""
+	import collections
+	import re
+
+	# current model variables
+	name_to_variable = collections.OrderedDict()
+	for name, var in dst_state.items():
+		name_to_variable[name] = var
+	
+	scope_map_num = 0
+	if scope_map is not None:
+		scope_map = scope_map.split(",")
+		scope_map_num = len(scope_map) // 2
+		for scope_map_idx in range(scope_map_num):
+			scope_map_id = scope_map_idx * 2
+			logging.info('assignment_map from scope {} to {}'.format(scope_map[scope_map_id], scope_map[scope_map_id+1]))
+	
+	assignment_map = {}
+	for name, var in src_state.items():
+
+		if scope_map:
+			for scope_map_idx in range(scope_map_num):
+				scope_map_id = scope_map_idx * 2
+				try:
+					idx = name.index(scope_map[scope_map_id])
+					new_name = scope_map[scope_map_id+1] + name[idx + len(scope_map[scope_map_id]):]
+					if new_name in name_to_variable:
+						assignment_map[name] = var
+				except:
+					continue
+		else:
+			if name in name_to_variable:
+				assignment_map[name] = var
+	
+	return assignment_map
 
 def load_pretrained_model(
-    init_param: str,
-    model: torch.nn.Module,
-    ignore_init_mismatch: bool,
-    map_location: str = "cpu",
-    oss_bucket=None,
+	path: str,
+	model: torch.nn.Module,
+	ignore_init_mismatch: bool,
+	map_location: str = "cpu",
+	oss_bucket=None,
+	scope_map=None,
+	excludes=None,
 ):
-    """Load a model state and set it to the model.
+	"""Load a model state and set it to the model.
 
-    Args:
-        init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
+	Args:
+		init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
 
-    Examples:
-        >>> load_pretrained_model("somewhere/model.pb", model)
-        >>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model)
-        >>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model)
-        >>> load_pretrained_model(
-        ...     "somewhere/model.pb:decoder:decoder:decoder.embed", model
-        ... )
-        >>> load_pretrained_model("somewhere/decoder.pb::decoder", model)
-    """
-    sps = init_param.split(":", 4)
-    if len(sps) == 4:
-        path, src_key, dst_key, excludes = sps
-    elif len(sps) == 3:
-        path, src_key, dst_key = sps
-        excludes = None
-    elif len(sps) == 2:
-        path, src_key = sps
-        dst_key, excludes = None, None
-    else:
-        (path,) = sps
-        src_key, dst_key, excludes = None, None, None
-    if src_key == "":
-        src_key = None
-    if dst_key == "":
-        dst_key = None
+	Examples:
 
-    if dst_key is None:
-        obj = model
-    else:
-
-        def get_attr(obj: Any, key: str):
-            """Get an nested attribute.
-
-            >>> class A(torch.nn.Module):
-            ...     def __init__(self):
-            ...         super().__init__()
-            ...         self.linear = torch.nn.Linear(10, 10)
-            >>> a = A()
-            >>> assert A.linear.weight is get_attr(A, 'linear.weight')
-
-            """
-            if key.strip() == "":
-                return obj
-            for k in key.split("."):
-                obj = getattr(obj, k)
-            return obj
-
-        obj = get_attr(model, dst_key)
-
-    if oss_bucket is None:
-        src_state = torch.load(path, map_location=map_location)
-    else:
-        buffer = BytesIO(oss_bucket.get_object(path).read())
-        src_state = torch.load(buffer, map_location=map_location)
-    src_state = src_state["model"] if "model" in src_state else src_state
-    if excludes is not None:
-        for e in excludes.split(","):
-            src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
-
-    if src_key is not None:
-        src_state = {
-            k[len(src_key) + 1 :]: v
-            for k, v in src_state.items()
-            if k.startswith(src_key)
-        }
-
-    dst_state = obj.state_dict()
-    if ignore_init_mismatch:
-        src_state = filter_state_dict(dst_state, src_state)
-
-    logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
-    logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
-    dst_state.update(src_state)
-    obj.load_state_dict(dst_state)
-    
\ No newline at end of file
+	"""
+	
+	obj = model
+	
+	if oss_bucket is None:
+		src_state = torch.load(path, map_location=map_location)
+	else:
+		buffer = BytesIO(oss_bucket.get_object(path).read())
+		src_state = torch.load(buffer, map_location=map_location)
+	src_state = src_state["model"] if "model" in src_state else src_state
+	
+	if excludes is not None:
+		for e in excludes.split(","):
+			src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
+	
+	dst_state = obj.state_dict()
+	src_state = assigment_scope_map(dst_state, src_state, scope_map)
+	
+	if ignore_init_mismatch:
+		src_state = filter_state_dict(dst_state, src_state)
+	
+	logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
+	logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
+	# dst_state.update(src_state)
+	obj.load_state_dict(dst_state)
\ No newline at end of file

--
Gitblit v1.9.1