| | |
| | | |
| | | |
| | | from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @dataclass |
| | | class ModelDimensions: |
| | |
| | | return x |
| | | |
| | | |
| | | |
| | | @tables.register("encoder_classes", "WhisperEncoder") |
| | | class AudioEncoder(nn.Module): |
| | | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| | | super().__init__() |
| | |
| | | x = self.ln_post(x) |
| | | return x |
| | | |
| | | |
| | | @tables.register("decoder_classes", "WhisperDecoder") |
| | | class TextDecoder(nn.Module): |
| | | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| | | super().__init__() |
| | |
| | | |
| | | return logits |
| | | |
| | | |
| | | @tables.register("model_classes", "Whisper") |
| | | class Whisper(nn.Module): |
| | | def __init__(self, dims: dict): |
| | | super().__init__() |