雾聪
2024-01-08 2acef4bdaea588adee3098a057a395937dff4e6a
funasr/bin/diar_inference_launch.py
@@ -15,10 +15,10 @@
from typing import Union
import numpy as np
import soundfile
# import librosa
import librosa
import torch
from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
from funasr.datasets.iterable_dataset import load_bytes
@@ -52,7 +52,6 @@
        mode: str = "sond",
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
@@ -94,10 +93,7 @@
            embedding_node="resnet1_dense"
        )
        logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
        speech2xvector = Speech2Xvector.from_pretrained(
            model_tag=model_tag,
            **speech2xvector_kwargs,
        )
        speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
        speech2xvector.sv_model.eval()
    # 2b. Build speech2diar
@@ -111,10 +107,7 @@
        dur_threshold=dur_threshold,
    )
    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
    speech2diar = Speech2DiarizationSOND.from_pretrained(
        model_tag=model_tag,
        **speech2diar_kwargs,
    )
    speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
    speech2diar.diar_model.eval()
    def output_results_str(results: dict, uttid: str):
@@ -152,7 +145,9 @@
                        # read waveform file
                        example = [load_bytes(x) if isinstance(x, bytes) else x
                                   for x in example]
                        example = [soundfile.read(x)[0] if isinstance(x, str) else x
                        # example = [librosa.load(x)[0] if isinstance(x, str) else x
                        #            for x in example]
                        example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
                                   for x in example]
                        # convert torch tensor to numpy array
                        example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
@@ -233,7 +228,6 @@
        param_dict: Optional[dict] = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
@@ -260,10 +254,7 @@
        dtype=dtype,
    )
    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
    speech2diar = Speech2DiarizationEEND.from_pretrained(
        model_tag=model_tag,
        **speech2diar_kwargs,
    )
    speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
    speech2diar.diar_model.eval()
    def output_results_str(results: dict, uttid: str):
@@ -465,11 +456,17 @@
        help="The batch size for inference",
    )
    group.add_argument(
        "--diar_smooth_size",
        "--smooth_size",
        type=int,
        default=121,
        help="The smoothing size for post-processing"
    )
    group.add_argument(
        "--dur_threshold",
        type=int,
        default=10,
        help="The threshold of minimum duration"
    )
    return parser