From e971e000ad582c767ae44c9650470899f5bb46d0 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 01:11:18 +0800
Subject: [PATCH] Dev gzf exp (#1663)

---
 funasr/datasets/sense_voice_datasets/datasets.py |   32 ++++++++++++++++++++++++++++++++
 1 files changed, 32 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 5468ea6..4f14b35 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -1,3 +1,5 @@
+import logging
+
 import torch
 import random
 
@@ -46,6 +48,8 @@
         self.float_pad_value = float_pad_value
         self.sos = kwargs.get("sos", "<|startoftranscript|>")
         self.eos = kwargs.get("eos", "<|endoftext|>")
+        self.batch_size = kwargs.get("batch_size")
+        self.batch_type = kwargs.get("batch_type")
 
     def get_source_len(self, index):
         item = self.index_ds[index]
@@ -124,4 +128,32 @@
                 outputs[key] = torch.nn.utils.rnn.pad_sequence(
                     data_list, batch_first=True, padding_value=pad_value
                 )
+
+        if self.batch_type != "example":
+            b, t, _ = outputs["speech"].shape
+            if b * t > self.batch_size:
+                beg = torch.randint(0, 2, ()).item()
+                logging.info(
+                    f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 1st, beg:{beg}"
+                )
+                for key, data_list in outputs.items():
+                    outputs[key] = outputs[key][beg : beg + b : 2]
+
+            b, t, _ = outputs["speech"].shape
+            if b * t > self.batch_size:
+                beg = torch.randint(0, 2, ()).item()
+                logging.info(
+                    f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 2nd, beg:{beg}"
+                )
+                for key, data_list in outputs.items():
+                    outputs[key] = outputs[key][beg : beg + b : 2]
+
+            b, t, _ = outputs["speech"].shape
+            if b * t > self.batch_size:
+                beg = torch.randint(0, 2, ()).item()
+                logging.info(
+                    f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 3th, beg:{beg}"
+                )
+                for key, data_list in outputs.items():
+                    outputs[key] = outputs[key][beg : beg + b : 2]
         return outputs

--
Gitblit v1.9.1