From 559cc2c6e296bc80917a7408911f671dfcc2b68b Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 17:25:54 +0800
Subject: [PATCH] update repo

---
 egs/aishell2/transformer/utils/extract_embeds.py |   47 +++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 47 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/extract_embeds.py b/egs/aishell2/transformer/utils/extract_embeds.py
new file mode 100755
index 0000000..7b817d8
--- /dev/null
+++ b/egs/aishell2/transformer/utils/extract_embeds.py
@@ -0,0 +1,47 @@
+from transformers import AutoTokenizer, AutoModel, pipeline
+import numpy as np
+import sys
+import os
+import torch
+from kaldiio import WriteHelper
+import re
+text_file_json = sys.argv[1]
+out_ark = sys.argv[2]
+out_scp = sys.argv[3]
+out_shape = sys.argv[4]
+device = int(sys.argv[5])
+model_path = sys.argv[6]
+
+model = AutoModel.from_pretrained(model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path)
+extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
+
+with open(text_file_json, 'r') as f:
+    js = f.readlines()
+
+
+f_shape = open(out_shape, "w")
+with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
+    with torch.no_grad():
+        for idx, line in enumerate(js):
+            id, tokens = line.strip().split(" ", 1)
+            tokens = re.sub(" ", "", tokens.strip())
+            tokens = ' '.join([j for j in tokens])
+            token_num = len(tokens.split(" "))
+            outputs = extractor(tokens)
+            outputs = np.array(outputs)
+            embeds = outputs[0, 1:-1, :]
+
+            token_num_embeds, dim = embeds.shape
+            if token_num == token_num_embeds:
+                writer(id, embeds)
+                shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
+                f_shape.write(shape_line)
+            else:
+                print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
+
+
+
+f_shape.close()
+
+

--
Gitblit v1.9.1