From 08114ae27d85949106aeab03b3fa5d764d100b33 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 14 六月 2024 15:16:40 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 120 +++++++++++++++++++++++++++++++++++++++++-------------------
1 files changed, 82 insertions(+), 38 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 2a55cd6..1151269 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -21,6 +21,8 @@
from funasr.train_utils.device_funcs import to_device
import traceback
+dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
+
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
@@ -396,7 +398,9 @@
# frontend = model.kwargs.get("frontend")
audio_encoder_output_size = model.model.encoder_output_size
- audio_encoder = model.model.model.encoder
+ audio_encoder = (
+ model.model.model.encoder if hasattr(model.model, "model") else model.model.encoder
+ )
# self.frontend = frontend
@@ -408,49 +412,59 @@
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
- if freeze_layer_num > 0:
- freeze_layer_num = range(freeze_layer_num)
+ # if freeze_layer_num > 0:
+ # freeze_layer_num = range(freeze_layer_num)
if freeze:
for name, param in audio_encoder.named_parameters():
- idx = re.search(r"\.\d+\.", name)
- if idx is not None:
- beg, end = idx.regs[0]
- layer_id = int(name[beg + 1 : end - 1])
- if isinstance(freeze_layer_num, (list, tuple)):
- if layer_id in freeze_layer_num:
+ if freeze_layer_num > 0:
+ idx = re.search(r"\.\d+\.", name)
+ if idx is not None:
+ beg, end = idx.regs[0]
+ layer_id = int(name[beg + 1 : end - 1])
+ if layer_id < freeze_layer_num:
param.requires_grad = False
- else:
+ elif "ln_post." not in name:
param.requires_grad = False
+ else:
+ param.requires_grad = False
+
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
- hub = llm_conf.get("hub", "hf")
self.llm = None
- if hub == "hf":
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- model = AutoModelForCausalLM.from_pretrained(
- init_param_path,
- load_in_8bit=None,
- device_map=None,
- use_cache=None,
- )
- freeze = llm_conf.get("freeze", True)
- if freeze:
- for name, param in model.named_parameters():
- param.requires_grad = False
- model.eval()
- self.llm = model
+ init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ )
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+ self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
+ self.llm = model.to(dtype_map[self.llm_dtype])
+ llm_dim = model.get_input_embeddings().weight.shape[-1]
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
+ audio_adaptor_conf["llm_dim"] = llm_dim
audio_adaptor = adaptor_class(**audio_adaptor_conf)
+ init_param_path = audio_adaptor_conf.get("init_param_path", None)
+ if init_param_path is not None:
+ src_state = torch.load(init_param_path, map_location="cpu")
+ flag = audio_adaptor.load_state_dict(src_state, strict=False)
+ logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
self.audio_adaptor = audio_adaptor
@@ -484,11 +498,12 @@
batch_size, frames, _ = speech.shape
- # audio encoder
- encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
+ with torch.cuda.amp.autocast(enabled=False):
+ # audio encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- # audio_adaptor
- encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
+ # audio_adaptor
+ encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
@@ -518,12 +533,17 @@
batch_idx, :min_len, :
]
- labels_ids[labels_ids == -1] = -100
-
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
- loss = model_outputs.loss
+ with torch.cuda.amp.autocast(
+ enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
+ ):
+ labels_ids[labels_ids == -1] = -100
+ attention_mask[attention_mask < 0] = 0
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]),
+ attention_mask=attention_mask,
+ labels=labels_ids,
+ )
+ loss = model_outputs.loss
stats = {}
with torch.no_grad():
@@ -545,6 +565,12 @@
batch_size = int((labels_ids > 0 + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
+
+ def encode(self, speech, speech_lengths):
+ # audio encoder
+ encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
+
+ return encoder_out, encoder_out_lens
def data_template(self, data):
system, user, assistant = [], [], []
@@ -701,7 +727,8 @@
speech = speech.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
- encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
+ # audio encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
@@ -728,7 +755,6 @@
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
- dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
with torch.cuda.amp.autocast(
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
):
@@ -787,3 +813,21 @@
ibest_writer["text_tn"][key[0]] = response_clean
return results, meta_data
+
+
+@tables.register("model_classes", "LLMASR3")
+class LLMASR3(LLMASR2):
+ """ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+
+ super().__init__(*args, **kwargs)
+
+ def encode(self, speech, speech_lengths):
+ # audio encoder
+ encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
+ return encoder_out, encoder_out_lens
--
Gitblit v1.9.1