From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/bin/inference.py | 218 ++++++++++++++++++++++++++++--------------------------
1 files changed, 113 insertions(+), 105 deletions(-)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index d63ebc9..09e28f3 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -5,123 +5,25 @@
import hydra
import json
from omegaconf import DictConfig, OmegaConf
-from funasr.utils.dynamic_import import dynamic_import
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time
import random
import string
+from funasr.utils.register import registry_tables
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
- assert "model" in kwargs
- pipeline = infer(**kwargs)
- res = pipeline(input=kwargs["input"])
- print(res)
-
-def infer(**kwargs):
-
- if ":" not in kwargs["model"]:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
- kwargs = download_model(**kwargs)
-
- set_all_random_seed(kwargs.get("seed", 0))
-
-
- device = kwargs.get("device", "cuda")
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
- device = "cpu"
- batch_size = 1
- kwargs["device"] = device
-
- # build_tokenizer
- tokenizer = build_tokenizer(
- token_type=kwargs.get("token_type", "char"),
- bpemodel=kwargs.get("bpemodel", None),
- delimiter=kwargs.get("delimiter", None),
- space_symbol=kwargs.get("space_symbol", "<space>"),
- non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
- g2p_type=kwargs.get("g2p_type", None),
- token_list=kwargs.get("token_list", None),
- unk_symbol=kwargs.get("unk_symbol", "<unk>"),
- )
-
- import pdb;
- pdb.set_trace()
- # build model
- model_class = dynamic_import(kwargs.get("model"))
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- model.eval()
- model.to(device)
- frontend = model.frontend
- kwargs["token_list"] = tokenizer.token_list
-
-
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- logging.info(f"Loading pretrained params from {init_param}")
- load_pretrained_model(
- model=model,
- init_param=init_param,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
- oss_bucket=kwargs.get("oss_bucket", None),
- )
-
- def _forward(input, input_len=None, **cfg):
- cfg = OmegaConf.merge(kwargs, cfg)
- date_type = cfg.get("date_type", "sound")
-
- key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
-
- speed_stats = {}
- asr_result_list = []
- num_samples = len(data_list)
- pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
- for beg_idx in range(0, num_samples, batch_size):
-
- end_idx = min(num_samples, beg_idx + batch_size)
- data_batch = data_list[beg_idx:end_idx]
- key_batch = key_list[beg_idx:end_idx]
- batch = {"data_in": data_batch, "key": key_batch}
-
- time1 = time.perf_counter()
- results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
- time2 = time.perf_counter()
-
- asr_result_list.append(results)
- pbar.update(1)
-
- # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
- batch_data_time = meta_data.get("batch_data_time", -1)
- speed_stats["load_data"] = meta_data["load_data"]
- speed_stats["extract_feat"] = meta_data["extract_feat"]
- speed_stats["forward"] = f"{time2 - time1:0.3f}"
- speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
- description = (
- f"{speed_stats}, "
- )
- pbar.set_description(description)
-
- torch.cuda.empty_cache()
- return asr_result_list
-
- return _forward
-
-
-def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
+def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
"""
:param input:
:param input_len:
- :param date_type:
+ :param data_type:
:param frontend:
:return:
"""
@@ -131,7 +33,7 @@
chars = string.ascii_letters + string.digits
- if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
+ if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
@@ -153,10 +55,10 @@
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
- elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
+ elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
- else: # raw text; audio sample point, fbank
+ else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
@@ -165,6 +67,112 @@
return key_list, data_list
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+
+ logging.basicConfig(level=log_level)
+
+ import pdb;
+ pdb.set_trace()
+ model = AutoModel(**kwargs)
+ res = model.generate(input=kwargs["input"])
+ print(res)
+
+class AutoModel:
+ def __init__(self, **kwargs):
+ registry_tables.print_register_tables()
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ if frontend is not None:
+ frontend_class = registry_tables.frontend_classes.get(frontend.lower())
+ frontend = frontend_class(**kwargs["frontend_conf"])
+ kwargs["frontend"] = frontend
+
+ # build model
+ model_class = registry_tables.model_classes.get(kwargs["model"].lower())
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ model.eval()
+ model.to(device)
+
+ kwargs["token_list"] = tokenizer.token_list
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ init_param=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ )
+ self.kwargs = kwargs
+ self.model = model
+ self.tokenizer = tokenizer
+
+ def generate(self, input, input_len=None, **cfg):
+ self.kwargs.update(cfg)
+ data_type = self.kwargs.get("data_type", "sound")
+ batch_size = self.kwargs.get("batch_size", 1)
+ if self.kwargs.get("device", "cpu") == "cpu":
+ batch_size = 1
+
+ key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+ if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
+ batch["data_batch"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ results, meta_data = self.model.generate(**batch, **self.kwargs)
+ time2 = time.perf_counter()
+
+ asr_result_list.append(results)
+ pbar.update(1)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time2 - time1:0.3f}"
+ speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
+ description = (
+ f"{speed_stats}, "
+ )
+ pbar.set_description(description)
+
+ torch.cuda.empty_cache()
+ return asr_result_list
if __name__ == '__main__':
main_hydra()
\ No newline at end of file
--
Gitblit v1.9.1