From 1243938b7bf56b08688530f8ed85bce7a8c1ef7e Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 11 五月 2023 14:21:50 +0800
Subject: [PATCH] update repo

---
 egs/librispeech_100h/conformer/local/spm_train  |   12 ++++++
 egs/librispeech_100h/conformer/run.sh           |    4 +-
 egs/librispeech_100h/conformer/local/spm_encode |   98 +++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 112 insertions(+), 2 deletions(-)

diff --git a/egs/librispeech_100h/conformer/local/spm_encode b/egs/librispeech_100h/conformer/local/spm_encode
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_encode
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", required=True,
+                        help="sentencepiece model to use for encoding")
+    parser.add_argument("--inputs", nargs="+", default=['-'],
+                        help="input files to filter/encode")
+    parser.add_argument("--outputs", nargs="+", default=['-'],
+                        help="path to save encoded outputs")
+    parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+    parser.add_argument("--min-len", type=int, metavar="N",
+                        help="filter sentence pairs with fewer than N tokens")
+    parser.add_argument("--max-len", type=int, metavar="N",
+                        help="filter sentence pairs with more than N tokens")
+    args = parser.parse_args()
+
+    assert len(args.inputs) == len(args.outputs), \
+        "number of input and output paths should match"
+
+    sp = spm.SentencePieceProcessor()
+    sp.Load(args.model)
+
+    if args.output_format == "piece":
+        def encode(l):
+            return sp.EncodeAsPieces(l)
+    elif args.output_format == "id":
+        def encode(l):
+            return list(map(str, sp.EncodeAsIds(l)))
+    else:
+        raise NotImplementedError
+
+    if args.min_len is not None or args.max_len is not None:
+        def valid(line):
+            return (
+                (args.min_len is None or len(line) >= args.min_len) and
+                (args.max_len is None or len(line) <= args.max_len)
+            )
+    else:
+        def valid(lines):
+            return True
+
+    with contextlib.ExitStack() as stack:
+        inputs = [
+            stack.enter_context(open(input, "r", encoding="utf-8"))
+            if input != "-" else sys.stdin
+            for input in args.inputs
+        ]
+        outputs = [
+            stack.enter_context(open(output, "w", encoding="utf-8"))
+            if output != "-" else sys.stdout
+            for output in args.outputs
+        ]
+
+        stats = {
+            "num_empty": 0,
+            "num_filtered": 0,
+        }
+
+        def encode_line(line):
+            line = line.strip()
+            if len(line) > 0:
+                line = encode(line)
+                if valid(line):
+                    return line
+                else:
+                    stats["num_filtered"] += 1
+            else:
+                stats["num_empty"] += 1
+            return None
+
+        for i, lines in enumerate(zip(*inputs), start=1):
+            enc_lines = list(map(encode_line, lines))
+            if not any(enc_line is None for enc_line in enc_lines):
+                for enc_line, output_h in zip(enc_lines, outputs):
+                    print(" ".join(enc_line), file=output_h)
+            if i % 10000 == 0:
+                print("processed {} lines".format(i), file=sys.stderr)
+
+        print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+        print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech_100h/conformer/local/spm_train b/egs/librispeech_100h/conformer/local/spm_train
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_train
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+    spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/librispeech_100h/conformer/run.sh b/egs/librispeech_100h/conformer/run.sh
index e879b5e..a855daa 100755
--- a/egs/librispeech_100h/conformer/run.sh
+++ b/egs/librispeech_100h/conformer/run.sh
@@ -100,8 +100,8 @@
     echo "<s>" >> ${dict}
     echo "</s>" >> ${dict}
     cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
-    spm_train --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
-    spm_encode --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${dict}
+    local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+    local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${dict}
     echo "<unk>" >> ${dict}
 fi
 

--
Gitblit v1.9.1