From f964078e9ce3a3257f57e3d3dd4f95a7f98941a0 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 17 五月 2023 19:10:07 +0800
Subject: [PATCH] Merge branch 'dev_infer' of https://github.com/alibaba-damo-academy/FunASR into dev_infer
---
funasr/train/trainer.py | 5 ++---
1 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 405268a..4052448 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -3,7 +3,6 @@
"""Trainer module."""
import argparse
-from audioop import bias
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
@@ -40,7 +39,7 @@
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
@@ -167,7 +166,7 @@
@classmethod
def run(
cls,
- model: AbsESPnetModel,
+ model: FunASRModel,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
train_iter_factory: AbsIterFactory,
--
Gitblit v1.9.1