| | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | from scipy.signal import medfilt |
| | | |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.tasks.diar import EENDOLADiarTask |
| | |
| | | # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | results = speech2diar(**batch) |
| | | |
| | | # post process |
| | | a = medfilt(results[0], (11, 1)) |
| | | rst = [] |
| | | for spkid, frames in enumerate(a.T): |
| | | frames = np.pad(frames, (1, 1), 'constant') |
| | | changes, = np.where(np.diff(frames, axis=0) != 0) |
| | | fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>" |
| | | for s, e in zip(changes[::2], changes[1::2]): |
| | | st = s / 10. |
| | | ed = e / 10. |
| | | rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid)))) |
| | | |
| | | # Only supporting batch_size==1 |
| | | key, value = keys[0], output_results_str(results, keys[0]) |
| | | item = {"key": key, "value": value} |
| | | value = "\n".join(rst) |
| | | item = {"key": keys[0], "value": value} |
| | | result_list.append(item) |
| | | if output_path is not None: |
| | | output_writer.write(value) |