游雁
2024-04-29 2779602177ae5374547c7a7e17de0b11a166326d
funasr/frontends/s3prl.py
@@ -29,11 +29,11 @@
    """Speech Pretrained Representation frontend structure for ASR."""
    def __init__(
            self,
            fs: Union[int, str] = 16000,
            frontend_conf: Optional[dict] = None,
            download_dir: str = None,
            multilayer_feature: bool = False,
        self,
        fs: Union[int, str] = 16000,
        frontend_conf: Optional[dict] = None,
        download_dir: str = None,
        multilayer_feature: bool = False,
    ):
        super().__init__()
        if isinstance(fs, str):
@@ -74,7 +74,7 @@
        ).to("cpu")
        if getattr(
                s3prl_upstream, "model", None
            s3prl_upstream, "model", None
        ) is not None and s3prl_upstream.model.__class__.__name__ in [
            "Wav2Vec2Model",
            "HubertModel",
@@ -102,9 +102,9 @@
        Output - sequence of tiled representations
                 shape: (batch_size, seq_len * factor, feature_dim)
        """
        assert (
                len(feature.shape) == 3
        ), "Input argument `feature` has invalid shape: {}".format(feature.shape)
        assert len(feature.shape) == 3, "Input argument `feature` has invalid shape: {}".format(
            feature.shape
        )
        tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
        tiled_feature = tiled_feature.reshape(
            feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
@@ -115,7 +115,7 @@
        return self.output_dim
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
        self.upstream.eval()