Lizerui9926
2023-03-14 a7ab8bd688d21e45f194dd9d87cb060d2cbc21bd
Merge pull request #230 from alibaba-damo-academy/dev_wjm

Dev wjm
2个文件已修改
3个文件已添加
1051 ■■■■■ 已修改文件
funasr/bin/eend_ola_inference.py 413 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_diar_eend_ola.py 242 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/eend_ola_feature.py 51 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/encoder.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py 329 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/eend_ola_inference.py
New file
@@ -0,0 +1,413 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
class Speech2Diarization:
    """Speech2Diarlization class
    Examples:
        >>> import soundfile
        >>> import numpy as np
        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
        >>> profile = np.load("profiles.npy")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2diar(audio, profile)
        {"spk1": [(int, int), ...], ...}
    """
    def __init__(
            self,
            diar_train_config: Union[Path, str] = None,
            diar_model_file: Union[Path, str] = None,
            device: str = "cpu",
            dtype: str = "float32",
    ):
        assert check_argument_types()
        # 1. Build Diarization model
        diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
            config_file=diar_train_config,
            model_file=diar_model_file,
            device=device
        )
        frontend = None
        if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
            frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
        # set up seed for eda
        np.random.seed(diar_train_args.seed)
        torch.manual_seed(diar_train_args.seed)
        torch.cuda.manual_seed(diar_train_args.seed)
        os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
        logging.info("diar_model: {}".format(diar_model))
        logging.info("diar_train_args: {}".format(diar_train_args))
        diar_model.to(dtype=getattr(torch, dtype)).eval()
        self.diar_model = diar_model
        self.diar_train_args = diar_train_args
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
            self,
            speech: Union[torch.Tensor, np.ndarray],
            speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ):
        """Inference
        Args:
            speech: Input speech data
        Returns:
            diarization results
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.diar_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        batch = {"speech": feats, "speech_lengths": feats_len}
        batch = to_device(batch, device=self.device)
        results = self.diar_model.estimate_sequential(**batch)
        return results
    @staticmethod
    def from_pretrained(
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ):
        """Build Speech2Diarization instance from the pretrained model.
        Args:
            model_tag (Optional[str]): Model tag of the pretrained models.
                Currently, the tags of espnet_model_zoo are supported.
        Returns:
            Speech2Diarization: Speech2Diarization instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2Diarization(**kwargs)
def inference_modelscope(
        diar_train_config: str,
        diar_model_file: str,
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        num_workers: int = 0,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
        model_tag: Optional[str] = None,
        allow_variable_data_keys: bool = True,
        streaming: bool = False,
        param_dict: Optional[dict] = None,
        **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    logging.info("param_dict: {}".format(param_dict))
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Build speech2diar
    speech2diar_kwargs = dict(
        diar_train_config=diar_train_config,
        diar_model_file=diar_model_file,
        device=device,
        dtype=dtype,
        streaming=streaming,
    )
    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
    speech2diar = Speech2Diarization.from_pretrained(
        model_tag=model_tag,
        **speech2diar_kwargs,
    )
    speech2diar.diar_model.eval()
    def output_results_str(results: dict, uttid: str):
        rst = []
        mid = uttid.rsplit("-", 1)[0]
        for key in results:
            results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
        template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
        for spk, segs in results.items():
            rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
        return "\n".join(rst)
    def _forward(
            data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
            raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
            output_dir_v2: Optional[str] = None,
            param_dict: Optional[dict] = None,
    ):
        # 2. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
            collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 3. Start for-loop
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
            output_writer = open("{}/result.txt".format(output_path), "w")
        result_list = []
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            results = speech2diar(**batch)
            # Only supporting batch_size==1
            key, value = keys[0], output_results_str(results, keys[0])
            item = {"key": key, "value": value}
            result_list.append(item)
            if output_path is not None:
                output_writer.write(value)
                output_writer.flush()
        if output_path is not None:
            output_writer.close()
        return result_list
    return _forward
def inference(
        data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
        diar_train_config: Optional[str],
        diar_model_file: Optional[str],
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        seed: int = 0,
        num_workers: int = 1,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
        model_tag: Optional[str] = None,
        allow_variable_data_keys: bool = True,
        streaming: bool = False,
        smooth_size: int = 83,
        dur_threshold: int = 10,
        out_format: str = "vad",
        **kwargs,
):
    inference_pipeline = inference_modelscope(
        diar_train_config=diar_train_config,
        diar_model_file=diar_model_file,
        output_dir=output_dir,
        batch_size=batch_size,
        dtype=dtype,
        ngpu=ngpu,
        seed=seed,
        num_workers=num_workers,
        log_level=log_level,
        key_file=key_file,
        model_tag=model_tag,
        allow_variable_data_keys=allow_variable_data_keys,
        streaming=streaming,
        smooth_size=smooth_size,
        dur_threshold=dur_threshold,
        out_format=out_format,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker verification/x-vector extraction",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=False)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=False,
        action="append",
    )
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--diar_train_config",
        type=str,
        help="diarization training configuration",
    )
    group.add_argument(
        "--diar_model_file",
        type=str,
        help="diarization model parameter file",
    )
    group.add_argument(
        "--dur_threshold",
        type=int,
        default=10,
        help="The threshold for short segments in number frames"
    )
    parser.add_argument(
        "--smooth_size",
        type=int,
        default=83,
        help="The smoothing window length in number frames"
    )
    group.add_argument(
        "--model_tag",
        type=str,
        help="Pretrained model tag. If specify this option, *_train_config and "
             "*_file will be overwritten",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    parser.add_argument("--streaming", type=str2bool, default=False)
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    logging.info("args: {}".format(kwargs))
    if args.output_dir is None:
        jobid, n_gpu = 1, 1
        gpuid = args.gpuid_list.split(",")[jobid - 1]
    else:
        jobid = int(args.output_dir.split(".")[-1])
        n_gpu = len(args.gpuid_list.split(","))
        gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
    results_list = inference(**kwargs)
    for results in results_list:
        print("{} {}".format(results["key"], results["value"]))
if __name__ == "__main__":
    main()
funasr/models/e2e_diar_eend_ola.py
New file
@@ -0,0 +1,242 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Tuple
import numpy as np
import torch
import torch.nn as  nn
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    pass
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
def pad_attractor(att, max_n_speakers):
    C, D = att.shape
    if C < max_n_speakers:
        att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
    return att
class DiarEENDOLAModel(AbsESPnetModel):
    """EEND-OLA diarization model"""
    def __init__(
            self,
            frontend: WavFrontendMel23,
            encoder: EENDOLATransformerEncoder,
            encoder_decoder_attractor: EncoderDecoderAttractor,
            n_units: int = 256,
            max_n_speaker: int = 8,
            attractor_loss_weight: float = 1.0,
            mapping_dict=None,
            **kwargs,
    ):
        assert check_argument_types()
        super().__init__()
        self.frontend = frontend
        self.encoder = encoder
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
            mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
            self.mapping_dict = mapping_dict
        # PostNet
        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
        self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
    def forward_encoder(self, xs, ilens):
        xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
        pad_shape = xs.shape
        xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
        emb = self.encoder(xs, xs_mask)
        emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
        emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
        return emb
    def forward_post_net(self, logits, ilens):
        maxlen = torch.max(ilens).to(torch.int).item()
        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
        outputs, (_, _) = self.PostNet(logits)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
        outputs = [self.output_layer(output) for output in outputs]
        return outputs
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        stats = dict()
        # 1. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                               1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        # Collect total loss stats
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def estimate_sequential(self,
                            speech: torch.Tensor,
                            speech_lengths: torch.Tensor,
                            n_speakers: int = None,
                            shuffle: bool = True,
                            threshold: float = 0.5,
                            **kwargs):
        if self.frontend is not None:
            speech = self.frontend(speech)
        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
        emb = self.forward_encoder(speech, speech_lengths)
        if shuffle:
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.encoder_decoder_attractor.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0:
                att = att[:n_speakers, ]
                attractors_active.append(att)
            elif threshold is not None:
                silence = torch.nonzero(p < threshold)[0]
                n_spk = silence[0] if silence.size else None
                att = att[:n_spk, ]
                attractors_active.append(att)
            else:
                NotImplementedError('n_speakers or threshold has to be given.')
        raw_n_speakers = [att.shape[0] for att in attractors_active]
        attractors = [
            pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
            for att in attractors_active]
        ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
        logits = self.forward_post_net(ys, speech_lengths)
        ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
              zip(logits, raw_n_speakers)]
        return ys, emb, attractors, raw_n_speakers
    def recover_y_from_powerlabel(self, logit, n_speaker):
        pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
        oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
        for i in oov_index:
            if i > 0:
                pred[i] = pred[i - 1]
            else:
                pred[i] = 0
        pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
        decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
        decisions = torch.from_numpy(
            np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
            torch.float32)
        decisions = decisions[:, :n_speaker]
        return decisions
funasr/models/frontend/eend_ola_feature.py
New file
@@ -0,0 +1,51 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This module is for computing audio features
import librosa
import numpy as np
def transform(Y, dtype=np.float32):
    Y = np.abs(Y)
    n_fft = 2 * (Y.shape[1] - 1)
    sr = 8000
    n_mels = 23
    mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
    Y = np.dot(Y ** 2, mel_basis.T)
    Y = np.log10(np.maximum(Y, 1e-10))
    mean = np.mean(Y, axis=0)
    Y = Y - mean
    return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
    Y_ss = Y[::subsampling]
    T_ss = T[::subsampling]
    return Y_ss, T_ss
def splice(Y, context_size=0):
    Y_pad = np.pad(
        Y,
        [(context_size, context_size), (0, 0)],
        'constant')
    Y_spliced = np.lib.stride_tricks.as_strided(
        np.ascontiguousarray(Y_pad),
        (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
        (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
    return Y_spliced
def stft(
        data,
        frame_size=1024,
        frame_shift=256):
    fft_size = 1 << (frame_size - 1).bit_length()
    if len(data) % frame_shift == 0:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T[:-1]
    else:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T
funasr/modules/eend_ola/encoder.py
@@ -1,5 +1,5 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
@@ -81,10 +81,16 @@
        return self.dropout(x)
class TransformerEncoder(nn.Module):
    def __init__(self, idim, n_layers, n_units,
                 e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
        super(TransformerEncoder, self).__init__()
class EENDOLATransformerEncoder(nn.Module):
    def __init__(self,
                 idim: int,
                 n_layers: int,
                 n_units: int,
                 e_units: int = 2048,
                 h: int = 8,
                 dropout_rate: float = 0.1,
                 use_pos_emb: bool = False):
        super(EENDOLATransformerEncoder, self).__init__()
        self.lnorm_in = nn.LayerNorm(n_units)
        self.n_layers = n_layers
        self.dropout = nn.Dropout(dropout_rate)
funasr/tasks/diar.py
@@ -20,19 +20,19 @@
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -41,17 +41,13 @@
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -70,6 +66,7 @@
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
        wav_frontend_mel23=WavFrontendMel23,
    ),
    type_check=AbsFrontend,
    default="default",
@@ -107,6 +104,7 @@
    "model",
    classes=dict(
        sond=DiarSondModel,
        eend_ola=DiarEENDOLAModel,
    ),
    type_check=AbsESPnetModel,
    default="sond",
@@ -126,6 +124,7 @@
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        ecapa_tdnn=ECAPA_TDNN,
        eend_ola_transformer=EENDOLATransformerEncoder,
    ),
    type_check=torch.nn.Module,
    default="resnet34",
@@ -176,6 +175,15 @@
    ),
    type_check=torch.nn.Module,
    default="fsmn",
)
# encoder_decoder_attractor is used for EEND-OLA
encoder_decoder_attractor_choices = ClassChoices(
    "encoder_decoder_attractor",
    classes=dict(
        eda=EncoderDecoderAttractor,
    ),
    type_check=torch.nn.Module,
    default="eda",
)
@@ -594,3 +602,294 @@
            var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
class EENDOLADiarTask(AbsTask):
    # If you need more than 1 optimizer, change this value
    num_optimizers: int = 1
    # Add variable objects configurations
    class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --specaug and --specaug_conf
        model_choices,
        # --encoder and --encoder_conf
        encoder_choices,
        # --speaker_encoder and --speaker_encoder_conf
        encoder_decoder_attractor_choices,
    ]
    # If you need to modify train() or eval() procedures, change Trainer class here
    trainer = Trainer
    @classmethod
    def add_task_arguments(cls, parser: argparse.ArgumentParser):
        group = parser.add_argument_group(description="Task related")
        # NOTE(kamo): add_arguments(..., required=True) can't be used
        # to provide --print_config mode. Instead of it, do as
        # required = parser.get_default("required")
        # required += ["token_list"]
        group.add_argument(
            "--token_list",
            type=str_or_none,
            default=None,
            help="A text mapping int-id to token",
        )
        group.add_argument(
            "--split_with_space",
            type=str2bool,
            default=True,
            help="whether to split text using <space>",
        )
        group.add_argument(
            "--seg_dict_file",
            type=str,
            default=None,
            help="seg_dict_file for text processing",
        )
        group.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        group.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        group = parser.add_argument_group(description="Preprocess related")
        group.add_argument(
            "--use_preprocessor",
            type=str2bool,
            default=True,
            help="Apply preprocessing to data or not",
        )
        group.add_argument(
            "--token_type",
            type=str,
            default="char",
            choices=["char"],
            help="The text will be tokenized in the specified level token",
        )
        parser.add_argument(
            "--speech_volume_normalize",
            type=float_or_none,
            default=None,
            help="Scale the maximum amplitude to the given value.",
        )
        parser.add_argument(
            "--rir_scp",
            type=str_or_none,
            default=None,
            help="The file path of rir scp file.",
        )
        parser.add_argument(
            "--rir_apply_prob",
            type=float,
            default=1.0,
            help="THe probability for applying RIR convolution.",
        )
        parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
        )
        parser.add_argument(
            "--noise_scp",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
        )
        parser.add_argument(
            "--noise_apply_prob",
            type=float,
            default=1.0,
            help="The probability applying Noise adding.",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of noise decibel level.",
        )
        for class_choices in cls.class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(group)
    @classmethod
    def build_collate_fn(
            cls, args: argparse.Namespace, train: bool
    ) -> Callable[
        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
        Tuple[List[str], Dict[str, torch.Tensor]],
    ]:
        assert check_argument_types()
        # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
    @classmethod
    def build_preprocess_fn(
            cls, args: argparse.Namespace, train: bool
    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
        assert check_argument_types()
        if args.use_preprocessor:
            retval = CommonPreprocessor(
                train=train,
                token_type=args.token_type,
                token_list=args.token_list,
                bpemodel=None,
                non_linguistic_symbols=None,
                text_cleaner=None,
                g2p_type=None,
                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
                # NOTE(kamo): Check attribute existence for backward compatibility
                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                rir_apply_prob=args.rir_apply_prob
                if hasattr(args, "rir_apply_prob")
                else 1.0,
                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
                noise_apply_prob=args.noise_apply_prob
                if hasattr(args, "noise_apply_prob")
                else 1.0,
                noise_db_range=args.noise_db_range
                if hasattr(args, "noise_db_range")
                else "13_15",
                speech_volume_normalize=args.speech_volume_normalize
                if hasattr(args, "rir_scp")
                else None,
            )
        else:
            retval = None
        assert check_return_type(retval)
        return retval
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        if not inference:
            retval = ("speech", "profile", "binary_labels")
        else:
            # Recognition mode
            retval = ("speech")
        return retval
    @classmethod
    def optional_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ()
        assert check_return_type(retval)
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace):
        assert check_argument_types()
        # 1. frontend
        if args.input_size is None or args.frontend == "wav_frontend_mel23":
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend':
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # 2. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        # 3. EncoderDecoderAttractor
        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
        # 9. Build model
        model_class = model_choices.get_class(args.model)
        model = model_class(
            frontend=frontend,
            encoder=encoder,
            encoder_decoder_attractor=encoder_decoder_attractor,
            **args.model_conf,
        )
        # 10. Initialize
        if args.init is not None:
            initialize(model, args.init)
        assert check_return_type(model)
        return model
    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
    @classmethod
    def build_model_from_file(
            cls,
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
    ):
        """Build model from the files.
        This method is used for inference or fine-tuning.
        Args:
            config_file: The yaml file saved when training.
            model_file: The model file saved when training.
            cmvn_file: The cmvn file for front-end
            device: Device type, "cpu", "cuda", or "cuda:N".
        """
        assert check_argument_types()
        if config_file is None:
            assert model_file is not None, (
                "The argument 'model_file' must be provided "
                "if the argument 'config_file' is not specified."
            )
            config_file = Path(model_file).parent / "config.yaml"
        else:
            config_file = Path(config_file)
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
            )
        if model_file is not None:
            if device == "cuda":
                device = f"cuda:{torch.cuda.current_device()}"
            checkpoint = torch.load(model_file, map_location=device)
            if "state_dict" in checkpoint.keys():
                model.load_state_dict(checkpoint["state_dict"])
            else:
                model.load_state_dict(checkpoint)
        model.to(device)
        return model, args