| | |
| | | import logging |
| | | from dataclasses import dataclass |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |
| | |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | batch_size, frames, _ = speech.shape |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | |
| | | stats["acc"] = acc_att |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | stats["batch_size_x_frames"] = frames * batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |