aky15
2023-05-30 556429d8a98fcab3e77ce1a0d5e7163691643872
funasr/train/trainer.py
@@ -44,6 +44,7 @@
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
from funasr.utils.build_dataclass import build_dataclass
from funasr.utils.kwargs2args import kwargs2args
if torch.distributed.is_available():
    from torch.distributed import ReduceOp
@@ -620,6 +621,24 @@
                all_steps_are_invalid = False
                continue
            if iiter == 1 and summary_writer is not None:
                try:
                    args = kwargs2args(model.forward, batch)
                except (ValueError, TypeError):
                    logging.warning(
                        "inpect.signature() is failed for the model. "
                        "The graph can't be added for tensorboard."
                    )
                else:
                    try:
                        summary_writer.add_graph(model, args, use_strict_trace=False)
                    except Exception:
                        logging.warning(
                            "summary_writer.add_graph() is failed for the model. "
                            "The graph can't be added for tensorboard."
                        )
                    del args
            with autocast(scaler is not None):
                with reporter.measure_time("forward_time"):
                    retval = model(**batch)