游雁
2023-02-10 59a791121fccd3c9ca177c4f6d33105a82d23ef3
funasr/bin/lm_calc_perplexity.py
@@ -56,7 +56,7 @@
    set_all_random_seed(seed)
    # 2. Build LM
    model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
    model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
    # Wrape model to make model.nll() data-parallel
    wrapped_model = ForwardAdaptor(model, "nll")
    wrapped_model.to(dtype=getattr(torch, dtype)).eval()
@@ -111,6 +111,7 @@
                    utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
                # Write PPL of each utts for debugging or analysis
                writer["utt2nll"][key] = str(-_nll)
                writer["utt2ppl"][key] = str(utt_ppl)
                writer["utt2ntokens"][key] = str(ntoken)