| examples/industrial_data_pretraining/llm_asr/demo_speech2text.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/__init__.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/auto/auto_model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/datasets/openai_datasets/datasets.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/datasets/openai_datasets/index_ds.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/download/download_from_hub.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/llm_asr/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/utils/dynamic_import.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/utils/export_utils.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
examples/industrial_data_pretraining/llm_asr/demo_speech2text.py
@@ -9,19 +9,20 @@ from funasr import AutoModel ckpt_dir = "/nfs/beinian.lzr/workspace/GPT-4o/Exp/exp6/5m-8gpu/exp6_speech2text_linear_ddp_0609" ckpt_id = "model.pt.ep0.90000" jsonl = ( "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/aishell1_test_speech2text.jsonl" ) output_dir = f"{os.path.join(ckpt_dir, ckpt_id)}" device = "cuda:0" if len(sys.argv) > 1: ckpt_dir = sys.argv[1] ckpt_id = sys.argv[2] jsonl = sys.argv[3] output_dir = sys.argv[4] device = sys.argv[5] else: ckpt_dir = "/nfs/beinian.lzr/workspace/GPT-4o/Exp/exp6/5m-8gpu/exp6_speech2text_linear_ddp_0609" ckpt_id = "model.pt.ep0.90000" jsonl = "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/aishell1_test_speech2text.jsonl" dataset = jsonl.split("/")[-1] output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset) device = "cuda:0" ckpt_dir = sys.argv[1] ckpt_id = sys.argv[2] jsonl = sys.argv[3] output_dir = sys.argv[4] device = sys.argv[5] model = AutoModel( model=ckpt_dir, examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py
New file @@ -0,0 +1,76 @@ #!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import json import os import sys from funasr import AutoModel if len(sys.argv) > 1: ckpt_dir = sys.argv[1] ckpt_id = sys.argv[2] jsonl = sys.argv[3] output_dir = sys.argv[4] device = sys.argv[5] else: ckpt_dir = "/nfs/beinian.lzr/workspace/GPT-4o/Exp/exp7/5m-8gpu/exp5-1-0619" ckpt_id = "model.pt.ep6" jsonl = ( "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl" ) dataset = jsonl.split("/")[-1] output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset) model = AutoModel( model=ckpt_dir, init_param=f"{os.path.join(ckpt_dir, ckpt_id)}", output_dir=output_dir, device=device, fp16=False, bf16=False, llm_dtype="bf16", ) with open(jsonl, "r") as f: lines = f.readlines() tearchforing = False for i, line in enumerate(lines): key_i = f"dialog_{i}" data_dict = json.loads(line.strip()) data = data_dict["messages"] contents = model.model.data_template(data) system = contents["system"] user = contents["user"] assistant = contents["assistant"] system_i, user_i, assistant_i = [], [], [] contents_i = [] for j, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): key = f"{key_i}_turn_{j}" if j == 0: contents_i.append({"role": "system", "content": system_prompt}) contents_i.append({"role": "user", "content": user_prompt}) contents_i.append({"role": "assistant", "content": target_out}) res = model.generate( input=[contents_i], tearchforing=tearchforing, cache={}, key=key, ) print(res) funasr/__init__.py
@@ -1,8 +1,6 @@ """Initialize funasr package.""" import os import pkgutil import importlib dirname = os.path.dirname(__file__) version_file = os.path.join(dirname, "version.txt") funasr/auto/auto_model.py
@@ -92,7 +92,8 @@ if isinstance(data_i, str) and os.path.exists(data_i): key = misc.extract_filename_without_extension(data_i) else: key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) if key is None: key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) key_list.append(key) else: # raw text; audio sample point, fbank; bytes funasr/datasets/openai_datasets/datasets.py
@@ -283,10 +283,11 @@ self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") # self.kwargs = kwargs self.max_token_length = kwargs.get("max_token_length", 1024) self.max_token_length = kwargs.get("max_token_length", 1500) self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5) self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500) self.multiturn_num_max = kwargs.get("multiturn_num_max", 5) self.max_source_length = kwargs.get("max_source_length", 3000) def get_source_len(self, index): item = self.index_ds[index] @@ -334,6 +335,12 @@ ): if i >= self.multiturn_num_max: break if len(input_ids) > self.max_token_length: logging.info( f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}" ) break if i == 0: source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" else: @@ -372,6 +379,11 @@ frontend=self.frontend, is_final=True, ) # speech: [b, T, d] if speech_lengths > self.max_source_length: logging.info( f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" ) badcase_flag = True if self.permute: speech = speech.permute(0, 2, 1) # if speech_lengths > self.batch_size: @@ -399,13 +411,9 @@ fbank_mask += fbank_mask_i fbank_lens.append(speech_lengths) if len(input_ids) > self.max_token_length: logging.info( f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}" ) badcase_flag = True if badcase_flag: continue input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] funasr/datasets/openai_datasets/index_ds.py
@@ -16,6 +16,12 @@ def __init__(self, path: str, **kwargs): super().__init__() self.max_source_length = kwargs.get("max_source_length", 3000) self.min_source_length = kwargs.get("min_source_length", 0) self.max_target_length = kwargs.get("max_target_length", 2048) self.min_target_length = kwargs.get("min_target_length", 0) self.max_token_length = kwargs.get("max_token_length", 2200) is_training = kwargs.get("is_training", True) if not (path.endswith(".jsonl") or path.endswith(".json")): # jsonl list file @@ -47,6 +53,15 @@ data = data_dict["messages"] speech_length = data_dict.get("speech_length", -1) // 8 text_length = data_dict.get("text_length", 0) if speech_length > self.max_source_length: logging.info( "speech_length: {speech_length} > {self.max_source_length}, drop it" ) continue if text_length > self.max_target_length: continue self.max_target_length = kwargs.get("max_target_length", 2048) system, user, assistant = [], [], [] for i, item in enumerate(data): funasr/download/download_from_hub.py
@@ -84,6 +84,12 @@ from funasr.utils.install_model_requirements import install_requirements install_requirements(requirements) if kwargs.get("trust_remote_code", False): import model # from funasr.register import tables # tables.print("model") return kwargs funasr/models/llm_asr/model.py
@@ -988,9 +988,9 @@ text: (Batch, Length) text_lengths: (Batch,) """ import pdb pdb.set_trace() # import pdb # # pdb.set_trace() if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] @@ -1011,6 +1011,7 @@ fake_token_len = kwargs.get("fake_token_len") fake_token_len[fake_token_len < 0] = 0 fbank_beg[fbank_beg < 0] = 0 speech_idx = 0 for batch_idx in range(batch_size): @@ -1025,12 +1026,15 @@ batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token except Exception as e: # logging.error(f"{str(e)}, {traceback.format_exc()}") logging.info( f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[speech_idx].item()}" f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" ) # import pdb; # pdb.set_trace() speech_token_len = encoder_out_lens[speech_idx].item() speech_token = encoder_out[speech_idx, turn_id, :speech_token_len, :] speech_token = encoder_out[speech_idx, :speech_token_len, :] inputs_embeds[ batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token @@ -1064,6 +1068,12 @@ stats["batch_size_x_tokens"] = token_num * batch_size stats["batch_size_real_tokens"] = attention_mask.sum().item() stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] dialog_turns = (fbank_beg > 0).sum(-1) dialog_turns_max = torch.max(dialog_turns).int().item() dialog_turns_avg = dialog_turns.sum().item() / batch_size stats["dialog_turns_max"] = dialog_turns_max stats["dialog_turns_avg"] = dialog_turns_avg # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: @@ -1105,8 +1115,8 @@ user = contents["user"] assistant = contents["assistant"] pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = ( [], input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = ( [], [], [], @@ -1115,21 +1125,30 @@ [], [], ) input_source_ids = [] for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): if i >= kwargs.get("multiturn_num_max", 5): break if len(input_ids) > kwargs.get("max_token_length", 1500): source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" break if i == 0: source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" else: source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" splits = pattern.split(source_input) source_ids_i = [] source_ids = [] fbank_i = [] fbank_mask_i = [] fbank_beg_i = [] fake_token_len_i = 0 fbank_beg_i = -1 fbank_lens_i = [] # target_ids_i = [] for k, sub_str in enumerate(splits): if not sub_str.startswith("<|startofspeech|>"): sub_token = tokenizer.encode(sub_str) source_ids_i += sub_token source_ids += sub_token fbank_mask_i += [0] * len(sub_token) else: sub_str = sub_str.replace("<|startofspeech|>", "").replace( @@ -1162,42 +1181,57 @@ if kwargs.get("permute", True): speech = speech.permute(0, 2, 1) if speech_lengths > kwargs.get("max_source_length", 5500): # logging.info( # f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" # ) badcase_flag = True olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 olens = 1 + (olens - 3 + 2 * 1) // 2 sub_token_len = (olens - 1) // 2 + 1 sub_token = [0] * sub_token_len fbank_beg_i = [len(source_ids_i)] source_ids_i += sub_token fbank_mask_i += [1] * len(sub_token) fake_token_len_i = (olens - 1) // 2 + 1 fake_token = [0] * fake_token_len_i fbank_beg_i = len(source_ids) source_ids += fake_token fbank_mask_i += [1] * len(fake_token) source_mask = [-100] * len(source_ids_i) fbank_beg += [fbank_beg_i + len(input_ids)] fake_token_len += [fake_token_len_i] source_mask = [-100] * len(source_ids) target_out = f"{target_out}<|im_end|>" target_ids = tokenizer.encode(target_out) input_ids += source_ids_i + target_ids input_source_ids = input_ids + source_ids input_ids += source_ids + target_ids labels += source_mask + target_ids fbank.append(speech[0, :, :]) fbank_mask += fbank_mask_i fbank_beg.append(fbank_beg_i) fbank_lens.append(speech_lengths) input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length] source_ids = torch.tensor(source_ids_i, dtype=torch.int64) target_ids = torch.tensor(target_ids, dtype=torch.int64) fbank = speech[0, :, :] fbank_lens = speech_lengths # fbank = speech[0, :, :] # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32) fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32) fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32) fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32) source_ids = torch.tensor(input_source_ids, dtype=torch.int64) target_ids = torch.tensor(target_ids, dtype=torch.int64) speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) speech_lengths = torch.nn.utils.rnn.pad_sequence( fbank_lens, batch_first=True, padding_value=-1 ) output = { "speech": fbank[None, :, :], "speech_lengths": fbank_lens[:, None], "speech": speech, "speech_lengths": speech_lengths, "fbank_mask": fbank_mask[None, :], "fbank_beg": fbank_beg[None,], "input_ids": input_ids[None, :], "attention_mask": attention_mask[None, :], "labels_ids": labels[None, :], "fake_token_len": fake_token_len[None, :], "input_ids": input_ids[None,], "attention_mask": attention_mask[None,], "labels_ids": labels, "source_ids": source_ids[None, :], "target_ids": target_ids[None, :], } @@ -1240,20 +1274,48 @@ input_ids = batch["input_ids"] source_ids = batch["source_ids"] fbank_beg = batch["fbank_beg"] fake_token_len = batch["fake_token_len"] if not kwargs.get("tearchforing", False): input_ids = source_ids input_ids[input_ids < 0] = 0 inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) batch_size, token_num, dims = inputs_embeds.shape fbank_beg = batch["fbank_beg"] fake_token_len[fake_token_len < 0] = 0 fbank_beg[fbank_beg < 0] = 0 speech_idx = 0 for batch_idx in range(batch_size): min_len = encoder_out_lens[batch_idx].item() fbank_beg_idx = fbank_beg[batch_idx] inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[ batch_idx, :min_len, : ] for turn_id in range(fbank_beg.shape[1]): fbank_beg_idx = fbank_beg[batch_idx, turn_id].item() if fbank_beg_idx > 0: speech_token_len = fake_token_len[batch_idx, turn_id] speech_token = encoder_out[speech_idx, :speech_token_len, :] try: inputs_embeds[ batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token except Exception as e: # logging.error(f"{str(e)}, {traceback.format_exc()}") logging.info( f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}" ) # import pdb; # pdb.set_trace() speech_token_len = encoder_out_lens[speech_idx].item() speech_token = encoder_out[speech_idx, :speech_token_len, :] inputs_embeds[ batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, : ] = speech_token speech_idx += 1 llm_dtype = kwargs.get("llm_dtype", "fp32") if llm_dtype == "fp32": @@ -1263,7 +1325,7 @@ with torch.cuda.amp.autocast( enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] ): label = contents["assistant"][0] label = contents["assistant"][-1] self.llm = self.llm.to(dtype_map[llm_dtype]) inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) @@ -1313,8 +1375,8 @@ results.append(result_i) if ibest_writer is not None: ibest_writer["text"][key[0]] = response ibest_writer["label"][key[0]] = label ibest_writer["text"][key[0]] = response.replace("\n", " ") ibest_writer["label"][key[0]] = label.replace("\n", " ") ibest_writer["text_tn"][key[0]] = response_clean return results, meta_data funasr/utils/dynamic_import.py
New file @@ -0,0 +1,39 @@ import importlib.util import importlib.util import inspect def load_module_from_path(file_path): """ 从给定的文件路径动态加载模块。 :param file_path: 模块文件的绝对路径。 :return: 加载的模块 """ module_name = file_path.split("/")[-1].replace(".py", "") spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module # # def load_module_from_path(module_name, file_path): # """ # 从给定的文件路径动态加载模块。 # # :param module_name: 动态加载的模块的名称。 # :param file_path: 模块文件的绝对路径。 # :return: 加载的模块 # """ # # 创建加载模块的spec(规格) # spec = importlib.util.spec_from_file_location(module_name, file_path) # # # 根据spec创建模块 # module = importlib.util.module_from_spec(spec) # # # 执行模块的代码来实际加载它 # spec.loader.exec_module(module) # # return module funasr/utils/export_utils.py
@@ -5,7 +5,7 @@ try: import torch_blade except Exception as e: print(f"failed to load torch_blade: {e}") print(f"Warning, if you are exporting bladedisc, please install it and try it again: pip install -U torch_blade\n") def export(model, data_in=None, quantize: bool = False, opset_version: int = 14, type='onnx', **kwargs): @@ -196,4 +196,4 @@ model.encoder = _bladedisc_opt(model.encoder, input_data[:2]) model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs)) model_script = torch.jit.trace(model, input_data) model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscripts")) model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscripts"))