From 7ae979bc5e9a9a09236848cc879c2cbd2bfa0837 Mon Sep 17 00:00:00 2001
From: ShiLiang Zhang <sly.zsl@alibaba-inc.com>
Date: 星期三, 08 五月 2024 17:17:38 +0800
Subject: [PATCH] Update README.md

---
 funasr/frontends/s3prl.py |   20 ++++++++++----------
 1 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/funasr/frontends/s3prl.py b/funasr/frontends/s3prl.py
index ff60592..5c358cc 100644
--- a/funasr/frontends/s3prl.py
+++ b/funasr/frontends/s3prl.py
@@ -29,11 +29,11 @@
     """Speech Pretrained Representation frontend structure for ASR."""
 
     def __init__(
-            self,
-            fs: Union[int, str] = 16000,
-            frontend_conf: Optional[dict] = None,
-            download_dir: str = None,
-            multilayer_feature: bool = False,
+        self,
+        fs: Union[int, str] = 16000,
+        frontend_conf: Optional[dict] = None,
+        download_dir: str = None,
+        multilayer_feature: bool = False,
     ):
         super().__init__()
         if isinstance(fs, str):
@@ -74,7 +74,7 @@
         ).to("cpu")
 
         if getattr(
-                s3prl_upstream, "model", None
+            s3prl_upstream, "model", None
         ) is not None and s3prl_upstream.model.__class__.__name__ in [
             "Wav2Vec2Model",
             "HubertModel",
@@ -102,9 +102,9 @@
         Output - sequence of tiled representations
                  shape: (batch_size, seq_len * factor, feature_dim)
         """
-        assert (
-                len(feature.shape) == 3
-        ), "Input argument `feature` has invalid shape: {}".format(feature.shape)
+        assert len(feature.shape) == 3, "Input argument `feature` has invalid shape: {}".format(
+            feature.shape
+        )
         tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
         tiled_feature = tiled_feature.reshape(
             feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
@@ -115,7 +115,7 @@
         return self.output_dim
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+        self, input: torch.Tensor, input_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
         self.upstream.eval()

--
Gitblit v1.9.1