| | |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size, frames, _ = speech.shape |
| | | _, text_tokens = text.shape |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | |
| | | stats["batch_size_x_frames"] = frames * batch_size |
| | | stats["batch_size_real_frames"] = speech_lengths.sum().item() |
| | | stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] |
| | | stats["batch_size_x_tokens"] = text_tokens * batch_size |
| | | stats["batch_size_real_tokens"] = text_lengths.sum().item() |
| | | stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] |
| | | stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |