From fe8ebd746bf0c0f57ef85ed342500cbf0e2c4e9e Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期二, 23 七月 2024 16:59:57 +0800
Subject: [PATCH] update gitignore
---
funasr/frontends/s3prl.py | 20 ++++++++++----------
1 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/funasr/frontends/s3prl.py b/funasr/frontends/s3prl.py
index ff60592..5c358cc 100644
--- a/funasr/frontends/s3prl.py
+++ b/funasr/frontends/s3prl.py
@@ -29,11 +29,11 @@
"""Speech Pretrained Representation frontend structure for ASR."""
def __init__(
- self,
- fs: Union[int, str] = 16000,
- frontend_conf: Optional[dict] = None,
- download_dir: str = None,
- multilayer_feature: bool = False,
+ self,
+ fs: Union[int, str] = 16000,
+ frontend_conf: Optional[dict] = None,
+ download_dir: str = None,
+ multilayer_feature: bool = False,
):
super().__init__()
if isinstance(fs, str):
@@ -74,7 +74,7 @@
).to("cpu")
if getattr(
- s3prl_upstream, "model", None
+ s3prl_upstream, "model", None
) is not None and s3prl_upstream.model.__class__.__name__ in [
"Wav2Vec2Model",
"HubertModel",
@@ -102,9 +102,9 @@
Output - sequence of tiled representations
shape: (batch_size, seq_len * factor, feature_dim)
"""
- assert (
- len(feature.shape) == 3
- ), "Input argument `feature` has invalid shape: {}".format(feature.shape)
+ assert len(feature.shape) == 3, "Input argument `feature` has invalid shape: {}".format(
+ feature.shape
+ )
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
tiled_feature = tiled_feature.reshape(
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
@@ -115,7 +115,7 @@
return self.output_dim
def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
+ self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
self.upstream.eval()
--
Gitblit v1.9.1