| | |
| | | import torch |
| | | from torch.nn.parallel import data_parallel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.fileio.npy_scp import NpyScpWriter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.forward_adaptor import ForwardAdaptor |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | |
| | | @torch.no_grad() |
| | | def collect_stats( |
| | | model: AbsESPnetModel, |
| | | model: FunASRModel, |
| | | train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| | | valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| | | output_dir: Path, |
| | |
| | | This method is used before executing train(). |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | npy_scp_writers = {} |
| | | for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]): |