From f97e0eb9eee3f14c410ce905b73d0c83033dc1c9 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 28 四月 2023 15:17:38 +0800
Subject: [PATCH] update
---
funasr/main_funcs/collect_stats.py | 4 ++--
1 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/funasr/main_funcs/collect_stats.py b/funasr/main_funcs/collect_stats.py
index bacda8f..584b85a 100644
--- a/funasr/main_funcs/collect_stats.py
+++ b/funasr/main_funcs/collect_stats.py
@@ -17,12 +17,12 @@
from funasr.fileio.npy_scp import NpyScpWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
@torch.no_grad()
def collect_stats(
- model: AbsESPnetModel,
+ model: FunASRModel,
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
output_dir: Path,
--
Gitblit v1.9.1