speech_asr
2023-03-15 fbec0f003d4de9e4b6ccb6bb58d2d4926a0ff332
funasr/bin/eend_ola_inference.py
@@ -27,9 +27,6 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from modelscope.utils.logger import get_logger
logger = get_logger()
class Speech2Diarization:
    """Speech2Diarlization class
@@ -148,7 +145,7 @@
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        ngpu: int = 1,
        num_workers: int = 0,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
@@ -210,8 +207,7 @@
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        logger.info(data_path_and_name_and_type)
            data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
@@ -231,8 +227,6 @@
            output_writer = open("{}/result.txt".format(output_path), "w")
        result_list = []
        for keys, batch in loader:
            logger.info("keys: {}".format(keys))
            logger.info("batch: {}".format(batch))
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))