zhifu gao
2023-03-15 729ac44000a5c4aa23dcb1a3b80adc119a350b23
Merge pull request #235 from alibaba-damo-academy/dev_wjm

Dev wjm
8个文件已修改
1个文件已添加
232 ■■■■ 已修改文件
egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/eend_ola_inference.py 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_diar_eend_ola.py 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 77 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/encoder_decoder_attractor.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py 68 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_asr_inference_pipeline.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py
New file
@@ -0,0 +1,10 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_diar_pipline = pipeline(
    task=Tasks.speaker_diarization,
    model='damo/speech_diarization_eend-ola-en-us-callhome-8k',
    model_revision="v1.0.0",
)
results = inference_diar_pipline(audio_in=["https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record2.wav"])
print(results)
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
@@ -14,13 +14,12 @@
)
# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音
audio_list = [[
audio_list = [
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav"
]]
]
results = inference_diar_pipline(audio_in=audio_list)
for rst in results:
    print(rst["value"])
print(results)
funasr/bin/diar_inference_launch.py
@@ -142,6 +142,9 @@
        else:
            kwargs["param_dict"] = param_dict
        return inference_modelscope(mode=mode, **kwargs)
    elif mode == "eend-ola":
        from funasr.bin.eend_ola_inference import inference_modelscope
        return inference_modelscope(mode=mode, **kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
funasr/bin/eend_ola_inference.py
@@ -16,6 +16,7 @@
import numpy as np
import torch
from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
@@ -146,7 +147,7 @@
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        ngpu: int = 1,
        num_workers: int = 0,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
@@ -179,7 +180,6 @@
        diar_model_file=diar_model_file,
        device=device,
        dtype=dtype,
        streaming=streaming,
    )
    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
    speech2diar = Speech2Diarization.from_pretrained(
@@ -209,7 +209,7 @@
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
            data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
@@ -236,9 +236,23 @@
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            results = speech2diar(**batch)
            # post process
            a = results[0][0].cpu().numpy()
            a = medfilt(a, (11, 1))
            rst = []
            for spkid, frames in enumerate(a.T):
                frames = np.pad(frames, (1, 1), 'constant')
                changes, = np.where(np.diff(frames, axis=0) != 0)
                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
                for s, e in zip(changes[::2], changes[1::2]):
                    st = s / 10.
                    dur = (e - s) / 10.
                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
            # Only supporting batch_size==1
            key, value = keys[0], output_results_str(results, keys[0])
            item = {"key": key, "value": value}
            value = "\n".join(rst)
            item = {"key": keys[0], "value": value}
            result_list.append(item)
            if output_path is not None:
                output_writer.write(value)
funasr/models/e2e_diar_eend_ola.py
@@ -52,15 +52,15 @@
        super().__init__()
        self.frontend = frontend
        self.encoder = encoder
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.enc = encoder
        self.eda = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
            mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
            self.mapping_dict = mapping_dict
        # PostNet
        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
        self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
        self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
    def forward_encoder(self, xs, ilens):
@@ -68,7 +68,7 @@
        pad_shape = xs.shape
        xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
        emb = self.encoder(xs, xs_mask)
        emb = self.enc(xs, xs_mask)
        emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
        emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
        return emb
@@ -76,8 +76,8 @@
    def forward_post_net(self, logits, ilens):
        maxlen = torch.max(ilens).to(torch.int).item()
        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
        outputs, (_, _) = self.PostNet(logits)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
        outputs, (_, _) = self.postnet(logits)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
        outputs = [self.output_layer(output) for output in outputs]
@@ -112,7 +112,7 @@
        text = text[:, : text_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
@@ -190,18 +190,16 @@
                            shuffle: bool = True,
                            threshold: float = 0.5,
                            **kwargs):
        if self.frontend is not None:
            speech = self.frontend(speech)
        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
        emb = self.forward_encoder(speech, speech_lengths)
        if shuffle:
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.encoder_decoder_attractor.estimate(
            attractors, probs = self.eda.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
            attractors, probs = self.eda.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0:
@@ -233,10 +231,23 @@
                pred[i] = pred[i - 1]
            else:
                pred[i] = 0
        pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
        pred = [self.inv_mapping_func(i) for i in pred]
        decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
        decisions = torch.from_numpy(
            np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
            torch.float32)
        decisions = decisions[:, :n_speaker]
        return decisions
    def inv_mapping_func(self, label):
        if not isinstance(label, int):
            label = int(label)
        if label in self.mapping_dict['label2dec'].keys():
            num = self.mapping_dict['label2dec'][label]
        else:
            num = -1
        return num
    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
        pass
funasr/models/frontend/wav_frontend.py
@@ -1,14 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from abc import ABC
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.models.frontend.abs_frontend import AbsFrontend
def load_cmvn(cmvn_file):
@@ -275,7 +276,8 @@
    # inputs tensor has catted the cache tensor
    # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
    #               is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
        torch.Tensor, torch.Tensor, int]:
        """
        Apply lfr with data
        """
@@ -376,7 +378,8 @@
            if self.lfr_m != 1 or self.lfr_n != 1:
                # update self.lfr_splice_cache in self.apply_lfr
                # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
                                                                                     is_final)
            if self.cmvn_file is not None:
                mat = self.apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
@@ -398,9 +401,10 @@
        assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
        waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths)  # input shape: B T D
        if feats.shape[0]:
            #if self.reserve_waveforms is None and self.lfr_m > 1:
            # if self.reserve_waveforms is None and self.lfr_m > 1:
            #    self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
                (self.reserve_waveforms, waveforms), dim=1)
            if not self.lfr_splice_cache:  # 初始化splice_cache
                for i in range(batch_size):
                    self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
@@ -409,7 +413,8 @@
                lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache)  # B T D
                feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
                feats_lengths += lfr_splice_cache_tensor[0].shape[0]
                frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
                frame_from_waveforms = int(
                    (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
                minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
                feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
                if self.lfr_m == 1:
@@ -423,14 +428,15 @@
                    self.waveforms = self.waveforms[:, :sample_length]
            else:
                # update self.reserve_waveforms and self.lfr_splice_cache
                self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
                self.reserve_waveforms = self.waveforms[:,
                                         :-(self.frame_sample_length - self.frame_shift_sample_length)]
                for i in range(batch_size):
                    self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
                return torch.empty(0), feats_lengths
        else:
            if is_final:
                self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
                feats = torch.stack(self.lfr_splice_cache)
                feats = torch.stack(self.lfr_splice_cache)
                feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
                feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
        if is_final:
@@ -444,3 +450,54 @@
        self.reserve_waveforms = None
        self.input_cache = None
        self.lfr_splice_cache = []
class WavFrontendMel23(AbsFrontend):
    """Conventional frontend structure for ASR.
    """
    def __init__(
            self,
            fs: int = 16000,
            frame_length: int = 25,
            frame_shift: int = 10,
            lfr_m: int = 1,
            lfr_n: int = 1,
    ):
        assert check_argument_types()
        super().__init__()
        self.fs = fs
        self.frame_length = frame_length
        self.frame_shift = frame_shift
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
        self.n_mels = 23
    def output_size(self) -> int:
        return self.n_mels * (2 * self.lfr_m + 1)
    def forward(
            self,
            input: torch.Tensor,
            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            waveform = waveform.numpy()
            mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
            mat = eend_ola_feature.transform(mat)
            mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
            mat = mat[::self.lfr_n]
            mat = torch.from_numpy(mat)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats,
                                 batch_first=True,
                                 padding_value=0.0)
        return feats_pad, feats_lens
funasr/modules/eend_ola/encoder_decoder_attractor.py
@@ -16,12 +16,12 @@
        self.n_units = n_units
    def forward_core(self, xs, zeros):
        ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.float32).to(xs[0].device)
        ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.int64)
        xs = [self.enc0_dropout(x) for x in xs]
        xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
        xs = nn.utils.rnn.pack_padded_sequence(xs, ilens, batch_first=True, enforce_sorted=False)
        _, (hx, cx) = self.encoder(xs)
        zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.float32).to(zeros[0].device)
        zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.int64)
        max_zlen = torch.max(zlens).to(torch.int).item()
        zeros = [self.enc0_dropout(z) for z in zeros]
        zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
@@ -47,4 +47,4 @@
        zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
        attractors = self.forward_core(xs, zeros)
        probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
        return attractors, probs
        return attractors, probs
funasr/tasks/diar.py
@@ -750,47 +750,47 @@
            cls, args: argparse.Namespace, train: bool
    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
        assert check_argument_types()
        if args.use_preprocessor:
            retval = CommonPreprocessor(
                train=train,
                token_type=args.token_type,
                token_list=args.token_list,
                bpemodel=None,
                non_linguistic_symbols=None,
                text_cleaner=None,
                g2p_type=None,
                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
                # NOTE(kamo): Check attribute existence for backward compatibility
                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                rir_apply_prob=args.rir_apply_prob
                if hasattr(args, "rir_apply_prob")
                else 1.0,
                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
                noise_apply_prob=args.noise_apply_prob
                if hasattr(args, "noise_apply_prob")
                else 1.0,
                noise_db_range=args.noise_db_range
                if hasattr(args, "noise_db_range")
                else "13_15",
                speech_volume_normalize=args.speech_volume_normalize
                if hasattr(args, "rir_scp")
                else None,
            )
        else:
            retval = None
        assert check_return_type(retval)
        return retval
        # if args.use_preprocessor:
        #     retval = CommonPreprocessor(
        #         train=train,
        #         token_type=args.token_type,
        #         token_list=args.token_list,
        #         bpemodel=None,
        #         non_linguistic_symbols=None,
        #         text_cleaner=None,
        #         g2p_type=None,
        #         split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
        #         seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
        #         # NOTE(kamo): Check attribute existence for backward compatibility
        #         rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
        #         rir_apply_prob=args.rir_apply_prob
        #         if hasattr(args, "rir_apply_prob")
        #         else 1.0,
        #         noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
        #         noise_apply_prob=args.noise_apply_prob
        #         if hasattr(args, "noise_apply_prob")
        #         else 1.0,
        #         noise_db_range=args.noise_db_range
        #         if hasattr(args, "noise_db_range")
        #         else "13_15",
        #         speech_volume_normalize=args.speech_volume_normalize
        #         if hasattr(args, "rir_scp")
        #         else None,
        #     )
        # else:
        #     retval = None
        # assert check_return_type(retval)
        return None
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        if not inference:
            retval = ("speech", "profile", "binary_labels")
            retval = ("speech", )
        else:
            # Recognition mode
            retval = ("speech")
            retval = ("speech", )
        return retval
    @classmethod
@@ -823,7 +823,7 @@
        # 2. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        encoder = encoder_class(**args.encoder_conf)
        # 3. EncoderDecoderAttractor
        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
tests/test_asr_inference_pipeline.py
@@ -451,7 +451,7 @@
    def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
        inference_pipeline = pipeline(
            task=Tasks.,
            task=Tasks.auto_speech_recognition,
            model='damo/speech_UniASauto_speech_recognitionR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
        rec_result = inference_pipeline(
            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',