游雁
2023-12-19 0e622e694e6cb4459955f1e5942a7c53349ce640
funasr/bin/inference.py
@@ -5,123 +5,25 @@
import hydra
import json
from omegaconf import DictConfig, OmegaConf
from funasr.utils.dynamic_import import dynamic_import
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.tokenizer.funtoken import build_tokenizer
from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time
import random
import string
from funasr.utils.register import registry_tables
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
   assert "model" in kwargs
   pipeline = infer(**kwargs)
   res = pipeline(input=kwargs["input"])
   print(res)
def infer(**kwargs):
   if ":" not in kwargs["model"]:
      logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
      kwargs = download_model(**kwargs)
   set_all_random_seed(kwargs.get("seed", 0))
   device = kwargs.get("device", "cuda")
   if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
      device = "cpu"
      batch_size = 1
   kwargs["device"] = device
   # build_tokenizer
   tokenizer = build_tokenizer(
      token_type=kwargs.get("token_type", "char"),
      bpemodel=kwargs.get("bpemodel", None),
      delimiter=kwargs.get("delimiter", None),
      space_symbol=kwargs.get("space_symbol", "<space>"),
      non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
      g2p_type=kwargs.get("g2p_type", None),
      token_list=kwargs.get("token_list", None),
      unk_symbol=kwargs.get("unk_symbol", "<unk>"),
   )
   import pdb;
   pdb.set_trace()
   # build model
   model_class = dynamic_import(kwargs.get("model"))
   model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
   model.eval()
   model.to(device)
   frontend = model.frontend
   kwargs["token_list"] = tokenizer.token_list
   # init_param
   init_param = kwargs.get("init_param", None)
   if init_param is not None:
      logging.info(f"Loading pretrained params from {init_param}")
      load_pretrained_model(
         model=model,
         init_param=init_param,
         ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
         oss_bucket=kwargs.get("oss_bucket", None),
      )
   def _forward(input, input_len=None, **cfg):
      cfg = OmegaConf.merge(kwargs, cfg)
      date_type = cfg.get("date_type", "sound")
      key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
      speed_stats = {}
      asr_result_list = []
      num_samples = len(data_list)
      pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
      for beg_idx in range(0, num_samples, batch_size):
         end_idx = min(num_samples, beg_idx + batch_size)
         data_batch = data_list[beg_idx:end_idx]
         key_batch = key_list[beg_idx:end_idx]
         batch = {"data_in": data_batch, "key": key_batch}
         time1 = time.perf_counter()
         results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
         time2 = time.perf_counter()
         asr_result_list.append(results)
         pbar.update(1)
         # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
         batch_data_time = meta_data.get("batch_data_time", -1)
         speed_stats["load_data"] = meta_data["load_data"]
         speed_stats["extract_feat"] = meta_data["extract_feat"]
         speed_stats["forward"] = f"{time2 - time1:0.3f}"
         speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
         description = (
            f"{speed_stats}, "
         )
         pbar.set_description(description)
      torch.cuda.empty_cache()
      return asr_result_list
   return _forward
def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
   """
   
   :param input:
   :param input_len:
   :param date_type:
   :param data_type:
   :param frontend:
   :return:
   """
@@ -131,7 +33,7 @@
   
   chars = string.ascii_letters + string.digits
   
   if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
   if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
      _, file_extension = os.path.splitext(data_in)
      file_extension = file_extension.lower()
      if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
@@ -153,10 +55,10 @@
         key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
         data_list = [data_in]
         key_list = [key]
   elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
   elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank]
      data_list = data_in
      key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
   else: # raw text; audio sample point, fbank
   else: # raw text; audio sample point, fbank; bytes
      if isinstance(data_in, bytes): # audio bytes
         data_in = load_bytes(data_in)
      key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
@@ -165,6 +67,112 @@
   
   return key_list, data_list
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
   log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
   logging.basicConfig(level=log_level)
   import pdb;
   pdb.set_trace()
   model = AutoModel(**kwargs)
   res = model.generate(input=kwargs["input"])
   print(res)
class AutoModel:
   def __init__(self, **kwargs):
      registry_tables.print_register_tables()
      assert "model" in kwargs
      if "model_conf" not in kwargs:
         logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
         kwargs = download_model(**kwargs)
      set_all_random_seed(kwargs.get("seed", 0))
      device = kwargs.get("device", "cuda")
      if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
         device = "cpu"
         kwargs["batch_size"] = 1
      kwargs["device"] = device
      # build tokenizer
      tokenizer = kwargs.get("tokenizer", None)
      if tokenizer is not None:
         tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
         tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
         kwargs["tokenizer"] = tokenizer
      # build frontend
      frontend = kwargs.get("frontend", None)
      if frontend is not None:
         frontend_class = registry_tables.frontend_classes.get(frontend.lower())
         frontend = frontend_class(**kwargs["frontend_conf"])
         kwargs["frontend"] = frontend
      # build model
      model_class = registry_tables.model_classes.get(kwargs["model"].lower())
      model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
      model.eval()
      model.to(device)
      kwargs["token_list"] = tokenizer.token_list
      # init_param
      init_param = kwargs.get("init_param", None)
      if init_param is not None:
         logging.info(f"Loading pretrained params from {init_param}")
         load_pretrained_model(
            model=model,
            init_param=init_param,
            ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
            oss_bucket=kwargs.get("oss_bucket", None),
         )
      self.kwargs = kwargs
      self.model = model
      self.tokenizer = tokenizer
   def generate(self, input, input_len=None, **cfg):
      self.kwargs.update(cfg)
      data_type = self.kwargs.get("data_type", "sound")
      batch_size = self.kwargs.get("batch_size", 1)
      if self.kwargs.get("device", "cpu") == "cpu":
         batch_size = 1
      key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
      speed_stats = {}
      asr_result_list = []
      num_samples = len(data_list)
      pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
      for beg_idx in range(0, num_samples, batch_size):
         end_idx = min(num_samples, beg_idx + batch_size)
         data_batch = data_list[beg_idx:end_idx]
         key_batch = key_list[beg_idx:end_idx]
         batch = {"data_in": data_batch, "key": key_batch}
         if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
            batch["data_batch"] = data_batch[0]
            batch["data_lengths"] = input_len
         time1 = time.perf_counter()
         results, meta_data = self.model.generate(**batch, **self.kwargs)
         time2 = time.perf_counter()
         asr_result_list.append(results)
         pbar.update(1)
         # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
         batch_data_time = meta_data.get("batch_data_time", -1)
         speed_stats["load_data"] = meta_data.get("load_data", 0.0)
         speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
         speed_stats["forward"] = f"{time2 - time1:0.3f}"
         speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
         description = (
            f"{speed_stats}, "
         )
         pbar.set_description(description)
      torch.cuda.empty_cache()
      return asr_result_list
if __name__ == '__main__':
   main_hydra()