feat: add campplus merge_thr (#2135)
| | |
| | | # if spk_model is not None, build spk model else None |
| | | spk_model = kwargs.get("spk_model", None) |
| | | spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) |
| | | cb_kwargs = {} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {}) |
| | | if spk_model is not None: |
| | | logging.info("Building SPK model.") |
| | | spk_kwargs["model"] = spk_model |
| | | spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") |
| | | spk_kwargs["device"] = kwargs["device"] |
| | | spk_model, spk_kwargs = self.build_model(**spk_kwargs) |
| | | self.cb_model = ClusterBackend().to(kwargs["device"]) |
| | | self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) |
| | | spk_mode = kwargs.get("spk_mode", "punc_segment") |
| | | if spk_mode not in ["default", "vad_segment", "punc_segment"]: |
| | | logging.error("spk_mode should be one of default, vad_segment and punc_segment.") |
| | |
| | | model_config: The model config. |
| | | """ |
| | | |
| | | def __init__(self): |
| | | def __init__(self, merge_thr=0.78): |
| | | super().__init__() |
| | | self.model_config = {"merge_thr": 0.78} |
| | | self.model_config = {"merge_thr": merge_thr} |
| | | # self.other_config = kwargs |
| | | |
| | | self.spectral_cluster = SpectralCluster() |
| New file |
| | |
| | | import unittest |
| | | import torch |
| | | import numpy as np |
| | | from funasr.auto.auto_model import AutoModel |
| | | |
| | | class TestAutoModel(unittest.TestCase): |
| | | |
| | | def setUp(self): |
| | | self.base_kwargs = { |
| | | "model": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | "vad_model": "fsmn-vad", |
| | | "punc_model":"ct-punc", |
| | | "device": "cpu", |
| | | "batch_size": 1, |
| | | "disable_update": True, |
| | | } |
| | | |
| | | def test_merge_thr_in_cb_model(self): |
| | | kwargs = self.base_kwargs.copy() |
| | | kwargs["spk_model"] = "cam++" |
| | | merge_thr = 0.5 |
| | | kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}} |
| | | model = AutoModel(**kwargs) |
| | | self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr) |
| | | # res = model.generate(input="/test.wav", |
| | | # batch_size_s=300) |
| | | if __name__ == '__main__': |
| | | unittest.main() |