游雁
2023-09-13 33d3d2084403fd34b79c835d2f2fe04f6cd8f738
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
import numpy as np
import sys
import math
 
 
class Segment(object):
    def __init__(self, uttid, spkr, stime, etime, text):
        self.uttid = uttid
        self.spkr = spkr
        self.stime = round(stime, 2)
        self.etime = round(etime, 2)
        self.text = text
 
    def change_stime(self, time):
        self.stime = time
 
    def change_etime(self, time):
        self.etime = time
 
 
def get_args():
    parser = argparse.ArgumentParser(description="process the textgrid files")
    parser.add_argument("--path", type=str, required=True, help="Data path")
    args = parser.parse_args()
    return args
 
 
 
def main(args):
    textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
    segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8")
    utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8")
 
    # get the path of textgrid file for each utterance
    for line in textgrid_flist:
        line_array = line.strip().split(" ")
        path = Path(line_array[1])
        uttid = line_array[0]
 
        try:
            tg = textgrid.TextGrid.fromFile(path)
        except:
            pdb.set_trace()
        num_spk = tg.__len__()
        spk2textgrid = {}
        spk2weight = {}
        weight2spk = {}
        cnt = 2
        xmax = 0
        for i in range(tg.__len__()):
            spk_name = tg[i].name
            if spk_name not in spk2weight:
                spk2weight[spk_name] = cnt
                weight2spk[cnt] = spk_name
                cnt = cnt * 2
            segments = []
            for j in range(tg[i].__len__()):
                if tg[i][j].mark:
                    if xmax < tg[i][j].maxTime:
                        xmax = tg[i][j].maxTime
                    segments.append(
                        Segment(
                            uttid,
                            tg[i].name,
                            tg[i][j].minTime,
                            tg[i][j].maxTime,
                            tg[i][j].mark.strip(),
                        )
                    )
            segments = sorted(segments, key=lambda x: x.stime)
            spk2textgrid[spk_name] = segments
        olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32)
        for spkid in spk2weight.keys():
            weight = spk2weight[spkid]
            segments = spk2textgrid[spkid]
            idx = int(math.log2(weight) )- 1
            for i in range(len(segments)):
                stime = segments[i].stime
                etime = segments[i].etime
                olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight
        sum_label = olp_label.sum(axis=0)
        stime = 0
        pre_value = 0
        for pos in range(sum_label.shape[0]):
            if sum_label[pos] in weight2spk:
                if pre_value in weight2spk:
                    if sum_label[pos] != pre_value:    
                        spkids = weight2spk[pre_value]
                        spkid_array = spkids.split("_")
                        spkid = spkid_array[-1]
                        #spkid = uttid+spkid 
                        if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
                            segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
                            utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
                        stime = pos
                        pre_value = sum_label[pos]
                else:
                    stime = pos
                    pre_value = sum_label[pos]
            else:
                if pre_value in weight2spk:
                    spkids = weight2spk[pre_value]
                    spkid_array = spkids.split("_")
                    spkid = spkid_array[-1]
                    #spkid = uttid+spkid 
                    if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
                        segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
                        utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
                    stime = pos
                    pre_value = sum_label[pos]
    textgrid_flist.close()
    segment_file.close()
 
 
if __name__ == "__main__":
    args = get_args()
    main(args)