From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/datasets/large_datasets/utils/clipping.py |   10 +++++++---
 1 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
index f5c2940..92f7d70 100644
--- a/funasr/datasets/large_datasets/utils/clipping.py
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -1,7 +1,7 @@
 import numpy as np
 import torch
 
-from funasr.datasets.collate_fn import crop_to_max_size
+from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
 
 
 def clipping(data):
@@ -25,7 +25,9 @@
             tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
 
             length_clip = min(tensor_lengths)
-            tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+            tensor_clip = tensor_list[0].new_zeros(
+                len(tensor_list), length_clip, tensor_list[0].shape[1]
+            )
             for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
                 diff = length - length_clip
                 assert diff >= 0
@@ -35,6 +37,8 @@
                     tensor_clip[i] = crop_to_max_size(tensor, length_clip)
 
             batch[data_name] = tensor_clip
-            batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+            batch[data_name + "_lengths"] = torch.tensor(
+                [tensor.shape[0] for tensor in tensor_clip], dtype=torch.long
+            )
 
     return keys, batch

--
Gitblit v1.9.1