From ab2148ec182db2b5d18d38f2dd690922d2cf56b8 Mon Sep 17 00:00:00 2001
From: Logan <52682203+ding-ding666@users.noreply.github.com>
Date: 星期一, 26 五月 2025 14:34:47 +0800
Subject: [PATCH] 更新go client 的原生实现 (#2532)

---
 funasr/frontends/s3prl.py |   20 ++++++++++----------
 1 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/funasr/frontends/s3prl.py b/funasr/frontends/s3prl.py
index ff60592..5c358cc 100644
--- a/funasr/frontends/s3prl.py
+++ b/funasr/frontends/s3prl.py
@@ -29,11 +29,11 @@
     """Speech Pretrained Representation frontend structure for ASR."""
 
     def __init__(
-            self,
-            fs: Union[int, str] = 16000,
-            frontend_conf: Optional[dict] = None,
-            download_dir: str = None,
-            multilayer_feature: bool = False,
+        self,
+        fs: Union[int, str] = 16000,
+        frontend_conf: Optional[dict] = None,
+        download_dir: str = None,
+        multilayer_feature: bool = False,
     ):
         super().__init__()
         if isinstance(fs, str):
@@ -74,7 +74,7 @@
         ).to("cpu")
 
         if getattr(
-                s3prl_upstream, "model", None
+            s3prl_upstream, "model", None
         ) is not None and s3prl_upstream.model.__class__.__name__ in [
             "Wav2Vec2Model",
             "HubertModel",
@@ -102,9 +102,9 @@
         Output - sequence of tiled representations
                  shape: (batch_size, seq_len * factor, feature_dim)
         """
-        assert (
-                len(feature.shape) == 3
-        ), "Input argument `feature` has invalid shape: {}".format(feature.shape)
+        assert len(feature.shape) == 3, "Input argument `feature` has invalid shape: {}".format(
+            feature.shape
+        )
         tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
         tiled_feature = tiled_feature.reshape(
             feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
@@ -115,7 +115,7 @@
         return self.output_dim
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+        self, input: torch.Tensor, input_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
         self.upstream.eval()

--
Gitblit v1.9.1