From 5de8bfdcd8a617ac13c13478505401bbf4e57472 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 15:38:17 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 32 ++++++++++++++++++++------------
1 files changed, 20 insertions(+), 12 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index fb0bee3..15969e3 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -21,6 +21,8 @@
from funasr.train_utils.device_funcs import to_device
import traceback
+dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
+
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
@@ -413,15 +415,18 @@
if freeze:
for name, param in audio_encoder.named_parameters():
- idx = re.search(r"\.\d+\.", name)
- if idx is not None:
- beg, end = idx.regs[0]
- layer_id = int(name[beg + 1 : end - 1])
- if isinstance(freeze_layer_num, (list, tuple)):
+ if isinstance(freeze_layer_num, (list, tuple)):
+ idx = re.search(r"\.\d+\.", name)
+ if idx is not None:
+ beg, end = idx.regs[0]
+ layer_id = int(name[beg + 1 : end - 1])
if layer_id in freeze_layer_num:
param.requires_grad = False
else:
param.requires_grad = False
+ else:
+ param.requires_grad = False
+
audio_encoder.eval()
self.audio_encoder = audio_encoder
@@ -446,6 +451,7 @@
model.eval()
self.llm = model
llm_dim = model.get_input_embeddings().weight.shape[-1]
+ self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@@ -524,12 +530,15 @@
batch_idx, :min_len, :
]
- labels_ids[labels_ids == -1] = -100
-
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
- loss = model_outputs.loss
+ with torch.cuda.amp.autocast(
+ enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
+ ):
+ labels_ids[labels_ids == -1] = -100
+ attention_mask[attention_mask < 0] = 0
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+ loss = model_outputs.loss
stats = {}
with torch.no_grad():
@@ -734,7 +743,6 @@
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
- dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
with torch.cuda.amp.autocast(
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
):
--
Gitblit v1.9.1