VirtuosoQ
2024-04-26 e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc
funasr/models/sense_voice/whisper_lib/model.py
@@ -74,7 +74,10 @@
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        q = self.query(x)
        if kv_cache is None or xa is None or self.key not in kv_cache:
@@ -87,12 +90,13 @@
            k = kv_cache[self.key]
            v = kv_cache[self.value]
        wv, qk = self.qkv_attention(q, k, v, mask)
        wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
        return self.out(wv), qk
    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -101,10 +105,20 @@
        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
            if not is_pad_mask:
                qk = qk + mask[:n_ctx, :n_ctx]
            else:
                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
                min_value = float(
                    np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
                )
                qk = qk.masked_fill(mask, min_value)
        qk = qk.float()
        w = F.softmax(qk, dim=-1).to(q.dtype)
        if mask is not None and is_pad_mask:
            w = w.masked_fill(mask, 0.0)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
@@ -132,10 +146,13 @@
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
@@ -145,7 +162,7 @@
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
@@ -163,8 +180,10 @@
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding).to(x.dtype)
        # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        # x = (x + self.positional_embedding).to(x.dtype)
        x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
        for block in self.blocks:
            x = block(x)
@@ -242,7 +261,9 @@
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
        # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
        # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
        # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
@@ -311,4 +332,4 @@
    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function
    decode = decode_function