| | |
| | | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| | | logging.basicConfig(level=log_level) |
| | | |
| | | if not kwargs.get("disable_log", True): |
| | | tables.print() |
| | | |
| | | model, kwargs = self.build_model(**kwargs) |
| | | |
| | | # if vad_model is not None, build vad model else None |
| | |
| | | self.spk_kwargs = spk_kwargs |
| | | self.model_path = kwargs.get("model_path") |
| | | |
| | | def build_model(self, **kwargs): |
| | | @staticmethod |
| | | def build_model(**kwargs): |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) |
| | |
| | | kwargs["frontend"] = frontend |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | assert model_class is not None, f'{kwargs["model"]} is not registered' |
| | | model_conf = {} |
| | | deep_update(model_conf, kwargs.get("model_conf", {})) |
| | | deep_update(model_conf, kwargs) |
| | |
| | | elif kwargs.get("bf16", False): |
| | | model.to(torch.bfloat16) |
| | | model.to(device) |
| | | |
| | | if not kwargs.get("disable_log", True): |
| | | tables.print() |
| | | |
| | | return model, kwargs |
| | | |
| | | def __call__(self, *args, **cfg): |