zhifu gao
2024-03-04 790bf549448c92f8a19ae1455ace15ff5d7a2e31
funasr/models/whisper/model.py
@@ -13,16 +13,29 @@
from funasr.register import tables
@tables.register("model_classes", "WhisperWarp")
@tables.register("model_classes", "Whisper-tiny.en")
@tables.register("model_classes", "Whisper-tiny")
@tables.register("model_classes", "Whisper-base.en")
@tables.register("model_classes", "Whisper-base")
@tables.register("model_classes", "Whisper-small.en")
@tables.register("model_classes", "Whisper-small")
@tables.register("model_classes", "Whisper-medium.en")
@tables.register("model_classes", "Whisper-medium")
@tables.register("model_classes", "Whisper-large-v1")
@tables.register("model_classes", "Whisper-large-v2")
@tables.register("model_classes", "Whisper-large-v3")
@tables.register("model_classes", "Whisper-WhisperWarp")
class WhisperWarp(nn.Module):
    def __init__(self, whisper_dims: dict, **kwargs):
    def __init__(self, *args, **kwargs):
        super().__init__()
        hub = kwargs.get("hub", "funasr")
        if hub == "openai":
            init_param_path = kwargs.get("init_param_path", "large-v3")
            model = whisper.load_model(init_param_path)
            model_or_path = kwargs.get("model_path", "Whisper-large-v3")
            if model_or_path.startswith("Whisper-"):
                model_or_path = model_or_path.replace("Whisper-", "")
            model = whisper.load_model(model_or_path)
        else:
            whisper_dims = kwargs.get("whisper_dims", {})
            dims = whisper.model.ModelDimensions(**whisper_dims)
            model = whisper.model.Whisper(dims=dims)