From 2a66366be4c2715870e4859fd5a5db6e8a9dc00a Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期四, 14 九月 2023 19:00:17 +0800
Subject: [PATCH] Merge pull request #956 from alibaba-damo-academy/chenmengzheAAA-patch-4

---
 funasr/samplers/length_batch_sampler.py |   14 ++++++++------
 1 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/funasr/samplers/length_batch_sampler.py b/funasr/samplers/length_batch_sampler.py
index cdf0e58..28404e3 100644
--- a/funasr/samplers/length_batch_sampler.py
+++ b/funasr/samplers/length_batch_sampler.py
@@ -1,9 +1,9 @@
 from typing import Iterator
 from typing import List
+from typing import Dict
 from typing import Tuple
 from typing import Union
 
-from typeguard import check_argument_types
 
 from funasr.fileio.read_text import load_num_sequence_text
 from funasr.samplers.abs_sampler import AbsSampler
@@ -13,14 +13,13 @@
     def __init__(
         self,
         batch_bins: int,
-        shape_files: Union[Tuple[str, ...], List[str]],
+        shape_files: Union[Tuple[str, ...], List[str], Dict],
         min_batch_size: int = 1,
         sort_in_batch: str = "descending",
         sort_batch: str = "ascending",
         drop_last: bool = False,
         padding: bool = True,
     ):
-        assert check_argument_types()
         assert batch_bins > 0
         if sort_batch != "ascending" and sort_batch != "descending":
             raise ValueError(
@@ -40,9 +39,12 @@
         # utt2shape: (Length, ...)
         #    uttA 100,...
         #    uttB 201,...
-        utt2shapes = [
-            load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
-        ]
+        if isinstance(shape_files, dict):
+            utt2shapes = [shape_files]
+        else:
+            utt2shapes = [
+                load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+            ]
 
         first_utt2shape = utt2shapes[0]
         for s, d in zip(shape_files, utt2shapes):

--
Gitblit v1.9.1