From 559cc2c6e296bc80917a7408911f671dfcc2b68b Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 17:25:54 +0800
Subject: [PATCH] update repo
---
egs/aishell2/transformer/utils/extract_embeds.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 47 insertions(+), 0 deletions(-)
diff --git a/egs/aishell2/transformer/utils/extract_embeds.py b/egs/aishell2/transformer/utils/extract_embeds.py
new file mode 100755
index 0000000..7b817d8
--- /dev/null
+++ b/egs/aishell2/transformer/utils/extract_embeds.py
@@ -0,0 +1,47 @@
+from transformers import AutoTokenizer, AutoModel, pipeline
+import numpy as np
+import sys
+import os
+import torch
+from kaldiio import WriteHelper
+import re
+text_file_json = sys.argv[1]
+out_ark = sys.argv[2]
+out_scp = sys.argv[3]
+out_shape = sys.argv[4]
+device = int(sys.argv[5])
+model_path = sys.argv[6]
+
+model = AutoModel.from_pretrained(model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path)
+extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
+
+with open(text_file_json, 'r') as f:
+ js = f.readlines()
+
+
+f_shape = open(out_shape, "w")
+with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
+ with torch.no_grad():
+ for idx, line in enumerate(js):
+ id, tokens = line.strip().split(" ", 1)
+ tokens = re.sub(" ", "", tokens.strip())
+ tokens = ' '.join([j for j in tokens])
+ token_num = len(tokens.split(" "))
+ outputs = extractor(tokens)
+ outputs = np.array(outputs)
+ embeds = outputs[0, 1:-1, :]
+
+ token_num_embeds, dim = embeds.shape
+ if token_num == token_num_embeds:
+ writer(id, embeds)
+ shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
+ f_shape.write(shape_line)
+ else:
+ print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
+
+
+
+f_shape.close()
+
+
--
Gitblit v1.9.1