From 8c24f52fd25fce2df46dc4d9dffb45619dc38f9f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期三, 24 五月 2023 23:40:03 +0800
Subject: [PATCH] Dev aky2 (#549)
---
funasr/datasets/large_datasets/build_dataloader.py | 4 ++--
funasr/datasets/large_datasets/utils/tokenize.py | 4 +---
funasr/modules/subsampling.py | 6 +++---
3 files changed, 6 insertions(+), 8 deletions(-)
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 318ae0b..339292f 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -73,8 +73,8 @@
seg_dict = load_seg_dict(args.seg_dict_file)
if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
punc_dict = read_symbol_table(args.punc_dict_file)
- if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
- bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
+ if hasattr(args, "bpemodel") and args.bpemodel is not None:
+ bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
self.dataset_conf = args.dataset_conf
self.frontend_conf = args.frontend_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index cf7d255..a7eb6d2 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -46,10 +46,8 @@
text = data["text"]
token = []
vad = -2
-
if bpe_tokenizer is not None:
- text = bpe_tokenizer.text2tokens(text)
-
+ text = bpe_tokenizer.text2tokens(" ".join(text))
if seg_dict is not None:
assert isinstance(seg_dict, dict)
text = seg_tokenize(text, seg_dict)
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index 623be65..a2b91a7 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -506,9 +506,9 @@
)
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size, 3, 2),
+ torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+ torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
torch.nn.ReLU(),
)
@@ -597,7 +597,7 @@
mask: Mask of output sequences. (B, sub(T))
"""
if self.subsampling_factor > 1:
- return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+ return mask[:, ::2][:, ::self.stride_2]
else:
return mask
--
Gitblit v1.9.1