From b01c9f1c25282c8376f8e25eabcc6dd29d14ad13 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 17 六月 2024 14:08:57 +0800
Subject: [PATCH] decoding

---
 funasr/datasets/openai_datasets/datasets.py |    9 +++++----
 1 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index ae9f289..04ddcfd 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -300,9 +300,9 @@
         return len(self.index_ds)
 
     def __getitem__(self, index):
-        import pdb
-
-        pdb.set_trace()
+        # import pdb
+        #
+        # pdb.set_trace()
 
         output = None
 
@@ -397,6 +397,7 @@
                 labels += source_mask + target_ids
                 fbank.append(speech[0, :, :])
                 fbank_mask += fbank_mask_i
+                fbank_lens.append(speech_lengths)
 
             if len(input_ids) > self.max_token_length:
                 logging.info(
@@ -410,7 +411,7 @@
             labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]
 
             # fbank = speech[0, :, :]
-            fbank_lens = speech_lengths
+            # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
             fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
             fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
             fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)

--
Gitblit v1.9.1