| | |
| | | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| | | ): |
| | | super().__init__() |
| | | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) |
| | | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1) |
| | | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) |
| | | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) |
| | | |
| | |
| | | x = F.gelu(self.conv2(x)) |
| | | x = x.permute(0, 2, 1) |
| | | |
| | | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| | | x = (x + self.positional_embedding).to(x.dtype) |
| | | # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| | | # x = (x + self.positional_embedding).to(x.dtype) |
| | | x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype) |
| | | |
| | | |
| | | for block in self.blocks: |
| | | x = block(x) |
| | |
| | | |
| | | detect_language = detect_language_function |
| | | transcribe = transcribe_function |
| | | decode = decode_function |
| | | decode = decode_function |