kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
examples/aishell/paraformer/utils/extract_embeds.py
@@ -5,6 +5,7 @@
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
@@ -16,17 +17,17 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, 'r') as f:
with open(text_file_json, "r") as f:
    js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
    with torch.no_grad():
        for idx, line in enumerate(js):
            id, tokens = line.strip().split(" ", 1)
            tokens = re.sub(" ", "", tokens.strip())
            tokens = ' '.join([j for j in tokens])
            tokens = " ".join([j for j in tokens])
            token_num = len(tokens.split(" "))
            outputs = extractor(tokens)
            outputs = np.array(outputs)
@@ -38,10 +39,11 @@
                shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
                f_shape.write(shape_line)
            else:
                print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
                print(
                    "{}, size has changed, {}, {}, {}".format(
                        id, token_num, token_num_embeds, tokens
                    )
                )
f_shape.close()