kongdeqiang
6 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/whisper_lid/decoder.py
@@ -29,9 +29,7 @@
        super().__init__()
        assert whisper_model in whisper.available_models()
        _model = whisper.load_model(
            whisper_model, download_root=download_dir, device="cpu"
        )
        _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
        self.decoders = copy.deepcopy(_model.decoder)
        attention_dim = self.decoders.token_embedding.embedding_dim
@@ -67,10 +65,7 @@
            olens: (batch, )
        """
        tgt, memory = ys_in_pad, hs_pad
        tgt = (
            self.decoders.token_embedding(tgt)
            + self.decoders.positional_embedding[: tgt.size(1)]
        )
        tgt = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
        tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
@@ -81,15 +76,20 @@
            memory_mask = None
        for layer, block in enumerate(self.decoders.blocks):
            x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
            x = block(
                x,
                memory,
                mask=self.decoders.mask,
                memory_mask=memory_mask,
                is_pad_mask=False,
                is_pad_memory_mask=True,
            )
            if layer < len(self.decoders.blocks) - 1:
                x = self.dropout(x)
        x = self.decoders.ln(x)
        x = (
            x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        x = (x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return x, ys_in_lens
@@ -116,10 +116,7 @@
            cache implementation is ignored for now
            for simplicity & correctness
        """
        x = (
            self.decoders.token_embedding(tgt)
            + self.decoders.positional_embedding[: tgt.size(1)]
        )
        x = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
        x = self.dropout(x)
        x = x.to(memory.dtype)
@@ -130,9 +127,7 @@
        x = self.decoders.ln(x)
        y = x[:, -1]
        y = (
            y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        y = (y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
        y = torch.log_softmax(y, dim=-1)
        return y, None