| | |
| | | from funasr.train.reporter import Reporter |
| | | from funasr.train.reporter import SubReporter |
| | | from funasr.utils.build_dataclass import build_dataclass |
| | | from funasr.utils.kwargs2args import kwargs2args |
| | | |
| | | if torch.distributed.is_available(): |
| | | from torch.distributed import ReduceOp |
| | |
| | | all_steps_are_invalid = False |
| | | continue |
| | | |
| | | if iiter == 1 and summary_writer is not None: |
| | | try: |
| | | args = kwargs2args(model.forward, batch) |
| | | except (ValueError, TypeError): |
| | | logging.warning( |
| | | "inpect.signature() is failed for the model. " |
| | | "The graph can't be added for tensorboard." |
| | | ) |
| | | else: |
| | | try: |
| | | summary_writer.add_graph(model, args, use_strict_trace=False) |
| | | except Exception: |
| | | logging.warning( |
| | | "summary_writer.add_graph() is failed for the model. " |
| | | "The graph can't be added for tensorboard." |
| | | ) |
| | | del args |
| | | |
| | | with autocast(scaler is not None): |
| | | with reporter.measure_time("forward_time"): |
| | | retval = model(**batch) |