jmwang66
2023-05-16 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906
funasr/main_funcs/collect_stats.py
@@ -17,12 +17,12 @@
from funasr.fileio.npy_scp import NpyScpWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
@torch.no_grad()
def collect_stats(
    model: AbsESPnetModel,
    model: FunASRModel,
    train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    output_dir: Path,