From 6bb6af36ac4e3a3bea69b36c7022896e18f9a079 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:16:28 +0800
Subject: [PATCH] update data2vec pretrain
---
funasr/tasks/abs_task.py | 2 ++
1 files changed, 2 insertions(+), 0 deletions(-)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5424f13..83926f4 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -44,6 +44,7 @@
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
from funasr.optimizers.sgd import SGD
+from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.samplers.build_batch_sampler import BATCH_TYPES
from funasr.samplers.build_batch_sampler import build_batch_sampler
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
@@ -83,6 +84,7 @@
optim_classes = dict(
adam=torch.optim.Adam,
+ fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
--
Gitblit v1.9.1