From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages
---
funasr/samplers/length_batch_sampler.py | 14 ++++++++------
1 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/funasr/samplers/length_batch_sampler.py b/funasr/samplers/length_batch_sampler.py
index cdf0e58..28404e3 100644
--- a/funasr/samplers/length_batch_sampler.py
+++ b/funasr/samplers/length_batch_sampler.py
@@ -1,9 +1,9 @@
from typing import Iterator
from typing import List
+from typing import Dict
from typing import Tuple
from typing import Union
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@@ -13,14 +13,13 @@
def __init__(
self,
batch_bins: int,
- shape_files: Union[Tuple[str, ...], List[str]],
+ shape_files: Union[Tuple[str, ...], List[str], Dict],
min_batch_size: int = 1,
sort_in_batch: str = "descending",
sort_batch: str = "ascending",
drop_last: bool = False,
padding: bool = True,
):
- assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(
@@ -40,9 +39,12 @@
# utt2shape: (Length, ...)
# uttA 100,...
# uttB 201,...
- utt2shapes = [
- load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
- ]
+ if isinstance(shape_files, dict):
+ utt2shapes = [shape_files]
+ else:
+ utt2shapes = [
+ load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+ ]
first_utt2shape = utt2shapes[0]
for s, d in zip(shape_files, utt2shapes):
--
Gitblit v1.9.1