From fdf74bb85cfe3dd0ce6cbaf51ec8d5b3ca3d2039 Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 09 二月 2023 17:18:43 +0800
Subject: [PATCH] update persian model recipe

---
 funasr/tasks/abs_task.py |    6 ++++++
 1 files changed, 6 insertions(+), 0 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5424f13..7899400 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -44,11 +44,13 @@
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
 from funasr.optimizers.sgd import SGD
+from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
 from funasr.schedulers.warmup_lr import WarmupLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -83,6 +85,7 @@
 
 optim_classes = dict(
     adam=torch.optim.Adam,
+    fairseq_adam=FairseqAdam,
     adamw=torch.optim.AdamW,
     sgd=SGD,
     adadelta=torch.optim.Adadelta,
@@ -149,6 +152,7 @@
     CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
     noamlr=NoamLR,
     warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
     cycliclr=torch.optim.lr_scheduler.CyclicLR,
     onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
     CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -1783,6 +1787,7 @@
             collate_fn,
             key_file: str = None,
             batch_size: int = 1,
+            fs: dict = None,
             dtype: str = np.float32,
             num_workers: int = 1,
             allow_variable_data_keys: bool = False,
@@ -1800,6 +1805,7 @@
         dataset = IterableESPnetDataset(
             data_path_and_name_and_type,
             float_dtype=dtype,
+            fs=fs,
             preprocess=preprocess_fn,
             key_file=key_file,
         )

--
Gitblit v1.9.1