游雁
2023-12-21 c8bae0ec85eee25d66de6b1e4502eff74d750b24
funasr/bin/inference.py
@@ -101,6 +101,7 @@
         tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
         tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
         kwargs["tokenizer"] = tokenizer
         kwargs["token_list"] = tokenizer.token_list
      
      # build frontend
      frontend = kwargs.get("frontend", None)
@@ -112,11 +113,9 @@
      
      # build model
      model_class = registry_tables.model_classes.get(kwargs["model"].lower())
      model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
      model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
      model.eval()
      model.to(device)
      kwargs["token_list"] = tokenizer.token_list
      
      # init_param
      init_param = kwargs.get("init_param", None)