From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example
---
funasr/bin/inference.py | 183 +++++----------------------------------------
1 files changed, 22 insertions(+), 161 deletions(-)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index d63ebc9..d2f0c14 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -1,170 +1,31 @@
-import os.path
-
-import torch
-import numpy as np
import hydra
-import json
-from omegaconf import DictConfig, OmegaConf
-from funasr.utils.dynamic_import import dynamic_import
import logging
-from funasr.download.download_from_hub import download_model
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
-from funasr.train_utils.device_funcs import to_device
-from tqdm import tqdm
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-import time
-import random
-import string
+from omegaconf import DictConfig, OmegaConf, ListConfig
+
+from funasr.auto.auto_model import AutoModel
+
@hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
- assert "model" in kwargs
+def main_hydra(cfg: DictConfig):
+ def to_plain_list(cfg_item):
+ if isinstance(cfg_item, ListConfig):
+ return OmegaConf.to_container(cfg_item, resolve=True)
+ elif isinstance(cfg_item, DictConfig):
+ return {k: to_plain_list(v) for k, v in cfg_item.items()}
+ else:
+ return cfg_item
+
+ kwargs = to_plain_list(cfg)
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
- pipeline = infer(**kwargs)
- res = pipeline(input=kwargs["input"])
- print(res)
-
-def infer(**kwargs):
-
- if ":" not in kwargs["model"]:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
- kwargs = download_model(**kwargs)
-
- set_all_random_seed(kwargs.get("seed", 0))
+ logging.basicConfig(level=log_level)
-
- device = kwargs.get("device", "cuda")
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
- device = "cpu"
- batch_size = 1
- kwargs["device"] = device
-
- # build_tokenizer
- tokenizer = build_tokenizer(
- token_type=kwargs.get("token_type", "char"),
- bpemodel=kwargs.get("bpemodel", None),
- delimiter=kwargs.get("delimiter", None),
- space_symbol=kwargs.get("space_symbol", "<space>"),
- non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
- g2p_type=kwargs.get("g2p_type", None),
- token_list=kwargs.get("token_list", None),
- unk_symbol=kwargs.get("unk_symbol", "<unk>"),
- )
-
- import pdb;
- pdb.set_trace()
- # build model
- model_class = dynamic_import(kwargs.get("model"))
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- model.eval()
- model.to(device)
- frontend = model.frontend
- kwargs["token_list"] = tokenizer.token_list
-
-
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- logging.info(f"Loading pretrained params from {init_param}")
- load_pretrained_model(
- model=model,
- init_param=init_param,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
- oss_bucket=kwargs.get("oss_bucket", None),
- )
-
- def _forward(input, input_len=None, **cfg):
- cfg = OmegaConf.merge(kwargs, cfg)
- date_type = cfg.get("date_type", "sound")
-
- key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
-
- speed_stats = {}
- asr_result_list = []
- num_samples = len(data_list)
- pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
- for beg_idx in range(0, num_samples, batch_size):
-
- end_idx = min(num_samples, beg_idx + batch_size)
- data_batch = data_list[beg_idx:end_idx]
- key_batch = key_list[beg_idx:end_idx]
- batch = {"data_in": data_batch, "key": key_batch}
-
- time1 = time.perf_counter()
- results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
- time2 = time.perf_counter()
-
- asr_result_list.append(results)
- pbar.update(1)
-
- # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
- batch_data_time = meta_data.get("batch_data_time", -1)
- speed_stats["load_data"] = meta_data["load_data"]
- speed_stats["extract_feat"] = meta_data["extract_feat"]
- speed_stats["forward"] = f"{time2 - time1:0.3f}"
- speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
- description = (
- f"{speed_stats}, "
- )
- pbar.set_description(description)
-
- torch.cuda.empty_cache()
- return asr_result_list
-
- return _forward
-
-
-def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
- """
-
- :param input:
- :param input_len:
- :param date_type:
- :param frontend:
- :return:
- """
- data_list = []
- key_list = []
- filelist = [".scp", ".txt", ".json", ".jsonl"]
-
- chars = string.ascii_letters + string.digits
-
- if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
- _, file_extension = os.path.splitext(data_in)
- file_extension = file_extension.lower()
- if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
- with open(data_in, encoding='utf-8') as fin:
- for line in fin:
- key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
- if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
- lines = json.loads(line.strip())
- data = lines["source"]
- key = data["key"] if "key" in data else key
- else: # filelist, wav.scp, text.txt: id \t data or data
- lines = line.strip().split()
- data = lines[1] if len(lines)>1 else lines[0]
- key = lines[0] if len(lines)>1 else key
-
- data_list.append(data)
- key_list.append(key)
- else:
- key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
- data_list = [data_in]
- key_list = [key]
- elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
- data_list = data_in
- key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
- else: # raw text; audio sample point, fbank
- if isinstance(data_in, bytes): # audio bytes
- data_in = load_bytes(data_in)
- key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
- data_list = [data_in]
- key_list = [key]
-
- return key_list, data_list
+ if kwargs.get("debug", False):
+ import pdb; pdb.set_trace()
+ model = AutoModel(**kwargs)
+ res = model.generate(input=kwargs["input"])
+ print(res)
if __name__ == '__main__':
- main_hydra()
\ No newline at end of file
+ main_hydra()
\ No newline at end of file
--
Gitblit v1.9.1