From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example

---
 funasr/bin/inference.py |  187 +++++-----------------------------------------
 1 files changed, 20 insertions(+), 167 deletions(-)

diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 50ea4d4..d2f0c14 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -1,178 +1,31 @@
-import os.path
-
-import torch
-import numpy as np
 import hydra
-import json
-from omegaconf import DictConfig, OmegaConf
 import logging
-from funasr.download.download_from_hub import download_model
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_bytes
-from funasr.train_utils.device_funcs import to_device
-from tqdm import tqdm
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-import time
-import random
-import string
-from funasr.utils.register import registry_tables
+from omegaconf import DictConfig, OmegaConf, ListConfig
 
+from funasr.auto.auto_model import AutoModel
 
-def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
-	"""
-	
-	:param input:
-	:param input_len:
-	:param data_type:
-	:param frontend:
-	:return:
-	"""
-	data_list = []
-	key_list = []
-	filelist = [".scp", ".txt", ".json", ".jsonl"]
-	
-	chars = string.ascii_letters + string.digits
-	
-	if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
-		_, file_extension = os.path.splitext(data_in)
-		file_extension = file_extension.lower()
-		if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
-			with open(data_in, encoding='utf-8') as fin:
-				for line in fin:
-					key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
-					if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
-						lines = json.loads(line.strip())
-						data = lines["source"]
-						key = data["key"] if "key" in data else key
-					else: # filelist, wav.scp, text.txt: id \t data or data
-						lines = line.strip().split()
-						data = lines[1] if len(lines)>1 else lines[0]
-						key = lines[0] if len(lines)>1 else key
-					
-					data_list.append(data)
-					key_list.append(key)
-		else:
-			key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
-			data_list = [data_in]
-			key_list = [key]
-	elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank]
-		data_list = data_in
-		key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
-	else: # raw text; audio sample point, fbank; bytes
-		if isinstance(data_in, bytes): # audio bytes
-			data_in = load_bytes(data_in)
-		key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
-		data_list = [data_in]
-		key_list = [key]
-	
-	return key_list, data_list
 
 @hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
-	log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+def main_hydra(cfg: DictConfig):
+    def to_plain_list(cfg_item):
+        if isinstance(cfg_item, ListConfig):
+            return OmegaConf.to_container(cfg_item, resolve=True)
+        elif isinstance(cfg_item, DictConfig):
+            return {k: to_plain_list(v) for k, v in cfg_item.items()}
+        else:
+            return cfg_item
+    
+    kwargs = to_plain_list(cfg)
+    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
 
-	logging.basicConfig(level=log_level)
+    logging.basicConfig(level=log_level)
 
-	import pdb;
-	pdb.set_trace()
-	model = AutoModel(**kwargs)
-	res = model.generate(input=kwargs["input"])
-	print(res)
+    if kwargs.get("debug", False):
+        import pdb; pdb.set_trace()
+    model = AutoModel(**kwargs)
+    res = model.generate(input=kwargs["input"])
+    print(res)
 
-class AutoModel:
-	def __init__(self, **kwargs):
-		registry_tables.print()
-		assert "model" in kwargs
-		if "model_conf" not in kwargs:
-			logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
-			kwargs = download_model(**kwargs)
-		
-		set_all_random_seed(kwargs.get("seed", 0))
-		
-		device = kwargs.get("device", "cuda")
-		if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
-			device = "cpu"
-			kwargs["batch_size"] = 1
-		kwargs["device"] = device
-
-		# build tokenizer
-		tokenizer = kwargs.get("tokenizer", None)
-		if tokenizer is not None:
-			tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
-			tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
-			kwargs["tokenizer"] = tokenizer
-			kwargs["token_list"] = tokenizer.token_list
-		
-		# build frontend
-		frontend = kwargs.get("frontend", None)
-		if frontend is not None:
-			frontend_class = registry_tables.frontend_classes.get(frontend.lower())
-			frontend = frontend_class(**kwargs["frontend_conf"])
-			kwargs["frontend"] = frontend
-			kwargs["input_size"] = frontend.output_size()
-		
-		# build model
-		model_class = registry_tables.model_classes.get(kwargs["model"].lower())
-		model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
-		model.eval()
-		model.to(device)
-		
-		# init_param
-		init_param = kwargs.get("init_param", None)
-		if init_param is not None:
-			logging.info(f"Loading pretrained params from {init_param}")
-			load_pretrained_model(
-				model=model,
-				init_param=init_param,
-				ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
-				oss_bucket=kwargs.get("oss_bucket", None),
-			)
-		self.kwargs = kwargs
-		self.model = model
-		self.tokenizer = tokenizer
-	
-	def generate(self, input, input_len=None, **cfg):
-		self.kwargs.update(cfg)
-		data_type = self.kwargs.get("data_type", "sound")
-		batch_size = self.kwargs.get("batch_size", 1)
-		if self.kwargs.get("device", "cpu") == "cpu":
-			batch_size = 1
-		
-		key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
-		
-		speed_stats = {}
-		asr_result_list = []
-		num_samples = len(data_list)
-		pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
-		for beg_idx in range(0, num_samples, batch_size):
-			end_idx = min(num_samples, beg_idx + batch_size)
-			data_batch = data_list[beg_idx:end_idx]
-			key_batch = key_list[beg_idx:end_idx]
-			batch = {"data_in": data_batch, "key": key_batch}
-			if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
-				batch["data_batch"] = data_batch[0]
-				batch["data_lengths"] = input_len
-		
-			time1 = time.perf_counter()
-			results, meta_data = self.model.generate(**batch, **self.kwargs)
-			time2 = time.perf_counter()
-			
-			asr_result_list.append(results)
-			pbar.update(1)
-			
-			# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
-			batch_data_time = meta_data.get("batch_data_time", -1)
-			speed_stats["load_data"] = meta_data.get("load_data", 0.0)
-			speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
-			speed_stats["forward"] = f"{time2 - time1:0.3f}"
-			speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
-			description = (
-				f"{speed_stats}, "
-			)
-			pbar.set_description(description)
-		
-		torch.cuda.empty_cache()
-		return asr_result_list
 
 if __name__ == '__main__':
-	main_hydra()
\ No newline at end of file
+    main_hydra()
\ No newline at end of file

--
Gitblit v1.9.1