嘉渊
2023-04-24 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1
egs/mars/sd/scripts/simu_chunk_with_labels.py
@@ -8,6 +8,14 @@
import argparse
from collections import OrderedDict
import random
from typing import List, Dict
from copy import deepcopy
import json
logging.basicConfig(
    level="INFO",
    format=f"[{os.uname()[1].split('.')[0]}]"
           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
class MyRunner(MultiProcessRunnerV3):
@@ -19,63 +27,233 @@
        parser.add_argument("--spk2meeting", type=str, required=True)
        parser.add_argument("--utt2xvec", type=str, required=True)
        parser.add_argument("--out_dir", type=str, required=True)
        parser.add_argument("--chunk_size", type=int, default=16)
        parser.add_argument("--chunk_shift", type=int, default=4)
        parser.add_argument("--chunk_size", type=float, default=16)
        parser.add_argument("--chunk_shift", type=float, default=4)
        parser.add_argument("--frame_shift", type=float, default=0.01)
        parser.add_argument("--embedding_dim", type=int, default=None)
        parser.add_argument("--average_emb_num", type=int, default=0)
        parser.add_argument("--subset", type=int, default=0)
        parser.add_argument("--data_json", type=str, default=None)
        parser.add_argument("--seed", type=int, default=1234)
        parser.add_argument("--log_interval", type=int, default=100)
        args = parser.parse_args()
        random.seed(args.seed)
        np.random.seed(args.seed)
        logging.info("Loading data...")
        if not os.path.exists(args.data_json):
            label_list = load_scp_as_list(args.label_scp)
            wav_scp = load_scp_as_dict(args.wav_scp)
            utt2spk = load_scp_as_dict(args.utt2spk)
            utt2xvec = load_scp_as_dict(args.utt2xvec)
            spk2meeting = load_scp_as_dict(args.spk2meeting)
            meeting2spks = OrderedDict()
            for spk, meeting in spk2meeting.items():
                if meeting not in meeting2spks:
                    meeting2spks[meeting] = []
                meeting2spks[meeting].append(spk)
            spk2utts = OrderedDict()
            for utt, spk in utt2spk.items():
                if spk not in spk2utts:
                    spk2utts[spk] = []
                spk2utts[spk].append(utt)
            os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
            logging.info("Dump data...")
            json.dump({
                "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
                "spk2utts": spk2utts, "meeting2spks": meeting2spks
            }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
        else:
            data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
            label_list = data_dict["label_list"]
            wav_scp = data_dict["wav_scp"]
            utt2xvec = data_dict["utt2xvec"]
            spk2utts = data_dict["spk2utts"]
            meeting2spks = data_dict["meeting2spks"]
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        label_list = load_scp_as_list(args.label_scp)
        wav_scp = load_scp_as_dict(args.wav_scp)
        utt2spk = load_scp_as_dict(args.utt2spk)
        utt2xvec = load_scp_as_dict(args.utt2xvec)
        spk2meeting = load_scp_as_dict(args.spk2meeting)
        args.chunk_size = int(args.chunk_size / args.frame_shift)
        args.chunk_shift = int(args.chunk_shift / args.frame_shift)
        meeting2spks = OrderedDict()
        for spk, meeting in spk2meeting.items():
            if meeting not in meeting2spks:
                meeting2spks[meeting] = []
            meeting2spks[meeting].append(spk)
        if args.embedding_dim is None:
            args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
            logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
        spk2utts = OrderedDict()
        for utt, spk in utt2spk.items():
            if spk not in spk2utts:
                spk2utts[spk] = []
            spk2utts[spk].append(utt)
        logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
            len(wav_scp), len(spk2utts), len(meeting2spks)
        ))
        return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
    def post(self, results_list, args):
        pass
        logging.info("[main]: Got {} chunks.".format(sum(results_list)))
def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
    utt_list = spk2utts[spk]
    wav_list = []
    cur_length = 0
    while cur_length < sample_length:
        uttid = random.choice(utt_list)
        wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
        wav_list.append(wav)
        cur_length += len(wav)
    concat_wav = np.concatenate(wav_list, axis=0)
    start = random.randint(0, len(concat_wav) - sample_length)
    return concat_wav[start: start+sample_length]
def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
    # process for dummy speaker
    if spk == "None":
        return np.zeros((1, embedding_dim), dtype=np.float32)
    # calculate averaged speaker embeddings
    utt_list = spk2utts[spk]
    if average_emb_num == 0 or average_emb_num > len(utt_list):
        xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
    else:
        xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
    xvec = np.concatenate(xvec_list, axis=0)
    xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
    xvec = np.mean(xvec, axis=0)
    return xvec
def simu_chunk(
        frame_label: np.ndarray,
        sample_label: np.ndarray,
        wav_scp: Dict[str, str],
        utt2xvec: Dict[str, str],
        spk2utts: Dict[str, List[str]],
        meeting2spks: Dict[str, List[str]],
        all_speaker_list: List[str],
        meeting_list: List[str],
        embedding_dim: int,
        average_emb_num: int,
):
    frame_length, max_spk_num = frame_label.shape
    sample_length = sample_label.shape[0]
    positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
    pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
    # get positive speakers
    if len(pos_speaker_list) >= positive_speaker_num:
        pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
    else:
        while len(pos_speaker_list) < positive_speaker_num:
            _spk = random.choice(all_speaker_list)
            if _spk not in pos_speaker_list:
                pos_speaker_list.append(_spk)
    # get negative speakers
    negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
    neg_speaker_list = []
    while len(neg_speaker_list) < negative_speaker_num:
        _spk = random.choice(all_speaker_list)
        if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
            neg_speaker_list.append(_spk)
    neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
    random.shuffle(pos_speaker_list)
    random.shuffle(neg_speaker_list)
    seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
    this_spk_list = []
    for idx, frame_num in enumerate(frame_label.sum(axis=0)):
        if frame_num > 0:
            spk = pos_speaker_list.pop(0)
            this_spk_list.append(spk)
            simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
            seperated_wav[:, idx] = simu_spk_wav
        else:
            spk = neg_speaker_list.pop(0)
            this_spk_list.append(spk)
    # calculate mixed wav
    mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
    # shuffle the order of speakers
    shuffle_idx = list(range(max_spk_num))
    random.shuffle(shuffle_idx)
    this_spk_list = [this_spk_list[x] for x in shuffle_idx]
    seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
    frame_label = frame_label.transpose()[shuffle_idx].transpose()
    # calculate profile
    profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
               for spk in this_spk_list]
    profile = np.vstack(profile)
    # pse_weights = 2 ** np.arange(max_spk_num)
    # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
    # pse_label = pse_label.astype(str).tolist()
    return mixed_wav, seperated_wav, profile, frame_label
def process(task_args):
    task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
    logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
    out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
    wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
    wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
    # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    out_path = os.path.join(args.out_dir, "label.{}".format(task_idx + 1))
    out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
    profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    idx = 0
    for _, label_path in task_list:
        rand_shift = random.randint(0, int(args.chunk_shift / args.frame_shift))
    speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
    labels_list = []
    total_chunks = 0
    for org_mid, label_path in task_list:
        whole_label = kaldiio.load_mat(label_path)
        whole_label = whole_label[rand_shift:]
        num_chunk = (whole_label.shape[0] - args.chunk_size) // args.chunk_shift
        # random offset to keep diversity
        rand_shift = random.randint(0, args.chunk_shift)
        num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
        labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
        total_chunks += num_chunk
    idx = 0
    simu_chunk_count = 0
    for org_mid, whole_label, rand_shift, num_chunk in labels_list:
        for i in range(num_chunk):
            utt_id = "part{}_utt{:10d}".format(task_idx + 1, idx + 1)
            idx = idx + 1
            st = i * args.chunk_shift + rand_shift
            ed = i * args.chunk_shift + args.chunk_size + rand_shift
            utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
                args.subset + 1, task_idx + 1, org_mid, st, ed
            )
            frame_label = whole_label[st: ed, :]
            sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
            mix_wav, seg_wav, profile, frame_label = simu_chunk(
                frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
                speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
            )
            wav_mix_writer(utt_id, mix_wav)
            # wav_sep_writer(utt_id, seg_wav)
            profile_writer(utt_id, profile)
            label_writer(utt_id, frame_label)
            simu_chunk_count += 1
            if simu_chunk_count % args.log_interval == 0:
                logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
                    task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
    wav_mix_writer.close()
    wav_sep_writer.close()
    # wav_sep_writer.close()
    profile_writer.close()
    label_writer.close()
    return None
    logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
    return simu_chunk_count
if __name__ == '__main__':