hnluo
2023-08-10 ea2c102e6162c924c682aabfe8a052ce9a766a4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
 
 
import logging
from pathlib import Path
from typing import List
from typing import Union
 
import numpy as np
import torch
 
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
 
 
class SpeechSeparator:
    """SpeechSeparator class
 
    Examples:
        >>> import soundfile
        >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> separated_wavs = speech_separator(audio)        
 
    """
 
    def __init__(
            self,
            ss_infer_config: Union[Path, str] = None,
            ss_model_file: Union[Path, str] = None,
            device: str = "cpu",
            batch_size: int = 1,
            dtype: str = "float32",
            **kwargs,
    ):
 
        # 1. Build ss model
        ss_model, ss_infer_args = build_model_from_file(
            ss_infer_config, ss_model_file, None, device, task_name="ss"
        )
 
        logging.info("ss_model: {}".format(ss_model))
        logging.info("ss_infer_args: {}".format(ss_infer_args))
 
        ss_model.to(dtype=getattr(torch, dtype)).eval()
 
        self.ss_model = ss_model
        self.ss_infer_args = ss_infer_args
        self.device = device
        self.dtype = dtype
        self.batch_size = batch_size
 
    def decode(self, model, args, inputs, nsamples):
        decode_do_segment = False
        with torch.no_grad():       
            out = []
            window = args.sample_rate * args.decode_window  # decoding window length
            stride = int(window*0.75)  # decoding stride if segmentation is used
            b, t = inputs.shape
            if t > window * args.one_time_decode_length:
                decode_do_segment = True  # set segment decoding to true for very long sequence
 
            if t < window:
                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
            elif t < window + stride:
                padding = window + stride - t
                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
            else:
                if (t - window) % stride != 0:
                    padding = t - (t-window)//stride * stride
                    inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
            inputs = torch.from_numpy(np.float32(inputs))
            inputs = to_device(inputs, device=self.device)
            b, t = inputs.shape
            if decode_do_segment:
                outputs = np.zeros((args.num_spks, t))
                give_up_length = (window - stride)//2
                current_idx = 0
                while current_idx + window <= t:
                    tmp_input = inputs[:, current_idx:current_idx+window]
                    tmp_out_list = model(tmp_input,)
                    for spk in range(args.num_spks):
                        tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
                        if current_idx == 0:
                            outputs[spk, current_idx:current_idx+window-give_up_length] = \
                                tmp_out_list[spk][:-give_up_length]
                        else:
                            outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
                                tmp_out_list[spk][give_up_length:-give_up_length]
                    current_idx += stride
                for spk in range(args.num_spks):
                    out.append(outputs[spk, :])
            else:
                out_list = model(inputs)
                for spk in range(args.num_spks):
                    out.append(out_list[spk][0, :].cpu().numpy())
 
            max_abs = 0
            for spk in range(args.num_spks):
                if max_abs < max(abs(out[spk])):
                    max_abs = max(abs(out[spk]))
            for spk in range(args.num_spks):
                out[spk] = out[spk][:nsamples]
                out[spk] = out[spk]/max_abs
 
        return out
 
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
    ) -> List[torch.Tensor]:
        """Inference
 
        Args:
            speech: Input speech data
        Returns:
            speech list: list of speech data
 
        """
 
        out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
 
        return out