From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/datasets/large_datasets/utils/clipping.py |   40 ++++++++++++++++++++++++++++++++++++++++
 1 files changed, 40 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
new file mode 100644
index 0000000..2554aba
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+    assert isinstance(data, list)
+    assert "key" in data[0]
+
+    keys = [x["key"] for x in data]
+
+    batch = {}
+    data_names = data[0].keys()
+    for data_name in data_names:
+        if data_name == "key":
+            continue
+        else:
+            if data[0][data_name].dtype.kind == "i":
+                tensor_type = torch.int64
+            else:
+                tensor_type = torch.float32
+
+            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+            length_clip = min(tensor_lengths)
+            tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+            for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+                diff = length - length_clip
+                assert diff >= 0
+                if diff == 0:
+                    tensor_clip[i] = tensor
+                else:
+                    tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+            batch[data_name] = tensor_clip
+            batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+    return keys, batch

--
Gitblit v1.9.1