| | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | |
| | | @tables.register("model_classes", "WhisperWarp") |
| | | @tables.register("model_classes", "Whisper-tiny.en") |
| | | @tables.register("model_classes", "Whisper-tiny") |
| | | @tables.register("model_classes", "Whisper-base.en") |
| | | @tables.register("model_classes", "Whisper-base") |
| | | @tables.register("model_classes", "Whisper-small.en") |
| | | @tables.register("model_classes", "Whisper-small") |
| | | @tables.register("model_classes", "Whisper-medium.en") |
| | | @tables.register("model_classes", "Whisper-medium") |
| | | @tables.register("model_classes", "Whisper-large-v1") |
| | | @tables.register("model_classes", "Whisper-large-v2") |
| | | @tables.register("model_classes", "Whisper-large-v3") |
| | | @tables.register("model_classes", "Whisper-WhisperWarp") |
| | | class WhisperWarp(nn.Module): |
| | | def __init__(self, whisper_dims: dict, **kwargs): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | hub = kwargs.get("hub", "funasr") |
| | | if hub == "openai": |
| | | init_param_path = kwargs.get("init_param_path", "large-v3") |
| | | model = whisper.load_model(init_param_path) |
| | | model_or_path = kwargs.get("model_path", "Whisper-large-v3") |
| | | if model_or_path.startswith("Whisper-"): |
| | | model_or_path = model_or_path.replace("Whisper-", "") |
| | | model = whisper.load_model(model_or_path) |
| | | else: |
| | | whisper_dims = kwargs.get("whisper_dims", {}) |
| | | dims = whisper.model.ModelDimensions(**whisper_dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |