From 2ed3f46f40cd5da19cad76a97b52c46b2869d5ed Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 一月 2024 18:42:37 +0800
Subject: [PATCH] funasr1.0 finetune
---
funasr/train_utils/trainer.py | 22 ++++++++++++++++++++--
funasr/bin/train.py | 5 +++--
funasr/auto/auto_model.py | 2 +-
examples/industrial_data_pretraining/paraformer/finetune.sh | 21 ++++++++++++---------
setup.py | 1 +
funasr/datasets/audio_datasets/samplers.py | 2 +-
6 files changed, 38 insertions(+), 15 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index 6dca09f..93cce73 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -1,14 +1,17 @@
-# download model
-local_path_root=../modelscope_models
-mkdir -p ${local_path_root}
-local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
-git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
+## download model
+#local_path_root=../modelscope_models
+#mkdir -p ${local_path_root}
+#local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+#git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
python funasr/bin/train.py \
-+model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
-+token_list="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
-+train_data_set_list="data/list/audio_datasets.jsonl" \
++model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model_revision="v2.0.2" \
++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+++dataset_conf.batch_size=2 \
+++dataset_conf.batch_type="example" \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
-+device="cpu"
\ No newline at end of file
++device="cpu" \
++debug="true"
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index ffb56a5..0bc5e0e 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -146,7 +146,7 @@
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
- # kwargs["batch_size"] = 1
+ kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index ef0d205..7ae687e 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -40,8 +40,7 @@
def main(**kwargs):
- # preprocess_config(kwargs)
- # import pdb; pdb.set_trace()
+
# set random seed
tables.print()
set_all_random_seed(kwargs.get("seed", 0))
@@ -169,6 +168,8 @@
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
+ output_dir=kwargs.get("output_dir", "./exp"),
+ resume=kwargs.get("resume", True),
**kwargs.get("train_conf"),
)
trainer.run()
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index bc71b28..4af35e9 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -20,7 +20,7 @@
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
- self.batch_size = batch_size
+ self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 0f0acc2..da346c3 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -5,6 +5,8 @@
from tqdm import tqdm
import torch.distributed as dist
from contextlib import nullcontext
+# from torch.utils.tensorboard import SummaryWriter
+from tensorboardX import SummaryWriter
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
@@ -34,6 +36,7 @@
local_rank,
use_ddp=False,
use_fsdp=False,
+ output_dir: str="./",
**kwargs):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
@@ -55,7 +58,7 @@
self.scheduler = scheduler
self.dataloader_train = dataloader_train
self.dataloader_val = dataloader_val
- self.output_dir = kwargs.get('output_dir', './')
+ self.output_dir = output_dir
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
@@ -77,6 +80,10 @@
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
+
+ os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
+ self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
+
def _save_checkpoint(self, epoch):
"""
@@ -128,6 +135,8 @@
if self.rank == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
+
+ self.writer.close()
def _train_epoch(self, epoch):
"""
@@ -215,7 +224,16 @@
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
-
+ if self.writer:
+ self.writer.add_scalar('Loss/train', loss.item(),
+ epoch*len(self.dataloader_train) + batch_idx)
+ for key, var in stats.items():
+ self.writer.add_scalar(f'{key}/train', var.item(),
+ epoch * len(self.dataloader_train) + batch_idx)
+ for key, var in speed_stats.items():
+ self.writer.add_scalar(f'{key}/train', eval(var),
+ epoch * len(self.dataloader_train) + batch_idx)
+
# if batch_idx == 2:
# break
pbar.close()
diff --git a/setup.py b/setup.py
index f7e6ee6..84a958c 100644
--- a/setup.py
+++ b/setup.py
@@ -46,6 +46,7 @@
"train": [
"editdistance",
"wandb",
+ "pip install tensorboardX",
],
# all: The modules should be optionally installled due to some reason.
# Please consider moving them to "install" occasionally
--
Gitblit v1.9.1