From 6e5f075b1d9f189dd4e5400a0a228c670aa4696e Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 09 二月 2023 14:15:18 +0800
Subject: [PATCH] Merge pull request #80 from alibaba-damo-academy/dev
---
funasr/tasks/abs_task.py | 4 ++++
1 files changed, 4 insertions(+), 0 deletions(-)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 83926f4..7899400 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -50,6 +50,7 @@
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.warmup_lr import WarmupLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -151,6 +152,7 @@
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -1785,6 +1787,7 @@
collate_fn,
key_file: str = None,
batch_size: int = 1,
+ fs: dict = None,
dtype: str = np.float32,
num_workers: int = 1,
allow_variable_data_keys: bool = False,
@@ -1802,6 +1805,7 @@
dataset = IterableESPnetDataset(
data_path_and_name_and_type,
float_dtype=dtype,
+ fs=fs,
preprocess=preprocess_fn,
key_file=key_file,
)
--
Gitblit v1.9.1