EMO_UNK禁用和Merge VAD修复 (#1940)
* 添加富文本解码约束
* special token
* bug fix
* fix
* 增加unk score的参数
* emobaned
* kwargs2cfg
* merge_vad bug fix
---------
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
| | |
| | | end_vad = time.time() |
| | | |
| | | # FIX(gcf): concat the vad clips for sense vocie model for better aed |
| | | if kwargs.get("merge_vad", False): |
| | | if cfg.get("merge_vad", False): |
| | | for i in range(len(res)): |
| | | res[i]["value"] = merge_vad( |
| | | res[i]["value"], kwargs.get("merge_length_s", 15) * 1000 |
| | |
| | | self.embed = torch.nn.Embedding( |
| | | 7 + len(self.lid_dict) + len(self.textnorm_dict), input_size |
| | | ) |
| | | self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004} |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | ctc_logits = self.ctc.log_softmax(encoder_out) |
| | | if kwargs.get("ban_emo_unk", False): |
| | | ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf") |
| | | |
| | | results = [] |
| | | b, n, d = encoder_out.size() |
| | |
| | | return speech_list, speech_lengths_list |
| | | |
| | | |
| | | def merge_vad(vad_result, max_length=15000): |
| | | def merge_vad(vad_result, max_length=15000, min_length=0): |
| | | new_result = [] |
| | | if len(vad_result) <= 1: |
| | | return vad_result |
| | | time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result] |
| | | time_step = sorted(list(set(time_step))) |
| | | if len(time_step) == 0: |
| | |
| | | time = time_step[i] |
| | | if time_step[i + 1] - bg < max_length: |
| | | continue |
| | | if time - bg < max_length * 1.5: |
| | | if time - bg > min_length: |
| | | new_result.append([bg, time]) |
| | | else: |
| | | split_num = int(time - bg) // max_length + 1 |
| | | spl_l = int(time - bg) // split_num |
| | | for j in range(split_num): |
| | | new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l]) |
| | | # if time - bg < max_length * 1.5: |
| | | # new_result.append([bg, time]) |
| | | # else: |
| | | # split_num = int(time - bg) // max_length + 1 |
| | | # spl_l = int(time - bg) // split_num |
| | | # for j in range(split_num): |
| | | # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l]) |
| | | bg = time |
| | | new_result.append([bg, time_step[-1]]) |
| | | return new_result |