游雁
2024-04-30 9a8086bdf5c15d45087175f170ac33d574c0ed79
batch
3个文件已修改
9 ■■■■■ 已修改文件
funasr/bin/train.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/update_jsonl.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -205,7 +205,6 @@
            dataloader_tr, dataloader_val = dataloader.build_iter(
                epoch, data_split_i=data_split_i, start_step=trainer.start_step
            )
            trainer.start_step = 0
            trainer.train_epoch(
                model=model,
@@ -218,7 +217,9 @@
                writer=writer,
                data_split_i=data_split_i,
                data_split_num=dataloader.data_split_num,
                start_step=trainer.start_step,
            )
            trainer.start_step = 0
            torch.cuda.empty_cache()
funasr/datasets/audio_datasets/update_jsonl.py
@@ -50,8 +50,8 @@
    sample_num = len(waveform)
    source_len = int(sample_num / 16000 * 1000 / 10)
    source_len_old = data["source_len"]
    if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
        logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
    # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
    #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
    data["source_len"] = source_len
    data["source"] = wav_path
    jsonl_line = json.dumps(data, ensure_ascii=False)
funasr/train_utils/trainer.py
@@ -456,7 +456,7 @@
                    batch_num_epoch = len(dataloader_train)
                self.log(
                    epoch,
                    batch_idx,
                    batch_idx + kwargs.get("start_step", 0),
                    step_in_epoch=self.step_in_epoch,
                    batch_num_epoch=batch_num_epoch,
                    lr=lr,