From 8c24f52fd25fce2df46dc4d9dffb45619dc38f9f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期三, 24 五月 2023 23:40:03 +0800
Subject: [PATCH] Dev aky2 (#549)

---
 funasr/datasets/large_datasets/build_dataloader.py |    4 ++--
 funasr/datasets/large_datasets/utils/tokenize.py   |    4 +---
 funasr/modules/subsampling.py                      |    6 +++---
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 318ae0b..339292f 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -73,8 +73,8 @@
             seg_dict = load_seg_dict(args.seg_dict_file)
         if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
             punc_dict = read_symbol_table(args.punc_dict_file)
-        if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
-            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
+        if hasattr(args, "bpemodel") and args.bpemodel is not None:
+            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
         self.dataset_conf = args.dataset_conf
         self.frontend_conf = args.frontend_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index cf7d255..a7eb6d2 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -46,10 +46,8 @@
     text = data["text"]
     token = []
     vad = -2
-
     if bpe_tokenizer is not None:
-        text = bpe_tokenizer.text2tokens(text)
-
+        text = bpe_tokenizer.text2tokens(" ".join(text))
     if seg_dict is not None:
         assert isinstance(seg_dict, dict)
         text = seg_tokenize(text, seg_dict)
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index 623be65..a2b91a7 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -506,9 +506,9 @@
                 )
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size, 3, 2),
+                    torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
                     torch.nn.ReLU(),
                 )
 
@@ -597,7 +597,7 @@
             mask: Mask of output sequences. (B, sub(T))
         """
         if self.subsampling_factor > 1:
-            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+            return mask[:, ::2][:, ::self.stride_2]
         else:
             return mask
 

--
Gitblit v1.9.1