shixian.shi
2024-01-10 668b830cb2a8f69c1cfb131ec9542d27f91b7283
funasr/bin/inference.py
@@ -159,6 +159,9 @@
         tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
         kwargs["tokenizer"] = tokenizer
         kwargs["token_list"] = tokenizer.token_list
         vocab_size = len(tokenizer.token_list)
      else:
         vocab_size = -1
      
      # build frontend
      frontend = kwargs.get("frontend", None)
@@ -170,8 +173,7 @@
      
      # build model
      model_class = tables.model_classes.get(kwargs["model"].lower())
      model = model_class(**kwargs, **kwargs["model_conf"],
                          vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
      model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
      model.eval()
      model.to(device)