From 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8 Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: 星期五, 05 七月 2024 00:55:32 +0800
Subject: [PATCH] 优化speakid和语句匹配逻辑,部分解决speakid不从0递增问题 (#1870)
---
examples/aishell/paraformer/utils/extract_embeds.py | 16 +++++++++-------
1 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/examples/aishell/paraformer/utils/extract_embeds.py b/examples/aishell/paraformer/utils/extract_embeds.py
index 7b817d8..e0cf98d 100755
--- a/examples/aishell/paraformer/utils/extract_embeds.py
+++ b/examples/aishell/paraformer/utils/extract_embeds.py
@@ -5,6 +5,7 @@
import torch
from kaldiio import WriteHelper
import re
+
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
@@ -16,17 +17,17 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
-with open(text_file_json, 'r') as f:
+with open(text_file_json, "r") as f:
js = f.readlines()
f_shape = open(out_shape, "w")
-with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
+with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
with torch.no_grad():
for idx, line in enumerate(js):
id, tokens = line.strip().split(" ", 1)
tokens = re.sub(" ", "", tokens.strip())
- tokens = ' '.join([j for j in tokens])
+ tokens = " ".join([j for j in tokens])
token_num = len(tokens.split(" "))
outputs = extractor(tokens)
outputs = np.array(outputs)
@@ -38,10 +39,11 @@
shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
f_shape.write(shape_line)
else:
- print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
-
+ print(
+ "{}, size has changed, {}, {}, {}".format(
+ id, token_num, token_num_embeds, tokens
+ )
+ )
f_shape.close()
-
-
--
Gitblit v1.9.1