Merge pull request #230 from alibaba-damo-academy/dev_wjm
Dev wjm
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.tasks.diar import EENDOLADiarTask |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | |
| | | class Speech2Diarization: |
| | | """Speech2Diarlization class |
| | | |
| | | Examples: |
| | | >>> import soundfile |
| | | >>> import numpy as np |
| | | >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth") |
| | | >>> profile = np.load("profiles.npy") |
| | | >>> audio, rate = soundfile.read("speech.wav") |
| | | >>> speech2diar(audio, profile) |
| | | {"spk1": [(int, int), ...], ...} |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | diar_train_config: Union[Path, str] = None, |
| | | diar_model_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | dtype: str = "float32", |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build Diarization model |
| | | diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file( |
| | | config_file=diar_train_config, |
| | | model_file=diar_model_file, |
| | | device=device |
| | | ) |
| | | frontend = None |
| | | if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None: |
| | | frontend = WavFrontendMel23(**diar_train_args.frontend_conf) |
| | | |
| | | # set up seed for eda |
| | | np.random.seed(diar_train_args.seed) |
| | | torch.manual_seed(diar_train_args.seed) |
| | | torch.cuda.manual_seed(diar_train_args.seed) |
| | | os.environ['PYTORCH_SEED'] = str(diar_train_args.seed) |
| | | logging.info("diar_model: {}".format(diar_model)) |
| | | logging.info("diar_train_args: {}".format(diar_train_args)) |
| | | diar_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | self.diar_model = diar_model |
| | | self.diar_train_args = diar_train_args |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.frontend = frontend |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, |
| | | speech: Union[torch.Tensor, np.ndarray], |
| | | speech_lengths: Union[torch.Tensor, np.ndarray] = None |
| | | ): |
| | | """Inference |
| | | |
| | | Args: |
| | | speech: Input speech data |
| | | Returns: |
| | | diarization results |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | if self.frontend is not None: |
| | | feats, feats_len = self.frontend.forward(speech, speech_lengths) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | | self.diar_model.frontend = None |
| | | else: |
| | | feats = speech |
| | | feats_len = speech_lengths |
| | | batch = {"speech": feats, "speech_lengths": feats_len} |
| | | batch = to_device(batch, device=self.device) |
| | | results = self.diar_model.estimate_sequential(**batch) |
| | | |
| | | return results |
| | | |
| | | @staticmethod |
| | | def from_pretrained( |
| | | model_tag: Optional[str] = None, |
| | | **kwargs: Optional[Any], |
| | | ): |
| | | """Build Speech2Diarization instance from the pretrained model. |
| | | |
| | | Args: |
| | | model_tag (Optional[str]): Model tag of the pretrained models. |
| | | Currently, the tags of espnet_model_zoo are supported. |
| | | |
| | | Returns: |
| | | Speech2Diarization: Speech2Diarization instance. |
| | | |
| | | """ |
| | | if model_tag is not None: |
| | | try: |
| | | from espnet_model_zoo.downloader import ModelDownloader |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "`espnet_model_zoo` is not installed. " |
| | | "Please install via `pip install -U espnet_model_zoo`." |
| | | ) |
| | | raise |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2Diarization(**kwargs) |
| | | |
| | | |
| | | def inference_modelscope( |
| | | diar_train_config: str, |
| | | diar_model_file: str, |
| | | output_dir: Optional[str] = None, |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | ngpu: int = 0, |
| | | num_workers: int = 0, |
| | | log_level: Union[int, str] = "INFO", |
| | | key_file: Optional[str] = None, |
| | | model_tag: Optional[str] = None, |
| | | allow_variable_data_keys: bool = True, |
| | | streaming: bool = False, |
| | | param_dict: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("param_dict: {}".format(param_dict)) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Build speech2diar |
| | | speech2diar_kwargs = dict( |
| | | diar_train_config=diar_train_config, |
| | | diar_model_file=diar_model_file, |
| | | device=device, |
| | | dtype=dtype, |
| | | streaming=streaming, |
| | | ) |
| | | logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs)) |
| | | speech2diar = Speech2Diarization.from_pretrained( |
| | | model_tag=model_tag, |
| | | **speech2diar_kwargs, |
| | | ) |
| | | speech2diar.diar_model.eval() |
| | | |
| | | def output_results_str(results: dict, uttid: str): |
| | | rst = [] |
| | | mid = uttid.rsplit("-", 1)[0] |
| | | for key in results: |
| | | results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]] |
| | | template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>" |
| | | for spk, segs in results.items(): |
| | | rst.extend([template.format(mid, st, ed, spk) for st, ed in segs]) |
| | | |
| | | return "\n".join(rst) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, |
| | | raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | param_dict: Optional[dict] = None, |
| | | ): |
| | | # 2. Build data-iterator |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] |
| | | loader = EENDOLADiarTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False), |
| | | collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | # 3. Start for-loop |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path is not None: |
| | | os.makedirs(output_path, exist_ok=True) |
| | | output_writer = open("{}/result.txt".format(output_path), "w") |
| | | result_list = [] |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | results = speech2diar(**batch) |
| | | # Only supporting batch_size==1 |
| | | key, value = keys[0], output_results_str(results, keys[0]) |
| | | item = {"key": key, "value": value} |
| | | result_list.append(item) |
| | | if output_path is not None: |
| | | output_writer.write(value) |
| | | output_writer.flush() |
| | | |
| | | if output_path is not None: |
| | | output_writer.close() |
| | | |
| | | return result_list |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def inference( |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
| | | diar_train_config: Optional[str], |
| | | diar_model_file: Optional[str], |
| | | output_dir: Optional[str] = None, |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | ngpu: int = 0, |
| | | seed: int = 0, |
| | | num_workers: int = 1, |
| | | log_level: Union[int, str] = "INFO", |
| | | key_file: Optional[str] = None, |
| | | model_tag: Optional[str] = None, |
| | | allow_variable_data_keys: bool = True, |
| | | streaming: bool = False, |
| | | smooth_size: int = 83, |
| | | dur_threshold: int = 10, |
| | | out_format: str = "vad", |
| | | **kwargs, |
| | | ): |
| | | inference_pipeline = inference_modelscope( |
| | | diar_train_config=diar_train_config, |
| | | diar_model_file=diar_model_file, |
| | | output_dir=output_dir, |
| | | batch_size=batch_size, |
| | | dtype=dtype, |
| | | ngpu=ngpu, |
| | | seed=seed, |
| | | num_workers=num_workers, |
| | | log_level=log_level, |
| | | key_file=key_file, |
| | | model_tag=model_tag, |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | streaming=streaming, |
| | | smooth_size=smooth_size, |
| | | dur_threshold=dur_threshold, |
| | | out_format=out_format, |
| | | **kwargs, |
| | | ) |
| | | |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs=None) |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Speaker verification/x-vector extraction", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=False) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpuid_list", |
| | | type=str, |
| | | default="", |
| | | help="The visible gpus", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | required=False, |
| | | action="append", |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--diar_train_config", |
| | | type=str, |
| | | help="diarization training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--diar_model_file", |
| | | type=str, |
| | | help="diarization model parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--dur_threshold", |
| | | type=int, |
| | | default=10, |
| | | help="The threshold for short segments in number frames" |
| | | ) |
| | | parser.add_argument( |
| | | "--smooth_size", |
| | | type=int, |
| | | default=83, |
| | | help="The smoothing window length in number frames" |
| | | ) |
| | | group.add_argument( |
| | | "--model_tag", |
| | | type=str, |
| | | help="Pretrained model tag. If specify this option, *_train_config and " |
| | | "*_file will be overwritten", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | parser.add_argument("--streaming", type=str2bool, default=False) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | logging.info("args: {}".format(kwargs)) |
| | | if args.output_dir is None: |
| | | jobid, n_gpu = 1, 1 |
| | | gpuid = args.gpuid_list.split(",")[jobid - 1] |
| | | else: |
| | | jobid = int(args.output_dir.split(".")[-1]) |
| | | n_gpu = len(args.gpuid_list.split(",")) |
| | | gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu] |
| | | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = gpuid |
| | | results_list = inference(**kwargs) |
| | | for results in results_list: |
| | | print("{} {}".format(results["key"], results["value"])) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| New file |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import Tuple |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.modules.eend_ola.utils.power import generate_mapping_dict |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | pass |
| | | else: |
| | | # Nothing to do if torch<1.6.0 |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | |
| | | |
| | | def pad_attractor(att, max_n_speakers): |
| | | C, D = att.shape |
| | | if C < max_n_speakers: |
| | | att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0) |
| | | return att |
| | | |
| | | |
| | | class DiarEENDOLAModel(AbsESPnetModel): |
| | | """EEND-OLA diarization model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | frontend: WavFrontendMel23, |
| | | encoder: EENDOLATransformerEncoder, |
| | | encoder_decoder_attractor: EncoderDecoderAttractor, |
| | | n_units: int = 256, |
| | | max_n_speaker: int = 8, |
| | | attractor_loss_weight: float = 1.0, |
| | | mapping_dict=None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | super().__init__() |
| | | self.frontend = frontend |
| | | self.encoder = encoder |
| | | self.encoder_decoder_attractor = encoder_decoder_attractor |
| | | self.attractor_loss_weight = attractor_loss_weight |
| | | self.max_n_speaker = max_n_speaker |
| | | if mapping_dict is None: |
| | | mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker) |
| | | self.mapping_dict = mapping_dict |
| | | # PostNet |
| | | self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) |
| | | self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1) |
| | | |
| | | def forward_encoder(self, xs, ilens): |
| | | xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) |
| | | pad_shape = xs.shape |
| | | xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens] |
| | | xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2) |
| | | emb = self.encoder(xs, xs_mask) |
| | | emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0) |
| | | emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)] |
| | | return emb |
| | | |
| | | def forward_post_net(self, logits, ilens): |
| | | maxlen = torch.max(ilens).to(torch.int).item() |
| | | logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) |
| | | logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False) |
| | | outputs, (_, _) = self.PostNet(logits) |
| | | outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] |
| | | outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] |
| | | outputs = [self.output_layer(output) for output in outputs] |
| | | return outputs |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | # Check that batch_size is unified |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | batch_size = speech.shape[0] |
| | | |
| | | # for data-parallel |
| | | text = text[:, : text_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | | loss_ctc, cer_ctc = None, None |
| | | stats = dict() |
| | | |
| | | # 1. CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | |
| | | # Collect total loss stats |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def estimate_sequential(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | n_speakers: int = None, |
| | | shuffle: bool = True, |
| | | threshold: float = 0.5, |
| | | **kwargs): |
| | | if self.frontend is not None: |
| | | speech = self.frontend(speech) |
| | | speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] |
| | | emb = self.forward_encoder(speech, speech_lengths) |
| | | if shuffle: |
| | | orders = [np.arange(e.shape[0]) for e in emb] |
| | | for order in orders: |
| | | np.random.shuffle(order) |
| | | attractors, probs = self.encoder_decoder_attractor.estimate( |
| | | [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) |
| | | else: |
| | | attractors, probs = self.encoder_decoder_attractor.estimate(emb) |
| | | attractors_active = [] |
| | | for p, att, e in zip(probs, attractors, emb): |
| | | if n_speakers and n_speakers >= 0: |
| | | att = att[:n_speakers, ] |
| | | attractors_active.append(att) |
| | | elif threshold is not None: |
| | | silence = torch.nonzero(p < threshold)[0] |
| | | n_spk = silence[0] if silence.size else None |
| | | att = att[:n_spk, ] |
| | | attractors_active.append(att) |
| | | else: |
| | | NotImplementedError('n_speakers or threshold has to be given.') |
| | | raw_n_speakers = [att.shape[0] for att in attractors_active] |
| | | attractors = [ |
| | | pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker] |
| | | for att in attractors_active] |
| | | ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] |
| | | logits = self.forward_post_net(ys, speech_lengths) |
| | | ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in |
| | | zip(logits, raw_n_speakers)] |
| | | |
| | | return ys, emb, attractors, raw_n_speakers |
| | | |
| | | def recover_y_from_powerlabel(self, logit, n_speaker): |
| | | pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) |
| | | oov_index = torch.where(pred == self.mapping_dict['oov'])[0] |
| | | for i in oov_index: |
| | | if i > 0: |
| | | pred[i] = pred[i - 1] |
| | | else: |
| | | pred[i] = 0 |
| | | pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] |
| | | decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] |
| | | decisions = torch.from_numpy( |
| | | np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( |
| | | torch.float32) |
| | | decisions = decisions[:, :n_speaker] |
| | | return decisions |
| New file |
| | |
| | | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) |
| | | # Licensed under the MIT license. |
| | | # |
| | | # This module is for computing audio features |
| | | |
| | | import librosa |
| | | import numpy as np |
| | | |
| | | |
| | | def transform(Y, dtype=np.float32): |
| | | Y = np.abs(Y) |
| | | n_fft = 2 * (Y.shape[1] - 1) |
| | | sr = 8000 |
| | | n_mels = 23 |
| | | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) |
| | | Y = np.dot(Y ** 2, mel_basis.T) |
| | | Y = np.log10(np.maximum(Y, 1e-10)) |
| | | mean = np.mean(Y, axis=0) |
| | | Y = Y - mean |
| | | return Y.astype(dtype) |
| | | |
| | | |
| | | def subsample(Y, T, subsampling=1): |
| | | Y_ss = Y[::subsampling] |
| | | T_ss = T[::subsampling] |
| | | return Y_ss, T_ss |
| | | |
| | | |
| | | def splice(Y, context_size=0): |
| | | Y_pad = np.pad( |
| | | Y, |
| | | [(context_size, context_size), (0, 0)], |
| | | 'constant') |
| | | Y_spliced = np.lib.stride_tricks.as_strided( |
| | | np.ascontiguousarray(Y_pad), |
| | | (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), |
| | | (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) |
| | | return Y_spliced |
| | | |
| | | |
| | | def stft( |
| | | data, |
| | | frame_size=1024, |
| | | frame_shift=256): |
| | | fft_size = 1 << (frame_size - 1).bit_length() |
| | | if len(data) % frame_shift == 0: |
| | | return librosa.stft(data, n_fft=fft_size, win_length=frame_size, |
| | | hop_length=frame_shift).T[:-1] |
| | | else: |
| | | return librosa.stft(data, n_fft=fft_size, win_length=frame_size, |
| | | hop_length=frame_shift).T |
| | |
| | | import math |
| | | import numpy as np |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import nn |
| | |
| | | return self.dropout(x) |
| | | |
| | | |
| | | class TransformerEncoder(nn.Module): |
| | | def __init__(self, idim, n_layers, n_units, |
| | | e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False): |
| | | super(TransformerEncoder, self).__init__() |
| | | class EENDOLATransformerEncoder(nn.Module): |
| | | def __init__(self, |
| | | idim: int, |
| | | n_layers: int, |
| | | n_units: int, |
| | | e_units: int = 2048, |
| | | h: int = 8, |
| | | dropout_rate: float = 0.1, |
| | | use_pos_emb: bool = False): |
| | | super(EENDOLATransformerEncoder, self).__init__() |
| | | self.lnorm_in = nn.LayerNorm(n_units) |
| | | self.n_layers = n_layers |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | | from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder |
| | | from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.e2e_diar_sond import DiarSondModel |
| | | from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | | from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.postencoder.hugging_face_transformers_postencoder import ( |
| | | HuggingFaceTransformersPostEncoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.preencoder.linear import LinearProjection |
| | | from funasr.models.preencoder.sinc import LightweightSincConvs |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | |
| | | s3prl=S3prlFrontend, |
| | | fused=FusedFrontends, |
| | | wav_frontend=WavFrontend, |
| | | wav_frontend_mel23=WavFrontendMel23, |
| | | ), |
| | | type_check=AbsFrontend, |
| | | default="default", |
| | |
| | | "model", |
| | | classes=dict( |
| | | sond=DiarSondModel, |
| | | eend_ola=DiarEENDOLAModel, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | default="sond", |
| | |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ecapa_tdnn=ECAPA_TDNN, |
| | | eend_ola_transformer=EENDOLATransformerEncoder, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="resnet34", |
| | |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="fsmn", |
| | | ) |
| | | # encoder_decoder_attractor is used for EEND-OLA |
| | | encoder_decoder_attractor_choices = ClassChoices( |
| | | "encoder_decoder_attractor", |
| | | classes=dict( |
| | | eda=EncoderDecoderAttractor, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="eda", |
| | | ) |
| | | |
| | | |
| | |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | class EENDOLADiarTask(AbsTask): |
| | | # If you need more than 1 optimizer, change this value |
| | | num_optimizers: int = 1 |
| | | |
| | | # Add variable objects configurations |
| | | class_choices_list = [ |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | | encoder_choices, |
| | | # --speaker_encoder and --speaker_encoder_conf |
| | | encoder_decoder_attractor_choices, |
| | | ] |
| | | |
| | | # If you need to modify train() or eval() procedures, change Trainer class here |
| | | trainer = Trainer |
| | | |
| | | @classmethod |
| | | def add_task_arguments(cls, parser: argparse.ArgumentParser): |
| | | group = parser.add_argument_group(description="Task related") |
| | | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | # required = parser.get_default("required") |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | group.add_argument( |
| | | "--split_with_space", |
| | | type=str2bool, |
| | | default=True, |
| | | help="whether to split text using <space>", |
| | | ) |
| | | group.add_argument( |
| | | "--seg_dict_file", |
| | | type=str, |
| | | default=None, |
| | | help="seg_dict_file for text processing", |
| | | ) |
| | | group.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | |
| | | group.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of input dimension of the feature", |
| | | ) |
| | | |
| | | group = parser.add_argument_group(description="Preprocess related") |
| | | group.add_argument( |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Apply preprocessing to data or not", |
| | | ) |
| | | group.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="char", |
| | | choices=["char"], |
| | | help="The text will be tokenized in the specified level token", |
| | | ) |
| | | parser.add_argument( |
| | | "--speech_volume_normalize", |
| | | type=float_or_none, |
| | | default=None, |
| | | help="Scale the maximum amplitude to the given value.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of rir scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="THe probability for applying RIR convolution.", |
| | | ) |
| | | parser.add_argument( |
| | | "--cmvn_file", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="The probability applying Noise adding.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_db_range", |
| | | type=str, |
| | | default="13_15", |
| | | help="The range of noise decibel level.", |
| | | ) |
| | | |
| | | for class_choices in cls.class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(group) |
| | | |
| | | @classmethod |
| | | def build_collate_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Callable[ |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | assert check_argument_types() |
| | | # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | assert check_argument_types() |
| | | if args.use_preprocessor: |
| | | retval = CommonPreprocessor( |
| | | train=train, |
| | | token_type=args.token_type, |
| | | token_list=args.token_list, |
| | | bpemodel=None, |
| | | non_linguistic_symbols=None, |
| | | text_cleaner=None, |
| | | g2p_type=None, |
| | | split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | # NOTE(kamo): Check attribute existence for backward compatibility |
| | | rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | rir_apply_prob=args.rir_apply_prob |
| | | if hasattr(args, "rir_apply_prob") |
| | | else 1.0, |
| | | noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, |
| | | noise_apply_prob=args.noise_apply_prob |
| | | if hasattr(args, "noise_apply_prob") |
| | | else 1.0, |
| | | noise_db_range=args.noise_db_range |
| | | if hasattr(args, "noise_db_range") |
| | | else "13_15", |
| | | speech_volume_normalize=args.speech_volume_normalize |
| | | if hasattr(args, "rir_scp") |
| | | else None, |
| | | ) |
| | | else: |
| | | retval = None |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | if not inference: |
| | | retval = ("speech", "profile", "binary_labels") |
| | | else: |
| | | # Recognition mode |
| | | retval = ("speech") |
| | | return retval |
| | | |
| | | @classmethod |
| | | def optional_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None or args.frontend == "wav_frontend_mel23": |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # 2. Encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | |
| | | # 3. EncoderDecoderAttractor |
| | | encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) |
| | | encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) |
| | | |
| | | # 9. Build model |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | frontend=frontend, |
| | | encoder=encoder, |
| | | encoder_decoder_attractor=encoder_decoder_attractor, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | # 10. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |
| | | cls, |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | | |
| | | Args: |
| | | config_file: The yaml file saved when training. |
| | | model_file: The model file saved when training. |
| | | cmvn_file: The cmvn file for front-end |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | | "if the argument 'config_file' is not specified." |
| | | ) |
| | | config_file = Path(model_file).parent / "config.yaml" |
| | | else: |
| | | config_file = Path(config_file) |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | ) |
| | | if model_file is not None: |
| | | if device == "cuda": |
| | | device = f"cuda:{torch.cuda.current_device()}" |
| | | checkpoint = torch.load(model_file, map_location=device) |
| | | if "state_dict" in checkpoint.keys(): |
| | | model.load_state_dict(checkpoint["state_dict"]) |
| | | else: |
| | | model.load_state_dict(checkpoint) |
| | | model.to(device) |
| | | return model, args |