From a1fee4e8adf80634ff11b5a597188f95202b214f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 18 六月 2023 13:05:48 +0800
Subject: [PATCH] Update README.md
---
funasr/train/trainer.py | 49 ++++++++++++++++++++++++++++++++++++++++---------
1 files changed, 40 insertions(+), 9 deletions(-)
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 405268a..f066909 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -3,7 +3,6 @@
"""Trainer module."""
import argparse
-from audioop import bias
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
@@ -40,11 +39,12 @@
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
from funasr.utils.build_dataclass import build_dataclass
+from funasr.utils.kwargs2args import kwargs2args
if torch.distributed.is_available():
from torch.distributed import ReduceOp
@@ -144,11 +144,23 @@
schedulers: Sequence[Optional[AbsScheduler]],
scaler: Optional[GradScaler],
ngpu: int = 0,
+ oss_bucket=None,
):
- states = torch.load(
- checkpoint,
- map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
- )
+ if oss_bucket is None:
+ if os.path.exists(checkpoint):
+ states = torch.load(
+ checkpoint,
+ map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
+ )
+
+ else:
+ return 0
+ else:
+ if oss_bucket.object_exists(checkpoint):
+ buffer = BytesIO(oss_bucket.get_object(checkpoint).read())
+ states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",)
+ else:
+ return 0
model.load_state_dict(states["model"])
reporter.load_state_dict(states["reporter"])
for optimizer, state in zip(optimizers, states["optimizers"]):
@@ -167,7 +179,7 @@
@classmethod
def run(
cls,
- model: AbsESPnetModel,
+ model: FunASRModel,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
train_iter_factory: AbsIterFactory,
@@ -207,15 +219,16 @@
else:
scaler = None
- if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
+ if trainer_options.resume:
cls.resume(
- checkpoint=output_dir / "checkpoint.pb",
+ checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb",
model=model,
optimizers=optimizers,
schedulers=schedulers,
reporter=reporter,
scaler=scaler,
ngpu=trainer_options.ngpu,
+ oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None,
)
start_epoch = reporter.get_epoch() + 1
@@ -608,6 +621,24 @@
all_steps_are_invalid = False
continue
+ if iiter == 1 and summary_writer is not None:
+ try:
+ args = kwargs2args(model.forward, batch)
+ except (ValueError, TypeError):
+ logging.warning(
+ "inpect.signature() is failed for the model. "
+ "The graph can't be added for tensorboard."
+ )
+ else:
+ try:
+ summary_writer.add_graph(model, args, use_strict_trace=False)
+ except Exception:
+ logging.warning(
+ "summary_writer.add_graph() is failed for the model. "
+ "The graph can't be added for tensorboard."
+ )
+ del args
+
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
retval = model(**batch)
--
Gitblit v1.9.1