huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
funasr/modules/streaming_utils/utils.py
@@ -1,6 +1,7 @@
import os
import torch
from torch.nn import functional as F
import yaml
import numpy as np
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
@@ -45,3 +46,46 @@
      outputs += inputs
   return outputs
def proc_tf_vocab(vocab_path):
   with open(vocab_path, encoding="utf-8") as f:
      token_list = [line.rstrip() for line in f]
      if '<unk>' not in token_list:
         token_list.append('<unk>')
   return token_list
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
   token_list = proc_tf_vocab(vocab_path)
   with open(config_path, encoding="utf-8") as f:
      config = yaml.safe_load(f)
   config['token_list'] = token_list
   if not os.path.exists(output_dir):
      os.makedirs(output_dir)
   with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
      yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
class NoAliasSafeDumper(yaml.SafeDumper):
   # Disable anchor/alias in yaml because looks ugly
   def ignore_aliases(self, data):
      return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
   """Safe-dump in yaml with no anchor/alias"""
   return yaml.dump(
      data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
   )
if __name__ == '__main__':
   import sys
   config_path = sys.argv[1]
   vocab_path = sys.argv[2]
   output_dir = sys.argv[3]
   gen_config_for_tfmodel(config_path, vocab_path, output_dir)