From 7a4816651fd59ba02f780884613c1fbf52031f76 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 28 二月 2024 14:38:05 +0800
Subject: [PATCH] init param
---
funasr/models/llm_asr_nar/model.py | 39 +++++++++++--------
funasr/models/llm_asr_nar/__init__.py | 0
funasr/bin/train.py | 2
funasr/models/llm_asr_nar/adaptor.py | 0
funasr/auto/auto_model.py | 4 +-
funasr/train_utils/load_pretrained_model.py | 41 ++++++++++++++++----
6 files changed, 58 insertions(+), 28 deletions(-)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 48a983c..046e9bf 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -90,7 +90,7 @@
class AutoModel:
def __init__(self, **kwargs):
- if not kwargs.get("disable_log", False):
+ if not kwargs.get("disable_log", True):
tables.print()
model, kwargs = self.build_model(**kwargs)
@@ -188,7 +188,7 @@
path=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
- scope_map=kwargs.get("scope_map", None),
+ scope_map=kwargs.get("scope_map", "module.,None"),
excludes=kwargs.get("excludes", None),
)
else:
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 44d84e7..6650f0a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -105,7 +105,7 @@
path=p,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
- scope_map=kwargs.get("scope_map", None),
+ scope_map=kwargs.get("scope_map", "module.,none"),
excludes=kwargs.get("excludes", None),
)
else:
diff --git a/funasr/models/llm_asr/__init__.py b/funasr/models/llm_asr_nar/__init__.py
similarity index 100%
rename from funasr/models/llm_asr/__init__.py
rename to funasr/models/llm_asr_nar/__init__.py
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr_nar/adaptor.py
similarity index 100%
rename from funasr/models/llm_asr/adaptor.py
rename to funasr/models/llm_asr_nar/adaptor.py
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr_nar/model.py
similarity index 90%
rename from funasr/models/llm_asr/model.py
rename to funasr/models/llm_asr_nar/model.py
index 2b6db96..a61190c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -294,24 +294,29 @@
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
- model_outputs = self.llm.generate(
- inputs_embeds=inputs_embeds,
- max_length=kwargs.get("max_length", 200),
- max_new_tokens=kwargs.get("max_new_tokens", 200),
- num_beams=kwargs.get("num_beams", 4),
- do_sample=kwargs.get("do_sample", False),
- min_length=kwargs.get("min_length", 1),
- top_p=kwargs.get("top_p", 1.0),
- repetition_penalty=kwargs.get("repetition_penalty", 1.0),
- length_penalty=kwargs.get("length_penalty", 1.0),
- temperature=kwargs.get("temperature", 1.0),
- attention_mask=attention_mask,
- bos_token_id=tokenizer.bos_token_id,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- )
+ # model_outputs = self.llm.generate(
+ # inputs_embeds=inputs_embeds,
+ # max_length=kwargs.get("max_length", 200),
+ # max_new_tokens=kwargs.get("max_new_tokens", 200),
+ # num_beams=kwargs.get("num_beams", 4),
+ # do_sample=kwargs.get("do_sample", False),
+ # min_length=kwargs.get("min_length", 1),
+ # top_p=kwargs.get("top_p", 1.0),
+ # repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+ # length_penalty=kwargs.get("length_penalty", 1.0),
+ # temperature=kwargs.get("temperature", 1.0),
+ # attention_mask=attention_mask,
+ # bos_token_id=tokenizer.bos_token_id,
+ # eos_token_id=tokenizer.eos_token_id,
+ # pad_token_id=tokenizer.pad_token_id
+ # )
- text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+ preds = torch.argmax(model_outputs.logits, -1)
+ text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+ text = text.split(': "\n')[-1]
+ # preds = torch.argmax(model_outputs.logits, -1)
ibest_writer = None
if kwargs.get("output_dir") is not None:
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 03a6ff5..23a6ef5 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -82,7 +82,7 @@
ignore_init_mismatch: bool,
map_location: str = "cpu",
oss_bucket=None,
- scope_map=None,
+ scope_map="module.:none",
excludes=None,
):
"""Load a model state and set it to the model.
@@ -108,15 +108,40 @@
src_state = src_state["model"] if "model" in src_state else src_state
+ if isinstance(scope_map, str):
+ scope_map = scope_map.split(",")
+
for k in dst_state.keys():
- if not k.startswith("module.") and "module." + k in src_state.keys():
- k_ddp = "module." + k
+ # if not k.startswith("module.") and "module." + k in src_state.keys():
+ # k_ddp = "module." + k
+ # else:
+ # k_ddp = k
+ k_src = k
+
+ if scope_map is not None:
+ src_prefix = ""
+ dst_prefix = ""
+ for i in range(0, len(scope_map), 2):
+ src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
+ dst_prefix = scope_map[i+1] if scope_map[i+1].lower() != "none" else ""
+
+ if k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix) in src_state.keys():
+ k_src = k.replace(dst_prefix, src_prefix)
+ print(f"init param, map: {k} from {k_src} in ckpt")
+
+ if k_src in src_state.keys():
+ dst_state[k] = src_state[k_src]
+
+ # if k_ddp.startswith("audio_encoder"):
+ # if k_ddp.replace("audio_encoder", "encoder.model") in src_state.keys():
+ # k_ddp = k_ddp.replace("audio_encoder", "encoder.model")
+ # if k_ddp.startswith("adaptor"):
+ # if k_ddp.replace("adaptor", "encoder_projector") in src_state.keys():
+ # k_ddp = k_ddp.replace("adaptor", "encoder_projector")
+ # if k_ddp in src_state:
+ # dst_state[k] = src_state[k_ddp]
else:
- k_ddp = k
- if k_ddp in src_state:
- dst_state[k] = src_state[k_ddp]
- else:
- print(f"Warning, miss key in ckpt: {k}, mapped: {k_ddp}")
+ print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
flag = obj.load_state_dict(dst_state, strict=False)
# print(flag)
--
Gitblit v1.9.1