From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train
---
funasr/main_funcs/collect_stats.py | 4 ++--
1 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/funasr/main_funcs/collect_stats.py b/funasr/main_funcs/collect_stats.py
index bacda8f..584b85a 100644
--- a/funasr/main_funcs/collect_stats.py
+++ b/funasr/main_funcs/collect_stats.py
@@ -17,12 +17,12 @@
from funasr.fileio.npy_scp import NpyScpWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
@torch.no_grad()
def collect_stats(
- model: AbsESPnetModel,
+ model: FunASRModel,
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
output_dir: Path,
--
Gitblit v1.9.1