游雁
2023-02-14 1d4ab65c8bfebaecbcb0eec0064bae9a321cad75
funasr/samplers/length_batch_sampler.py
@@ -1,5 +1,6 @@
from typing import Iterator
from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
@@ -13,7 +14,7 @@
    def __init__(
        self,
        batch_bins: int,
        shape_files: Union[Tuple[str, ...], List[str]],
        shape_files: Union[Tuple[str, ...], List[str], Dict],
        min_batch_size: int = 1,
        sort_in_batch: str = "descending",
        sort_batch: str = "ascending",
@@ -40,9 +41,12 @@
        # utt2shape: (Length, ...)
        #    uttA 100,...
        #    uttB 201,...
        utt2shapes = [
            load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
        ]
        if isinstance(shape_files, dict):
            utt2shapes = [shape_files]
        else:
            utt2shapes = [
                load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
            ]
        first_utt2shape = utt2shapes[0]
        for s, d in zip(shape_files, utt2shapes):