From a0f03bd2a87d97d47a1636bbe6f0855a43160331 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 五月 2024 19:48:50 +0800
Subject: [PATCH] Dev gzf deepspeed (#1732)
---
funasr/datasets/audio_datasets/espnet_samplers.py | 7
funasr/train_utils/trainer.py | 50 +
funasr/bin/train_ds.py | 241 ++++++++
README_zh.md | 28
examples/industrial_data_pretraining/sense_voice/demo.py | 5
README.md | 10
funasr/datasets/sense_voice_datasets/datasets.py | 24
funasr/models/sense_voice/model.py | 412 +++++++++++++++
funasr/utils/misc.py | 14
funasr/tokenizer/sentencepiece_tokenizer.py | 8
examples/industrial_data_pretraining/emotion2vec/demo.py | 8
funasr/bin/train.py | 3
funasr/train_utils/trainer_ds.py | 800 +++++++++++++++++++++++++++++
examples/industrial_data_pretraining/sense_voice/demo_fsmn.py | 1
14 files changed, 1,553 insertions(+), 58 deletions(-)
diff --git a/README.md b/README.md
index ba23f3f..e02b3e2 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,7 @@
<a name="whats-new"></a>
## What's new:
+- 2024/05/15锛歟motion recognition models are new supported. [emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)锛孾emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary)锛孾emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary). currently supports the following categories: 0: angry 1: happy 2: neutral 3: sad 4: unknown.
- 2024/05/15: Offline File Transcription Service 4.5, Offline File Transcription Service of English 1.6锛孯eal-time Transcription Service 1.10 released锛宎dapting to FunASR 1.0 model structure锛�([docs](runtime/readme.md))
- 2024/03/05锛欰dded the Qwen-Audio and Qwen-Audio-Chat large-scale audio-text multimodal models, which have topped multiple audio domain leaderboards. These models support speech dialogue, [usage](examples/industrial_data_pretraining/qwen_audio).
- 2024/03/05锛欰dded support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
@@ -84,10 +85,11 @@
| fsmn-vad <br> ( [猸怾(https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [馃](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M |
| fa-zh <br> ( [猸怾(https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [馃](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M |
| cam++ <br> ( [猸怾(https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [馃](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M |
-| Whisper-large-v2 <br> ([猸怾(https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [馃崁](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
-| Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [馃崁](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
-| Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py) [馃](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
-| Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
+| Whisper-large-v2 <br> ([猸怾(https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [馃崁](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
+| Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [馃崁](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
+| Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py) [馃](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
+| Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
+| emotion2vec+large <br> ([猸怾(https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [馃](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M |
diff --git a/README_zh.md b/README_zh.md
index 44f92e6..d3c9ff6 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -29,6 +29,7 @@
<a name="鏈�鏂板姩鎬�"></a>
## 鏈�鏂板姩鎬�
+- 2024/05/15锛氭柊澧炲姞鎯呮劅璇嗗埆妯″瀷锛孾emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)锛孾emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary)锛孾emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary)锛岃緭鍑烘儏鎰熺被鍒负锛氱敓姘�/angry锛屽紑蹇�/happy锛屼腑绔�/neutral锛岄毦杩�/sad銆�
- 2024/05/15: 涓枃绂荤嚎鏂囦欢杞啓鏈嶅姟 4.5銆佽嫳鏂囩绾挎枃浠惰浆鍐欐湇鍔� 1.6銆佷腑鏂囧疄鏃惰闊冲惉鍐欐湇鍔� 1.10 鍙戝竷锛岄�傞厤FunASR 1.0妯″瀷缁撴瀯锛涜缁嗕俊鎭弬闃�([閮ㄧ讲鏂囨。](runtime/readme_cn.md))
- 2024/03/05锛氭柊澧炲姞Qwen-Audio涓嶲wen-Audio-Chat闊抽鏂囨湰妯℃�佸ぇ妯″瀷锛屽湪澶氫釜闊抽棰嗗煙娴嬭瘯姒滃崟鍒锋锛屼腑鏀寔璇煶瀵硅瘽锛岃缁嗙敤娉曡 [绀轰緥](examples/industrial_data_pretraining/qwen_audio)銆�
- 2024/03/05锛氭柊澧炲姞Whisper-large-v3妯″瀷鏀寔锛屽璇█璇煶璇嗗埆/缈昏瘧/璇璇嗗埆锛屾敮鎸佷粠 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)浠撳簱涓嬭浇锛屼篃鏀寔浠� [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)浠撳簱涓嬭浇妯″瀷銆�
@@ -75,19 +76,20 @@
锛堟敞锛氣瓙 琛ㄧずModelScope妯″瀷浠撳簱锛岎煠� 琛ㄧずHuggingface妯″瀷浠撳簱锛岎煃�琛ㄧずOpenAI妯″瀷浠撳簱锛�
-| 妯″瀷鍚嶅瓧 | 浠诲姟璇︽儏 | 璁粌鏁版嵁 | 鍙傛暟閲� |
-|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:------------:|:----:|
-| paraformer-zh <br> ([猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [馃](https://huggingface.co/funasr/paraformer-tp) ) | 璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃� | 60000灏忔椂锛屼腑鏂� | 220M |
-| paraformer-zh-streaming <br> ( [猸怾(https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [馃](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 璇煶璇嗗埆锛屽疄鏃� | 60000灏忔椂锛屼腑鏂� | 220M |
-| paraformer-en <br> ( [猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [馃](https://huggingface.co/funasr/paraformer-en) ) | 璇煶璇嗗埆锛岄潪瀹炴椂 | 50000灏忔椂锛岃嫳鏂� | 220M |
-| conformer-en <br> ( [猸怾(https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [馃](https://huggingface.co/funasr/conformer-en) ) | 璇煶璇嗗埆锛岄潪瀹炴椂 | 50000灏忔椂锛岃嫳鏂� | 220M |
-| ct-punc <br> ( [猸怾(https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [馃](https://huggingface.co/funasr/ct-punc) ) | 鏍囩偣鎭㈠ | 100M锛屼腑鏂囦笌鑻辨枃 | 1.1B |
-| fsmn-vad <br> ( [猸怾(https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [馃](https://huggingface.co/funasr/fsmn-vad) ) | 璇煶绔偣妫�娴嬶紝瀹炴椂 | 5000灏忔椂锛屼腑鏂囦笌鑻辨枃 | 0.4M |
-| fa-zh <br> ( [猸怾(https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [馃](https://huggingface.co/funasr/fa-zh) ) | 瀛楃骇鍒椂闂存埑棰勬祴 | 50000灏忔椂锛屼腑鏂� | 38M |
-| cam++ <br> ( [猸怾(https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [馃](https://huggingface.co/funasr/campplus) ) | 璇磋瘽浜虹‘璁�/鍒嗗壊 | 5000灏忔椂 | 7.2M |
-| Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [馃崁](https://github.com/openai/whisper) ) | 璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃� | 澶氳瑷� | 1550 M |
-| Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py) [馃](https://huggingface.co/Qwen/Qwen-Audio) ) | 闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛堥璁粌锛� | 澶氳瑷� | 8B |
-| Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | 闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛坈hat鐗堟湰锛� | 澶氳瑷� | 8B |
+| 妯″瀷鍚嶅瓧 | 浠诲姟璇︽儏 | 璁粌鏁版嵁 | 鍙傛暟閲� |
+|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:|
+| paraformer-zh <br> ([猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [馃](https://huggingface.co/funasr/paraformer-tp) ) | 璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃� | 60000灏忔椂锛屼腑鏂� | 220M |
+| paraformer-zh-streaming <br> ( [猸怾(https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [馃](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 璇煶璇嗗埆锛屽疄鏃� | 60000灏忔椂锛屼腑鏂� | 220M |
+| paraformer-en <br> ( [猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [馃](https://huggingface.co/funasr/paraformer-en) ) | 璇煶璇嗗埆锛岄潪瀹炴椂 | 50000灏忔椂锛岃嫳鏂� | 220M |
+| conformer-en <br> ( [猸怾(https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [馃](https://huggingface.co/funasr/conformer-en) ) | 璇煶璇嗗埆锛岄潪瀹炴椂 | 50000灏忔椂锛岃嫳鏂� | 220M |
+| ct-punc <br> ( [猸怾(https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [馃](https://huggingface.co/funasr/ct-punc) ) | 鏍囩偣鎭㈠ | 100M锛屼腑鏂囦笌鑻辨枃 | 1.1B |
+| fsmn-vad <br> ( [猸怾(https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [馃](https://huggingface.co/funasr/fsmn-vad) ) | 璇煶绔偣妫�娴嬶紝瀹炴椂 | 5000灏忔椂锛屼腑鏂囦笌鑻辨枃 | 0.4M |
+| fa-zh <br> ( [猸怾(https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [馃](https://huggingface.co/funasr/fa-zh) ) | 瀛楃骇鍒椂闂存埑棰勬祴 | 50000灏忔椂锛屼腑鏂� | 38M |
+| cam++ <br> ( [猸怾(https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [馃](https://huggingface.co/funasr/campplus) ) | 璇磋瘽浜虹‘璁�/鍒嗗壊 | 5000灏忔椂 | 7.2M |
+| Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [馃崁](https://github.com/openai/whisper) ) | 璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃� | 澶氳瑷� | 1550 M |
+| Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py) [馃](https://huggingface.co/Qwen/Qwen-Audio) ) | 闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛堥璁粌锛� | 澶氳瑷� | 8B |
+| Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | 闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛坈hat鐗堟湰锛� | 澶氳瑷� | 8B |
+| emotion2vec+large <br> ([猸怾(https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [馃](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | 鎯呮劅璇嗗埆妯″瀷 | 40000灏忔椂锛�4绉嶆儏鎰熺被鍒� | 300M |
<a name="蹇�熷紑濮�"></a>
## 蹇�熷紑濮�
diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index f33dfee..71f69bb 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -6,14 +6,20 @@
from funasr import AutoModel
# model="iic/emotion2vec_base"
+# model="iic/emotion2vec_base_finetuned"
+# model="iic/emotion2vec_plus_seed"
+# model="iic/emotion2vec_plus_base"
+model = "iic/emotion2vec_plus_large"
+
model = AutoModel(
- model="iic/emotion2vec_base_finetuned",
+ model=model,
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="master",
# vad_kwargs={"max_single_segment_time": 2000},
)
wav_file = f"{model.model_path}/example/test.wav"
+
res = model.generate(
wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False
)
diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py
index ed583f0..5303999 100644
--- a/examples/industrial_data_pretraining/sense_voice/demo.py
+++ b/examples/industrial_data_pretraining/sense_voice/demo.py
@@ -7,8 +7,8 @@
model = AutoModel(
model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
- vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- vad_kwargs={"max_single_segment_time": 30000},
+ # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ # vad_kwargs={"max_single_segment_time": 30000},
)
@@ -21,6 +21,7 @@
"language": "auto",
"fp16": True,
"gain_event": True,
+ "beam_size": 5,
}
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions)
diff --git a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
index e063e1f..ce4bdf8 100644
--- a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
+++ b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
@@ -21,6 +21,7 @@
"language": "auto",
"fp16": True,
"gain_event": True,
+ "beam_size": 5,
}
res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 7695e51..c3556d1 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -223,6 +223,7 @@
torch.cuda.empty_cache()
+ trainer.start_data_split_i = 0
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
)
@@ -240,6 +241,8 @@
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
)
+ trainer.train_acc_avg = 0.0
+ trainer.train_loss_avg = 0.0
if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
new file mode 100644
index 0000000..e4db533
--- /dev/null
+++ b/funasr/bin/train_ds.py
@@ -0,0 +1,241 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
+import os
+import sys
+import torch
+import torch.nn as nn
+import hydra
+import logging
+import time
+import argparse
+from io import BytesIO
+
+from contextlib import nullcontext
+import torch.distributed as dist
+
+from omegaconf import DictConfig, OmegaConf
+from torch.cuda.amp import autocast, GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from funasr.train_utils.average_nbest_models import average_checkpoints
+
+from funasr.register import tables
+from funasr.optimizers import optim_classes
+from funasr.train_utils.trainer_ds import Trainer
+from funasr.schedulers import scheduler_classes
+from funasr.train_utils.initialize import initialize
+from funasr.download.download_from_hub import download_model
+from funasr.models.lora.utils import mark_only_lora_as_trainable
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
+from funasr import AutoModel
+
+try:
+ import deepspeed
+except:
+ deepspeed = None
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+ if kwargs.get("debug", False):
+ import pdb
+
+ pdb.set_trace()
+
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
+
+ main(**kwargs)
+
+
+def main(**kwargs):
+
+ # set random seed
+ set_all_random_seed(kwargs.get("seed", 0))
+ torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
+ torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+ # open tf32
+ torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
+
+ rank = int(os.environ.get("RANK", 0))
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+ if local_rank == 0:
+ tables.print()
+
+ use_ddp = world_size > 1
+ use_fsdp = kwargs.get("use_fsdp", False)
+ use_deepspeed = kwargs.get("use_deepspeed", False)
+ if use_deepspeed:
+ logging.info(f"use_deepspeed: {use_deepspeed}")
+ deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
+ elif use_ddp or use_fsdp:
+ logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
+ dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
+ torch.cuda.set_device(local_rank)
+
+ logging.info("Build model, frontend, tokenizer")
+ device = kwargs.get("device", "cuda")
+ kwargs["device"] = "cpu"
+ model = AutoModel(**kwargs)
+
+ # save config.yaml
+ if rank == 0:
+ prepare_model_dir(**kwargs)
+
+ # parse kwargs
+ kwargs = model.kwargs
+ kwargs["device"] = device
+ tokenizer = kwargs["tokenizer"]
+ frontend = kwargs["frontend"]
+ model = model.model
+ del kwargs["model"]
+
+ # freeze_param
+ freeze_param = kwargs.get("freeze_param", None)
+ if freeze_param is not None:
+ if "," in freeze_param:
+ freeze_param = eval(freeze_param)
+ if not isinstance(freeze_param, (list, tuple)):
+ freeze_param = (freeze_param,)
+ logging.info("freeze_param is not None: %s", freeze_param)
+ for t in freeze_param:
+ for k, p in model.named_parameters():
+ if k.startswith(t + ".") or k == t:
+ logging.info(f"Setting {k}.requires_grad = False")
+ p.requires_grad = False
+ if local_rank == 0:
+ logging.info(f"{model_summary(model)}")
+
+ trainer = Trainer(
+ rank=rank,
+ local_rank=local_rank,
+ world_size=world_size,
+ use_ddp=use_ddp,
+ use_fsdp=use_fsdp,
+ device=kwargs["device"],
+ output_dir=kwargs.get("output_dir", "./exp"),
+ **kwargs.get("train_conf"),
+ )
+
+ model = trainer.warp_model(model)
+
+ kwargs["device"] = next(model.parameters()).device
+ trainer.device = kwargs["device"]
+
+ # optim
+ logging.info("Build optim")
+ optim = kwargs.get("optim", "adam")
+ assert optim in optim_classes
+ optim_class = optim_classes.get(optim)
+ optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+
+ # scheduler
+ logging.info("Build scheduler")
+ scheduler = kwargs.get("scheduler", "warmuplr")
+ assert scheduler in scheduler_classes
+ scheduler_class = scheduler_classes.get(scheduler)
+ scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+
+ if use_deepspeed:
+ args = OmegaConf.create({"deepspeed_config": kwargs.get("deepspeed_config", "")})
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=args,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=scheduler,
+ model_parameters=model.parameters(),
+ )
+
+ # dataset
+ logging.info("Build dataloader")
+ dataloader_class = tables.dataloader_classes.get(
+ kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
+ )
+ dataloader = dataloader_class(**kwargs)
+ # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
+
+ scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+ scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
+
+ trainer.resume_checkpoint(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ )
+
+ tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
+ os.makedirs(tensorboard_dir, exist_ok=True)
+ try:
+ from tensorboardX import SummaryWriter
+
+ writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
+ except:
+ writer = None
+
+ dataloader_tr, dataloader_val = None, None
+ for epoch in range(trainer.start_epoch, trainer.max_epoch):
+ time1 = time.perf_counter()
+
+ for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+ dataloader_tr, dataloader_val = dataloader.build_iter(
+ epoch, data_split_i=data_split_i, start_step=trainer.start_step
+ )
+
+ trainer.train_epoch(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ dataloader_train=dataloader_tr,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer,
+ data_split_i=data_split_i,
+ data_split_num=dataloader.data_split_num,
+ start_step=trainer.start_step,
+ )
+ trainer.start_step = 0
+
+ torch.cuda.empty_cache()
+
+ trainer.start_data_split_i = 0
+ trainer.validate_epoch(
+ model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
+ )
+ scheduler.step()
+ trainer.step_in_epoch = 0
+ trainer.save_checkpoint(
+ epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
+ )
+
+ time2 = time.perf_counter()
+ time_escaped = (time2 - time1) / 3600.0
+ logging.info(
+ f"rank: {local_rank}, "
+ f"time_escaped_epoch: {time_escaped:.3f} hours, "
+ f"estimated to finish {trainer.max_epoch} "
+ f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
+ )
+ trainer.train_acc_avg = 0.0
+ trainer.train_loss_avg = 0.0
+
+ if trainer.rank == 0:
+ average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+
+ trainer.close()
+
+
+if __name__ == "__main__":
+ main_hydra()
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index e155cd7..528f593 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -146,10 +146,9 @@
start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank
rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
- if self.start_step > 0:
- logging.info(
- f"Warning, rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num_before: {end_idx-start_idx}, now: {len(rank_batches)}"
- )
+ logging.info(
+ f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
+ )
# Return an iterator over the batches for the current rank
return iter(rank_batches)
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index ee2f13d..690a1c5 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -53,6 +53,12 @@
self.prompt_ids_len = 0
self.retry = kwargs.get("retry", 5)
+ self.permute = False
+ from funasr.frontends.whisper_frontend import WhisperFrontend
+
+ if isinstance(self.frontend, WhisperFrontend):
+ self.permute = True
+
def get_source_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
@@ -92,7 +98,8 @@
if speech_lengths > self.batch_size:
continue
- speech = speech.permute(0, 2, 1)
+ if self.permute:
+ speech = speech.permute(0, 2, 1)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
@@ -100,8 +107,14 @@
task = item.get("prompt", "<|ASR|>")
text_language = item.get("text_language", "<|zh|>")
- prompt = f"{self.sos}{task}{text_language}"
- prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+ if isinstance(self.sos, str):
+ prompt = f"{self.sos}{task}{text_language}"
+ prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+ else:
+ prompt = f"{task}{text_language}"
+ prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+ prompt_ids = [self.sos] + prompt_ids
+
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
self.prompt_ids_len = prompt_ids_len
@@ -110,7 +123,10 @@
if target_ids_len > 200:
continue
- eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
+ if isinstance(self.eos, str):
+ eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
+ else:
+ eos = [self.eos]
ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos]
ids_lengths = len(ids)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 56e61e7..127d5a0 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -966,3 +966,415 @@
ibest_writer["text"][key[i]] = text
return results, meta_data
+
+
+@tables.register("model_classes", "SenseVoiceSANM")
+class SenseVoiceSANM(nn.Module):
+
+ def __init__(
+ self,
+ specaug: str = None,
+ specaug_conf: dict = None,
+ normalize: str = None,
+ normalize_conf: dict = None,
+ encoder: str = None,
+ encoder_conf: dict = None,
+ decoder: str = None,
+ decoder_conf: dict = None,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ decoder_class = tables.decoder_classes.get(decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+
+ self.specaug = specaug
+
+ self.encoder = encoder
+
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ self.error_calculator = None
+
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+ self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ):
+ target_mask = kwargs.get("target_mask", None)
+
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size, frames, _ = speech.shape
+ _, text_tokens = text.shape
+
+ if self.activation_checkpoint:
+ from torch.utils.checkpoint import checkpoint
+
+ encoder_out, encoder_out_lens = checkpoint(
+ self.encode, speech, speech_lengths, use_reentrant=False
+ )
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+ )
+
+ loss = loss_att
+ stats = {}
+ stats["acc"] = acc_att
+ stats["loss"] = torch.clone(loss.detach())
+ stats["batch_size"] = batch_size
+ stats["batch_size_x_frames"] = frames * batch_size
+ stats["batch_size_real_frames"] = speech_lengths.sum().item()
+ stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
+ stats["batch_size_x_tokens"] = text_tokens * batch_size
+ stats["batch_size_real_tokens"] = text_lengths.sum().item()
+ stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
+ stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ **kwargs,
+ ):
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ if isinstance(encoder_out, (tuple, list)):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ **kwargs,
+ ):
+ target_mask = kwargs.get("target_mask", None)
+ stats = {}
+
+ # 1. Forward decoder
+ ys_pad[ys_pad == -1] = 0
+ decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+ if isinstance(decoder_out, (list, tuple)):
+ decoder_out = decoder_out[0]
+
+ # 2. Compute attention loss
+ mask = torch.ones_like(ys_pad) * (-1)
+ ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
+ ys_pad_mask[ys_pad_mask == 0] = -1
+ loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+
+ with torch.no_grad():
+ preds = torch.argmax(decoder_out, -1)
+ acc_att = compute_accuracy(
+ preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
+ )
+
+ return loss_att, acc_att, None, None
+
+ def init_beam_search(
+ self,
+ **kwargs,
+ ):
+ from .search import BeamSearch
+
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ scorers.update(
+ decoder=self.decoder,
+ length_bonus=LengthBonus(self.vocab_size),
+ )
+
+ weights = dict(
+ decoder=1.0,
+ ctc=0.0,
+ lm=0.0,
+ ngram=0.0,
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 5),
+ weights=weights,
+ scorers=scorers,
+ sos=None,
+ eos=None,
+ vocab_size=self.vocab_size,
+ token_list=None,
+ pre_beam_score_key="full",
+ )
+
+ self.beam_search = beam_search
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ if not hasattr(self, "beam_search") or self.beam_search is None:
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ if frontend is None and not hasattr(self, "frontend"):
+ frontend_class = tables.frontend_classes.get("WhisperFrontend")
+ frontend = frontend_class(
+ n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
+ )
+ self.frontend = frontend
+ else:
+ frontend = frontend if frontend is not None else self.frontend
+
+ meta_data = {}
+ if (
+ isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+ ): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ )
+
+ if (
+ isinstance(kwargs.get("data_type", None), (list, tuple))
+ and len(kwargs.get("data_type", [])) > 1
+ ):
+ audio_sample_list, text_token_int_list = audio_sample_list
+ text_token_int = text_token_int_list[0]
+ else:
+ text_token_int = None
+
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(
+ audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+ )
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+ lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])[0, :, :]
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ DecodingOptions = kwargs.get("DecodingOptions", {})
+ task = DecodingOptions.get("task", "ASR")
+ if isinstance(task, str):
+ task = [task]
+ task = "".join([f"<|{x}|>" for x in task])
+ initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+
+ language = DecodingOptions.get("language", None)
+ language = None if language == "auto" else language
+
+ sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ sos_int = tokenizer.encode(sos, allowed_special="all")
+ eos = kwargs.get("model_conf").get("eos")
+ eos_int = tokenizer.encode(eos, allowed_special="all")
+ self.beam_search.sos = sos_int
+ self.beam_search.eos = eos_int[0]
+
+ # Paramterts for rich decoding
+ self.beam_search.emo_unk = tokenizer.encode(
+ DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+ )[0]
+ self.beam_search.emo_unk_score = 1
+ self.beam_search.emo_tokens = tokenizer.encode(
+ DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+ allowed_special="all",
+ )
+ self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+ self.beam_search.event_bg_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+ allowed_special="all",
+ )
+ self.beam_search.event_ed_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+ allowed_special="all",
+ )
+ self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
+ encoder_out, encoder_out_lens = self.encode(
+ speech[None, :, :].permute(0, 2, 1), speech_lengths
+ )
+
+ if text_token_int is not None:
+ i = 0
+ results = []
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"1best_recog"]
+
+ # 1. Forward decoder
+ ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
+ None, :
+ ]
+ ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
+ kwargs["device"]
+ )[None, :]
+ decoder_out = self.model.decoder(
+ x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+ )
+
+ token_int = decoder_out.argmax(-1)[0, :].tolist()
+ text = tokenizer.decode(token_int)
+
+ result_i = {"key": key[i], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ return results, meta_data
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0],
+ maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0),
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ b, n, d = encoder_out.size()
+ for i in range(b):
+
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # # remove blank symbol id, which is assumed to be 0
+ # token_int = list(
+ # filter(
+ # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+ # )
+ # )
+
+ # Change integer-ids to tokens
+ # token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.decode(token_int)
+
+ result_i = {"key": key[i], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+
+ return results, meta_data
diff --git a/funasr/tokenizer/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
index ff4b3a2..1be1b81 100644
--- a/funasr/tokenizer/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -20,6 +20,7 @@
# "TypeError: can't pickle SwigPyObject objects",
# when giving it as argument of "multiprocessing.Process()".
self.sp = None
+ self._build_sentence_piece_processor()
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
@@ -38,10 +39,13 @@
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
- def encode(self, line: str) -> List[int]:
+ def encode(self, line: str, **kwargs) -> List[int]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)
- def decode(self, line: List[int]):
+ def decode(self, line: List[int], **kwargs):
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
+
+ def get_vocab_size(self):
+ return self.sp.GetPieceSize()
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 01e2924..50f99f0 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -382,8 +382,6 @@
):
torch.cuda.empty_cache()
- time3 = time.perf_counter()
- speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
@@ -398,34 +396,28 @@
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
+ # loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
+
+ time3 = time.perf_counter()
+ speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
if self.use_fp16:
scaler.scale(loss).backward()
else:
loss.backward()
time4 = time.perf_counter()
- speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+ speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
self.train_loss_avg = (
- self.train_loss_avg * (self.step_in_epoch - 1) + loss.detach().cpu().item()
- ) / self.step_in_epoch
+ self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+ + loss.detach().cpu().item()
+ ) / (batch_idx + kwargs.get("start_step", 0) + 1)
if "acc" in stats:
self.train_acc_avg = (
- self.train_acc_avg * (self.step_in_epoch - 1)
+ self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
+ stats["acc"].detach().cpu().item()
- ) / self.step_in_epoch
- if self.use_ddp or self.use_fsdp:
- train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
- self.device
- )
- train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
- self.device
- )
- dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
- self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
- self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+ ) / (batch_idx + kwargs.get("start_step", 0) + 1)
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0:
@@ -454,8 +446,22 @@
scheduler.step()
# Clear gradients for the next accumulation stage
optim.zero_grad(set_to_none=True)
- total_time = f"{time.perf_counter() - time5:0.3f}"
+
+ if self.use_ddp or self.use_fsdp:
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
+ self.device
+ )
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
+ self.device
+ )
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
+ total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
time5 = time.perf_counter()
+
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
@@ -662,9 +668,9 @@
f"data_slice: {data_split_i}/{data_split_num}, "
f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "
- f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
- f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
- f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+ f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+ f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+ f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), "
f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
new file mode 100644
index 0000000..7188921
--- /dev/null
+++ b/funasr/train_utils/trainer_ds.py
@@ -0,0 +1,800 @@
+import math
+import os
+import time
+import torch
+import logging
+from tqdm import tqdm
+from datetime import datetime
+import torch.distributed as dist
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
+from pathlib import Path
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.recursive_op import recursive_average
+from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+try:
+ import wandb
+except:
+ wandb = None
+
+
+@contextmanager
+def maybe_autocast(enabled):
+ if enabled:
+ with autocast():
+ yield
+ else:
+ yield
+
+
+class Trainer:
+ """
+ A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+ and optionally resuming from a saved checkpoint.
+
+ Attributes:
+ max_epoch (int): Maximum number of epochs for training.
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+ output_dir (str): Directory where model checkpoints will be saved.
+ resume (str, optional): Path to a checkpoint to resume training from.
+ """
+
+ def __init__(
+ self,
+ rank=0,
+ local_rank=0,
+ world_size=1,
+ use_ddp: bool = False,
+ use_fsdp: bool = False,
+ use_fp16: bool = False,
+ use_deepspeed: bool = False,
+ output_dir: str = "./",
+ **kwargs,
+ ):
+ """
+ Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+
+ Args:
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+ **kwargs: Additional keyword arguments:
+ max_epoch (int): The maximum number of epochs for training.
+ output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+ resume (str, optional): The file path to a checkpoint to resume training from.
+ """
+ self.rank = kwargs.get("rank", 0)
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.use_ddp = use_ddp
+ self.use_fsdp = use_fsdp
+ self.use_deepspeed = use_deepspeed
+ self.device = kwargs.get("device", "cuda")
+
+ self.output_dir = output_dir
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir, exist_ok=True)
+ self.resume = kwargs.get("resume", True)
+ self.start_epoch = 0
+ self.max_epoch = kwargs.get("max_epoch", 100)
+
+ # self.kwargs = kwargs
+ self.log_interval = kwargs.get("log_interval", 50)
+ self.batch_total = 0
+ self.use_fp16 = use_fp16
+ self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+ self.validate_interval = kwargs.get("validate_interval", 5000)
+ self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
+ self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
+ self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
+ self.accum_grad = kwargs.get("accum_grad", 1)
+ self.grad_clip = kwargs.get("grad_clip", 10.0)
+ self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
+
+ self.train_acc_avg = 0.0
+ self.train_loss_avg = 0.0
+ self.val_acc_avg = 0.0
+ self.val_loss_avg = 0.0
+ self.best_acc_idx = 0
+ self.saved_ckpts = {}
+ self.step_or_epoch = -1
+ self.best_step_or_epoch = ""
+ self.val_acc_step_or_eoch = {}
+ self.val_loss_step_or_eoch = {}
+
+ self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+ self.start_data_split_i = 0
+ self.start_step = 0
+ self.step_in_epoch = 0
+ self.use_wandb = kwargs.get("use_wandb", False)
+ if self.use_wandb:
+ wandb.login(key=kwargs.get("wandb_token"))
+ wandb.init(
+ config=kwargs,
+ project=kwargs.get("wandb_project", "my_project"),
+ entity=kwargs.get("wandb_team", "my_team"),
+ name=kwargs.get("wandb_exp_name", "my_exp"),
+ dir=output_dir,
+ job_type="training",
+ reinit=True,
+ )
+
+ def save_checkpoint(
+ self,
+ epoch,
+ step=None,
+ model=None,
+ optim=None,
+ scheduler=None,
+ scaler=None,
+ step_in_epoch=None,
+ **kwargs,
+ ):
+ """
+ Saves a checkpoint containing the model's state, the optimizer's state,
+ and the scheduler's state at the end of the given epoch. This method is
+ intended to be called at the end of each epoch to save the training progress.
+
+ Args:
+ epoch (int): The epoch number at which the checkpoint is being saved.
+ """
+
+ step_in_epoch = None if step is None else step_in_epoch
+ if self.rank == 0:
+ logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+ # self.step_or_epoch += 1
+ state = {
+ "epoch": epoch,
+ "state_dict": model.state_dict(),
+ "optimizer": optim.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "saved_ckpts": self.saved_ckpts,
+ "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
+ "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "best_step_or_epoch": self.best_step_or_epoch,
+ "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+ "step": step,
+ "step_in_epoch": step_in_epoch,
+ "data_split_i": kwargs.get("data_split_i", 0),
+ "data_split_num": kwargs.get("data_split_num", 1),
+ "batch_total": self.batch_total,
+ "train_loss_avg": kwargs.get("train_loss_avg", 0),
+ "train_acc_avg": kwargs.get("train_acc_avg", 0),
+ }
+ step = step_in_epoch
+ if hasattr(model, "module"):
+ state["state_dict"] = model.module.state_dict()
+
+ if scaler:
+ state["scaler_state"] = scaler.state_dict()
+ # Create output directory if it does not exist
+ os.makedirs(self.output_dir, exist_ok=True)
+ if step is None:
+ ckpt_name = f"model.pt.ep{epoch}"
+ else:
+ ckpt_name = f"model.pt.ep{epoch}.{step}"
+ filename = os.path.join(self.output_dir, ckpt_name)
+ torch.save(state, filename)
+
+ logging.info(f"\nCheckpoint saved to {filename}\n")
+ latest = Path(os.path.join(self.output_dir, f"model.pt"))
+ torch.save(state, latest)
+ if self.best_step_or_epoch == "":
+ self.best_step_or_epoch = ckpt_name
+
+ if self.avg_keep_nbest_models_type == "acc":
+ if (
+ self.val_acc_step_or_eoch[ckpt_name]
+ >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+ ):
+ self.best_step_or_epoch = ckpt_name
+ best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
+ torch.save(state, best_ckpt)
+ logging.info(
+ f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ )
+ else:
+ logging.info(
+ f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ )
+ elif self.avg_keep_nbest_models_type == "loss":
+ if (
+ self.val_loss_step_or_eoch[ckpt_name]
+ <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+ ):
+ self.best_step_or_epoch = ckpt_name
+ best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
+ torch.save(state, best_ckpt)
+ logging.info(
+ f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ )
+ else:
+ logging.info(
+ f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ )
+ else:
+ print("Undo")
+ self.saved_ckpts[ckpt_name] = getattr(
+ self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+ )[ckpt_name]
+ if self.keep_nbest_models > 0:
+ if len(self.saved_ckpts) > self.keep_nbest_models:
+ if self.avg_keep_nbest_models_type == "acc":
+ key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+ else:
+ key = max(self.saved_ckpts, key=self.saved_ckpts.get)
+ if key in self.saved_ckpts:
+ del self.saved_ckpts[key]
+ filename = os.path.join(self.output_dir, key)
+ logging.info(f"Delete: {filename}")
+ if os.path.exists(filename):
+ os.remove(filename)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ def resume_checkpoint(
+ self,
+ model=None,
+ optim=None,
+ scheduler=None,
+ scaler=None,
+ ):
+ """
+ Resumes training from a checkpoint at the given file path.
+ Loads the model's state, the optimizer's state, and the scheduler's state.
+
+ Args:
+ resume_path (str): The file path to the checkpoint to resume from.
+ """
+ if self.resume:
+ ckpt = os.path.join(self.output_dir, "model.pt")
+ if os.path.isfile(ckpt):
+ checkpoint = torch.load(ckpt, map_location="cpu")
+ self.start_epoch = checkpoint["epoch"]
+ # self.model.load_state_dict(checkpoint['state_dict'])
+ src_state = checkpoint["state_dict"]
+ dst_state = model.state_dict()
+ for k in dst_state.keys():
+ if not k.startswith("module.") and "module." + k in src_state.keys():
+ k_ddp = "module." + k
+ elif k.startswith("module.") and "module." + k not in src_state.keys():
+ k_ddp = k.replace("module.", "", 1)
+ else:
+ k_ddp = k
+ if k_ddp in src_state.keys():
+ dst_state[k] = src_state[k_ddp]
+ else:
+ print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+
+ model.load_state_dict(dst_state)
+ optim.load_state_dict(checkpoint["optimizer"])
+ scheduler.load_state_dict(checkpoint["scheduler"])
+ if scaler is not None and "scaler_state" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler_state"])
+
+ self.saved_ckpts = checkpoint["saved_ckpts"]
+ self.val_acc_step_or_eoch = (
+ checkpoint["val_acc_step_or_eoch"]
+ if "val_acc_step_or_eoch" in checkpoint
+ else {}
+ )
+ self.val_loss_step_or_eoch = (
+ checkpoint["val_loss_step_or_eoch"]
+ if "val_loss_step_or_eoch" in checkpoint
+ else {}
+ )
+ self.best_step_or_epoch = (
+ checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
+ )
+ self.start_data_split_i = (
+ checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
+ )
+ self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+ self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+ self.start_step = 0 if self.start_step is None else self.start_step
+ self.step_in_epoch = (
+ checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
+ )
+ self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
+ print(checkpoint["train_acc_avg"])
+ self.train_acc_avg = (
+ checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
+ )
+ self.train_loss_avg = (
+ checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
+ )
+ model.to(self.device)
+ print(f"Checkpoint loaded successfully from '{ckpt}'")
+ else:
+ print(f"No checkpoint found at '{ckpt}', does not resume status!")
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ def train_epoch(
+ self,
+ model=None,
+ optim=None,
+ scheduler=None,
+ scaler=None,
+ dataloader_train=None,
+ dataloader_val=None,
+ epoch=None,
+ writer=None,
+ **kwargs,
+ ):
+ """
+ Defines the training process for a single epoch with gradient accumulation.
+ Args:
+ epoch (int): The current epoch number.
+ """
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
+ model.train()
+
+ # Set the number of steps for gradient accumulation
+ accum_grad = self.accum_grad
+ # Initialize the gradient accumulation
+ optim.zero_grad()
+ speed_stats = {}
+
+ iterator_stop = torch.tensor(0).to(self.device)
+
+ dataloader_train.batch_sampler.set_epoch(epoch)
+ time_beg = time.perf_counter()
+ time5 = time_beg
+ for batch_idx, batch in enumerate(dataloader_train):
+ if self.use_ddp or self.use_fsdp:
+ dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+ self.batch_total += 1
+ self.step_in_epoch += 1
+ time1 = time.perf_counter()
+ speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
+
+ batch = to_device(batch, self.device)
+
+ my_context = nullcontext
+ if self.use_ddp or self.use_fsdp:
+ my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
+ with my_context():
+ time2 = time.perf_counter()
+ loss_dict = {}
+ self.forward_step(model, batch, loss_dict=loss_dict)
+
+ time3 = time.perf_counter()
+ speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+ self.backward_step(model, scaler, loss_dict=loss_dict)
+
+ time4 = time.perf_counter()
+ speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
+
+ # self.train_loss_avg = (
+ # self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+ # + loss.detach().cpu().item()
+ # ) / (batch_idx + kwargs.get("start_step", 0) + 1)
+ # if "acc" in stats:
+ # self.train_acc_avg = (
+ # self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
+ # + stats["acc"].detach().cpu().item()
+ # ) / (batch_idx + kwargs.get("start_step", 0) + 1)
+
+ self.update_step(model, optim, scheduler, scaler, loss_dict)
+ # Perform an optimizer step only after accumulating enough gradients
+
+ if self.step_in_epoch % self.validate_interval == 0:
+ self.validate_epoch(
+ model=model,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer,
+ step=batch_idx + 1,
+ step_in_epoch=self.step_in_epoch,
+ )
+
+ if self.step_in_epoch % self.save_checkpoint_interval == 0:
+ self.save_checkpoint(
+ epoch,
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ step=batch_idx + 1,
+ step_in_epoch=self.step_in_epoch,
+ data_split_i=kwargs.get("data_split_i", 0),
+ data_split_num=kwargs.get("data_split_num", 1),
+ train_loss_avg=self.train_loss_avg,
+ train_acc_avg=self.train_acc_avg,
+ )
+
+ time_beg = time.perf_counter()
+ else:
+ if self.use_ddp or self.use_fsdp:
+ iterator_stop.fill_(1)
+ dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ iterator_stop = torch.tensor(0).to(self.device)
+
+ def forward_step(self, model, batch, loss_dict={}):
+ with maybe_autocast(self.use_fp16):
+ retval = model(**batch)
+
+ if (
+ self.reset_gpu_cache
+ and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
+ ):
+ torch.cuda.empty_cache()
+
+ loss, stats, weight = retval
+ stats = {k: v for k, v in stats.items() if v is not None}
+ # if self.use_ddp or self.use_fsdp:
+ # # Apply weighted averaging for loss and stats
+ # loss = (loss * weight.type(loss.dtype)).sum()
+ # # if distributed, this method can also apply all_reduce()
+ # # stats, weight = recursive_average(stats, weight, distributed=True)
+ # if self.use_ddp or self.use_fsdp:
+ # dist.all_reduce(weight, op=dist.ReduceOp.SUM)
+ # # Now weight is summation over all workers
+ # loss /= weight.sum() # shape:[1] -> shape:[]
+ # # Multiply world_size because DistributedDataParallel
+ # # automatically normalizes the gradient by world_size.
+ # loss *= self.world_size
+ # loss *= self.world_size
+ # Scale the loss since we're not updating for every mini-batch
+
+ loss_dict["loss"] = loss
+ loss_dict["stats"] = stats
+ loss_dict["weight"] = weight
+
+ def backward_step(self, model, scaler, loss_dict={}):
+ loss = loss_dict["loss"]
+
+ if self.use_deepspeed:
+ scaled_loss = model.backward(loss)
+ else:
+ loss = loss / self.accum_grad
+ if self.use_fp16:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
+ if (batch_idx + 1) % self.accum_grad == 0:
+ # Perform gradient clipping if it is set
+ if self.grad_clip > 0:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ model.parameters(),
+ max_norm=self.grad_clip,
+ norm_type=self.grad_clip_type,
+ )
+ if not torch.isfinite(grad_norm):
+ logging.warning(f"The grad norm is {grad_norm}. Skipping updating the model.")
+ optim.zero_grad() # Reset gradients
+ return
+
+ # Execute an optimization step (update model parameters)
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ if self.use_fp16:
+ scaler.step(optim)
+ scaler.update()
+ else:
+ optim.step()
+ scheduler.step()
+ # Clear gradients for the next accumulation stage
+ optim.zero_grad(set_to_none=True)
+
+ if self.use_ddp or self.use_fsdp:
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
+ self.device
+ )
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
+ self.device
+ )
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
+ total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
+ time5 = time.perf_counter()
+
+ speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+
+ speed_stats["total_time"] = total_time
+ lr = scheduler.get_last_lr()[0]
+ batch_num_epoch = 1
+ if hasattr(dataloader_train, "__len__"):
+ batch_num_epoch = len(dataloader_train)
+ self.log(
+ epoch,
+ batch_idx,
+ log_step=batch_idx + kwargs.get("start_step", 0),
+ step_in_epoch=self.step_in_epoch,
+ batch_num_epoch=batch_num_epoch,
+ lr=lr,
+ loss=loss.detach().cpu().item(),
+ speed_stats=speed_stats,
+ stats=stats,
+ writer=writer,
+ tag="train",
+ data_split_i=kwargs.get("data_split_i", 0),
+ data_split_num=kwargs.get("data_split_num", 1),
+ )
+
+ def validate_epoch(
+ self,
+ model=None,
+ dataloader_val=None,
+ epoch=None,
+ writer=None,
+ **kwargs,
+ ):
+ """
+ Defines the validation process for a single epoch.
+ Should be implemented with the actual model validation steps.
+
+ Args:
+ epoch (int): The current epoch number.
+ """
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
+ model.eval()
+
+ with torch.no_grad():
+
+ speed_stats = {}
+ time5 = time.perf_counter()
+ iterator_stop = torch.tensor(0).to(self.device)
+ dataloader_val.batch_sampler.set_epoch(epoch)
+ for batch_idx, batch in enumerate(dataloader_val):
+ if self.use_ddp or self.use_fsdp:
+ dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+ time1 = time.perf_counter()
+ speed_stats["data_load"] = f"{time1 - time5:0.3f}"
+ batch = to_device(batch, self.device)
+ time2 = time.perf_counter()
+ retval = model(**batch)
+ time3 = time.perf_counter()
+ speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+ loss, stats, weight = retval
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if self.use_ddp or self.use_fsdp:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+ # if distributed, this method can also apply all_reduce()
+ # stats, weight = recursive_average(stats, weight, distributed=True)
+ if self.use_ddp or self.use_fsdp:
+ dist.all_reduce(weight, op=dist.ReduceOp.SUM)
+ # Now weight is summation over all workers
+ loss /= weight.sum() # shape:[1] -> shape:[]
+ # Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= self.world_size
+ # Scale the loss since we're not updating for every mini-batch
+ loss = loss
+ time4 = time.perf_counter()
+
+ self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
+ batch_idx + 1
+ )
+ if "acc" in stats:
+ self.val_acc_avg = (
+ self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+ ) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+ self.device
+ )
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+ self.device
+ )
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+ time5 = time.perf_counter()
+ batch_num_epoch = 1
+ if hasattr(dataloader_val, "__len__"):
+ batch_num_epoch = len(dataloader_val)
+ self.log(
+ epoch,
+ batch_idx,
+ batch_num_epoch=batch_num_epoch,
+ lr=0.0,
+ loss=loss.detach().cpu().item(),
+ speed_stats=speed_stats,
+ stats=stats,
+ writer=writer,
+ tag="val",
+ )
+
+ else:
+ if self.use_ddp or self.use_fsdp:
+ iterator_stop.fill_(1)
+ dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+
+ if kwargs.get("step_in_epoch", None) is None:
+ ckpt_name = f"model.pt.ep{epoch}"
+ else:
+ ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
+ self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
+ self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+ model.train()
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ iterator_stop = torch.tensor(0).to(self.device)
+
+ def log(
+ self,
+ epoch=0,
+ batch_idx=0,
+ step_in_epoch=0,
+ batch_num_epoch=-1,
+ lr=0.0,
+ loss=0.0,
+ speed_stats=None,
+ stats=None,
+ writer=None,
+ tag="train",
+ data_split_i=0,
+ data_split_num=1,
+ log_step=None,
+ **kwargs,
+ ):
+
+ if (batch_idx + 1) % self.log_interval == 0:
+ batch_idx = log_step if log_step is not None else batch_idx
+ gpu_info = (
+ "GPU, memory: usage: {:.3f} GB, "
+ "peak: {:.3f} GB, "
+ "cache: {:.3f} GB, "
+ "cache_peak: {:.3f} GB".format(
+ torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
+ torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
+ torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
+ )
+ )
+
+ loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+ acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
+ description = (
+ f"{tag}, "
+ f"rank: {self.rank}, "
+ f"epoch: {epoch}/{self.max_epoch}, "
+ f"data_slice: {data_split_i}/{data_split_num}, "
+ f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
+ f"(loss_avg_rank: {loss:.3f}), "
+ f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+ f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+ f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
+ f"(lr: {lr:.3e}), "
+ f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
+ f"{speed_stats}, "
+ f"{gpu_info}"
+ )
+ logging.info(description)
+
+ description_dict = {
+ f"rank{self.rank}_loss/{tag}": loss,
+ f"rank{self.rank}_lr/{tag}": lr,
+ }
+
+ if writer is not None:
+ writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+ writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
+ for key, var in stats.items():
+ writer.add_scalar(
+ f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
+ )
+ description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
+ for key, var in speed_stats.items():
+ writer.add_scalar(
+ f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
+ )
+ description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+ if self.use_wandb and wandb is not None:
+ wandb.log(
+ description_dict,
+ setp=self.batch_total,
+ )
+
+ def close(self, writer=None):
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ if writer is not None:
+ writer.close()
+
+ if self.use_ddp or self.use_fsdp:
+ torch.distributed.destroy_process_group()
+
+ def warp_model(self, model, **kwargs):
+
+ if self.use_deepspeed:
+ from deepspeed.runtime.zero.stage_1_and_2 import (
+ estimate_zero2_model_states_mem_needs_all_live,
+ )
+ from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
+ from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+ # NOTE(xcsong): look in detail how the memory estimator API works:
+ # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
+ if int(os.environ.get("RANK", 0)) == 0:
+ logging.info("Estimating model states memory needs (zero2)...")
+ estimate_zero2_model_states_mem_needs_all_live(
+ model,
+ num_gpus_per_node=local_world_size,
+ num_nodes=world_size // local_world_size,
+ )
+ logging.info("Estimating model states memory needs (zero3)...")
+ estimate_zero3_model_states_mem_needs_all_live(
+ model,
+ num_gpus_per_node=local_world_size,
+ num_nodes=world_size // local_world_size,
+ )
+ device = None # Init device later
+ pass # Init DeepSpeed later
+
+ elif self.use_ddp:
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ model = model.cuda(local_rank)
+ model = DDP(
+ model,
+ device_ids=[local_rank],
+ find_unused_parameters=kwargs.get("train_conf", {}).get(
+ "find_unused_parameters", False
+ ),
+ )
+ # elif self.use_fsdp:
+ # # model = FSDP(model).cuda(local_rank)
+ #
+ # def custom_auto_wrap_policy(
+ # module: nn.Module,
+ # recurse: bool,
+ # nonwrapped_numel: int,
+ # # Additional custom arguments
+ # min_num_params: int = int(1e8),
+ # ) -> bool:
+ # # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+ # is_large = unwrapped_params >= min_num_params
+ # requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+ # return is_large and requires_grad_uniform
+ #
+ # # Configure a custom `min_num_params`
+ # my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+ # torch.cuda.set_device(local_rank)
+ # model = FSDP(
+ # model,
+ # auto_wrap_policy=custom_auto_wrap_policy,
+ # mixed_precision=None,
+ # device_id=torch.cuda.current_device(),
+ # )
+ else:
+ model = model.to(device=kwargs.get("device", "cuda"))
+
+ return model
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index 9f01955..4613cb3 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -70,14 +70,16 @@
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
- print(kwargs)
+ logging.info(f"kwargs: {kwargs}")
logging.info("config.yaml is saved to: %s", yaml_file)
- # model_path = kwargs.get("model_path")
- # if model_path is not None:
- # config_json = os.path.join(model_path, "configuration.json")
- # if os.path.exists(config_json):
- # shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
+ model_path = kwargs.get("model_path", None)
+ if model_path is not None:
+ config_json = os.path.join(model_path, "configuration.json")
+ if os.path.exists(config_json):
+ shutil.copy(
+ config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
+ )
def extract_filename_without_extension(file_path):
--
Gitblit v1.9.1