From cdf117b9746fdb72c6d0a2aa1ada4e1a131895ec Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期二, 27 六月 2023 09:59:50 +0800
Subject: [PATCH] bug fix (#667)
---
funasr/models/encoder/conformer_encoder.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++-----
1 files changed, 50 insertions(+), 5 deletions(-)
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 5f20dee..994607f 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -1081,7 +1081,10 @@
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ if self.training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ else:
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@
elif self.dynamic_chunk_training:
max_len = x.size(1)
- chunk_size = torch.randint(1, max_len, (1,)).item()
+ if self.training:
+ chunk_size = torch.randint(1, max_len, (1,)).item()
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@
return x, olens, None
+ def full_utt_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ return x_utt
+
def simu_chunk_forward(
self,
x: torch.Tensor,
--
Gitblit v1.9.1