From fc9595625855be5b63f86a38ac785e49c142c0ae Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期二, 21 三月 2023 14:10:03 +0800
Subject: [PATCH] embed debug
---
funasr/models_transducer/encoder/encoder.py | 12 +++++++-----
1 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models_transducer/encoder/encoder.py
index 45c99c1..b486a11 100644
--- a/funasr/models_transducer/encoder/encoder.py
+++ b/funasr/models_transducer/encoder/encoder.py
@@ -134,14 +134,11 @@
)
mask = make_source_mask(x_len)
- if self.unified_model_training:
- x, mask = self.embed(x, mask, self.default_chunk_size)
- else:
- x, mask = self.embed(x, mask)
- pos_enc = self.pos_enc(x)
if self.unified_model_training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
@@ -178,6 +175,9 @@
else:
chunk_size = (chunk_size % self.short_chunk_size) + 1
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
@@ -185,6 +185,8 @@
device=x.device,
)
else:
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
chunk_mask = None
x = self.encoders(
x,
--
Gitblit v1.9.1