From e9acc5db07daa51a22cd51ea9233ee09a38d726d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 18:36:22 +0800
Subject: [PATCH] auto frontend
---
funasr/models/llm_asr/model.py | 76 ++++++++++----------------------------
1 files changed, 20 insertions(+), 56 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 11db009..411b59d 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -385,13 +385,6 @@
super().__init__()
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = tables.normalize_classes.get(normalize)
- normalize = normalize_class(**normalize_conf)
-
# audio encoder
hub = audio_encoder_conf.get("hub", None)
if hub == "ms":
@@ -422,23 +415,23 @@
# llm
hub = llm_conf.get("hub", "hf")
self.llm = None
- # if hub == "hf":
- # from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- #
- # init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
- #
- # model = AutoModelForCausalLM.from_pretrained(
- # init_param_path,
- # load_in_8bit=None,
- # device_map=None,
- # use_cache=None,
- # )
- # freeze = llm_conf.get("freeze", True)
- # if freeze:
- # for name, param in model.named_parameters():
- # param.requires_grad = False
- # model.eval()
- # self.llm = model
+ if hub == "hf":
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+ init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ )
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+ self.llm = model
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@@ -446,21 +439,6 @@
audio_adaptor = adaptor_class(**audio_adaptor_conf)
self.audio_adaptor = audio_adaptor
-
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.specaug = specaug
- self.normalize = normalize
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
self.error_calculator = None
@@ -493,10 +471,10 @@
batch_size = speech.shape[0]
# audio encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
# audio_adaptor
- encoder_out = self.audio_adaptor(encoder_out)
+ encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
input_ids[input_ids == -1] = 0
input_ids[input_ids == -100] = 0
@@ -530,23 +508,9 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
+ batch_size = int((labels_ids > 0 + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- **kwargs,
- ):
- speech = speech.permute(0, 2, 1)
- res = self.audio_encoder(speech)
- if isinstance(res, (list, tuple)):
- encoder_out, encoder_out_lens = res[0], res[1]
- else:
- encoder_out, encoder_out_lens = res, speech_lengths
- return encoder_out, encoder_out_lens
def inference(
self,
--
Gitblit v1.9.1