| | |
| | | ss_results = speech_separator(**batch) |
| | | |
| | | for spk in range(num_spks): |
| | | sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate) |
| | | sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate) |
| | | torch.cuda.empty_cache() |
| | | return ss_results |
| | | |
| | |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | |
| | | def load_checkpoint(checkpoint_path, use_cuda=1): |
| | | if use_cuda: |
| | | checkpoint = torch.load(checkpoint_path) |
| | | else: |
| | | checkpoint = torch.load( |
| | | checkpoint_path, map_location=lambda storage, loc: storage) |
| | | return checkpoint |
| | | |
| | | def reload_ss_for_eval(model, checkpoint_path, use_cuda=False): |
| | | checkpoint = load_checkpoint(checkpoint_path, use_cuda) |
| | | model.load_state_dict(checkpoint['model'], strict=False) |
| | | |
| | | def build_model_from_file( |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | if task_name == 'ss': |
| | | reload_ss_for_eval(model, model_file, use_cuda=True) |
| | | logging.info("model is loaded from path: {}".format(model_file)) |
| | | if task_name == "ss": |
| | | model_dict = model_dict['model'] |
| | | if task_name == "diar" and mode == "sond": |
| | | model_dict = fileter_model_dict(model_dict, model.state_dict()) |
| | | if task_name == "vad": |