funasr/auto/auto_model.py
@@ -213,7 +213,6 @@ deep_update(model_conf, kwargs.get("model_conf", {})) deep_update(model_conf, kwargs) model = model_class(**model_conf, vocab_size=vocab_size) model.to(device) # init_param init_param = kwargs.get("init_param", None) @@ -236,6 +235,7 @@ model.to(torch.float16) elif kwargs.get("bf16", False): model.to(torch.bfloat16) model.to(device) return model, kwargs def __call__(self, *args, **cfg):