From 9afcf0ea7d2877ddbbafec5b1a77f5cf025dab17 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 17:17:03 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 27 ++++++++++++++++++++++-----
1 files changed, 22 insertions(+), 5 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index f72b2c8..2a55cd6 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -407,9 +407,21 @@
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
+ freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
+ if freeze_layer_num > 0:
+ freeze_layer_num = range(freeze_layer_num)
+
if freeze:
for name, param in audio_encoder.named_parameters():
- param.requires_grad = False
+ idx = re.search(r"\.\d+\.", name)
+ if idx is not None:
+ beg, end = idx.regs[0]
+ layer_id = int(name[beg + 1 : end - 1])
+ if isinstance(freeze_layer_num, (list, tuple)):
+ if layer_id in freeze_layer_num:
+ param.requires_grad = False
+ else:
+ param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder
@@ -687,10 +699,8 @@
# fp16
if kwargs.get("fp16", False):
speech = speech.to(torch.float16)
- encoder_out_lens = encoder_out_lens.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
- encoder_out_lens = encoder_out_lens.to(torch.bfloat16)
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
# audio_adaptor
@@ -714,12 +724,18 @@
]
llm_dtype = kwargs.get("llm_dtype", "fp32")
+ if llm_dtype == "fp32":
+ llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
+ llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
+
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
- with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
+ with torch.cuda.amp.autocast(
+ enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
+ ):
label = contents["assistant"][0]
self.llm = self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
- attention_mask = attention_mask.to(dtype_map[llm_dtype])
+
if not kwargs.get("tearchforing", False):
generated_ids = self.llm.generate(
@@ -739,6 +755,7 @@
labels_ids = batch["labels_ids"]
labels_ids[labels_ids == -1] = -100
attention_mask = batch.get("attention_mask", None)
+ # attention_mask = attention_mask.to(dtype_map[llm_dtype])
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
--
Gitblit v1.9.1