Emo2Vec限定选择的情感类别 (#1730)
* 限定选择的情感类别
* 使用none来禁用情感标签输出
* 修改输出接口
* 使用unuse来禁用token
---------
Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
| | |
| | | if self.proj: |
| | | x = x.mean(dim=1) |
| | | x = self.proj(x) |
| | | for idx, lab in enumerate(labels): |
| | | x[:,idx] = -np.inf if lab.startswith("unuse") else x[:,idx] |
| | | x = torch.softmax(x, dim=-1) |
| | | scores = x[0].tolist() |
| | | |
| | | result_i = {"key": key[i], "labels": labels, "scores": scores} |
| | | select_label = [lb for lb in labels if not lb.startswith("unuse")] |
| | | select_score = [scores[idx] for idx, lb in enumerate(labels) if not lb.startswith("unuse")] |
| | | |
| | | # result_i = {"key": key[i], "labels": labels, "scores": scores} |
| | | result_i = {"key": key[i], "labels": select_label, "scores": select_score} |
| | | |
| | | if extract_embedding: |
| | | result_i["feats"] = feats |
| | | results.append(result_i) |