funasr/auto/auto_model.py
@@ -212,7 +212,6 @@ deep_update(model_conf, kwargs.get("model_conf", {})) deep_update(model_conf, kwargs) model = model_class(**model_conf, vocab_size=vocab_size) model.to(device) # init_param init_param = kwargs.get("init_param", None) @@ -235,6 +234,7 @@ model.to(torch.float16) elif kwargs.get("bf16", False): model.to(torch.bfloat16) model.to(device) return model, kwargs def __call__(self, *args, **cfg):