liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/whisper_lid/encoder.py
@@ -22,13 +22,13 @@
    """
    def __init__(
            self,
            dropout_rate: float = 0.0,
            whisper_model: str = "small",
            download_dir: str = None,
            use_specaug: bool = False,
            use_padmask: bool = False,
            specaug_conf: Union[dict, None] = None,
        self,
        dropout_rate: float = 0.0,
        whisper_model: str = "small",
        download_dir: str = None,
        use_specaug: bool = False,
        use_padmask: bool = False,
        specaug_conf: Union[dict, None] = None,
    ):
        super().__init__()
@@ -36,9 +36,7 @@
        self.dropout = torch.nn.Dropout(dropout_rate)
        assert whisper_model in whisper.available_models()
        _model = whisper.load_model(
            whisper_model, download_root=download_dir, device="cpu"
        )
        _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
        self.encoders = copy.deepcopy(_model.encoder)
        self.encoders.train()
@@ -51,9 +49,9 @@
        self.use_padmask = use_padmask
    def whisper_encode(
            self,
            input: torch.Tensor,
            ilens: torch.Tensor = None,
        self,
        input: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        x = F.gelu(self.encoders.conv1(input))
        x = F.gelu(self.encoders.conv2(x))
@@ -69,13 +67,9 @@
        if ilens is not None:
            olens = (
                    1
                    + (
                            ilens
                            - self.encoders.conv2.kernel_size[0]
                            + 2 * self.encoders.conv2.padding[0]
                    )
                    // self.encoders.conv2.stride[0]
                1
                + (ilens - self.encoders.conv2.kernel_size[0] + 2 * self.encoders.conv2.padding[0])
                // self.encoders.conv2.stride[0]
            )
            olens = torch.clamp(olens, max=max_pos)
        else:
@@ -102,10 +96,10 @@
        return self.encoders.conv2.weight.shape[0]
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
            prev_states: torch.Tensor = None,
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        feats, feats_lens = xs_pad, ilens