From 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8 Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: 星期五, 05 七月 2024 00:55:32 +0800
Subject: [PATCH] 优化speakid和语句匹配逻辑,部分解决speakid不从0递增问题 (#1870)

---
 examples/aishell/paraformer/utils/extract_embeds.py |   16 +++++++++-------
 1 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/examples/aishell/paraformer/utils/extract_embeds.py b/examples/aishell/paraformer/utils/extract_embeds.py
index 7b817d8..e0cf98d 100755
--- a/examples/aishell/paraformer/utils/extract_embeds.py
+++ b/examples/aishell/paraformer/utils/extract_embeds.py
@@ -5,6 +5,7 @@
 import torch
 from kaldiio import WriteHelper
 import re
+
 text_file_json = sys.argv[1]
 out_ark = sys.argv[2]
 out_scp = sys.argv[3]
@@ -16,17 +17,17 @@
 tokenizer = AutoTokenizer.from_pretrained(model_path)
 extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
 
-with open(text_file_json, 'r') as f:
+with open(text_file_json, "r") as f:
     js = f.readlines()
 
 
 f_shape = open(out_shape, "w")
-with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
+with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
     with torch.no_grad():
         for idx, line in enumerate(js):
             id, tokens = line.strip().split(" ", 1)
             tokens = re.sub(" ", "", tokens.strip())
-            tokens = ' '.join([j for j in tokens])
+            tokens = " ".join([j for j in tokens])
             token_num = len(tokens.split(" "))
             outputs = extractor(tokens)
             outputs = np.array(outputs)
@@ -38,10 +39,11 @@
                 shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
                 f_shape.write(shape_line)
             else:
-                print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
-
+                print(
+                    "{}, size has changed, {}, {}, {}".format(
+                        id, token_num, token_num_embeds, tokens
+                    )
+                )
 
 
 f_shape.close()
-
-

--
Gitblit v1.9.1