1. bug fix:list(mean)和list(var),由于mean和var是numpy,导致写入到文件的格式错误,参考上面的话,大概率是list(mean.tolist()),其实外层list没有必要 (#2437)

2. 删除不必要的代码list(numpy_array.tolist())-->numpy_array.tolist()
3. 性能优化:replace没有必要,性能慢,性能为O(nm),n是源字符串长度,m是需要替换的字符串长度,虽然这里的m长度是1,且list转字符串的"[]",只有首尾有,直接拼接即可。

Co-authored-by: tiandiweizun <qq1274949542@163.com>
1个文件已修改
10 ■■■■■ 已修改文件
funasr/bin/compute_audio_cmvn.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/compute_audio_cmvn.py
@@ -88,8 +88,8 @@
        total_frames += fbank.shape[0]
    cmvn_info = {
        "mean_stats": list(mean_stats.tolist()),
        "var_stats": list(var_stats.tolist()),
        "mean_stats": mean_stats.tolist(),
        "var_stats": var_stats.tolist(),
        "total_frames": total_frames,
    }
    cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
@@ -118,11 +118,9 @@
            + str(dims)
            + "\n"
        )
        mean_str = str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
        fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
        fout.write("<LearnRateCoef> 0 [ " + " ".join([str(item) for item in mean]) + " ]\n")
        fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
        var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
        fout.write("<LearnRateCoef> 0 " + var_str + "\n")
        fout.write("<LearnRateCoef> 0 [ " + " ".join([str(item) for item in var]) + " ]\n")
        fout.write("</Nnet>" + "\n")