liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/scama/utils.py
@@ -15,6 +15,7 @@
    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def apply_cmvn(inputs, mvn):
    device = inputs.device
    dtype = inputs.dtype
@@ -27,15 +28,13 @@
    return inputs.type(torch.float32)
def drop_and_add(inputs: torch.Tensor,
                 outputs: torch.Tensor,
                 training: bool,
                 dropout_rate: float = 0.1,
                 stoch_layer_coeff: float = 1.0):
def drop_and_add(
    inputs: torch.Tensor,
    outputs: torch.Tensor,
    training: bool,
    dropout_rate: float = 0.1,
    stoch_layer_coeff: float = 1.0,
):
    outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
    outputs *= stoch_layer_coeff
@@ -51,8 +50,8 @@
def proc_tf_vocab(vocab_path):
    with open(vocab_path, encoding="utf-8") as f:
        token_list = [line.rstrip() for line in f]
        if '<unk>' not in token_list:
            token_list.append('<unk>')
        if "<unk>" not in token_list:
            token_list.append("<unk>")
    return token_list
@@ -60,12 +59,12 @@
    token_list = proc_tf_vocab(vocab_path)
    with open(config_path, encoding="utf-8") as f:
        config = yaml.safe_load(f)
    config['token_list'] = token_list
    config["token_list"] = token_list
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
        yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
@@ -78,15 +77,13 @@
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
    """Safe-dump in yaml with no anchor/alias"""
    return yaml.dump(
        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
    )
    return yaml.dump(data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs)
if __name__ == '__main__':
if __name__ == "__main__":
    import sys
    config_path = sys.argv[1]
    vocab_path = sys.argv[2]
    output_dir = sys.argv[3]
    gen_config_for_tfmodel(config_path, vocab_path, output_dir)
    gen_config_for_tfmodel(config_path, vocab_path, output_dir)