| | |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build LM |
| | | model, train_args = LMTask.build_model_from_file(train_config, model_file, device) |
| | | model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device) |
| | | # Wrape model to make model.nll() data-parallel |
| | | wrapped_model = ForwardAdaptor(model, "nll") |
| | | wrapped_model.to(dtype=getattr(torch, dtype)).eval() |
| | |
| | | utt_ppl = log_base ** (_nll / ntoken / np.log(log_base)) |
| | | |
| | | # Write PPL of each utts for debugging or analysis |
| | | writer["utt2nll"][key] = str(-_nll) |
| | | writer["utt2ppl"][key] = str(utt_ppl) |
| | | writer["utt2ntokens"][key] = str(ntoken) |
| | | |