From 4482bbcbb912f699a4faecaafd65aa15aec64a51 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 11:49:30 +0800
Subject: [PATCH] train (#1521)
---
funasr/datasets/large_datasets/utils/filter.py | 26
funasr/train_utils/trainer.py | 501 ++++++----
funasr/register.py | 1
funasr/datasets/large_datasets/utils/clipping.py | 40
examples/industrial_data_pretraining/paraformer/finetune_from_local.sh | 18
examples/industrial_data_pretraining/paraformer/finetune.sh | 11
funasr/datasets/large_datasets/build_dataloader.py | 97 ++
funasr/datasets/large_datasets/utils/padding.py | 74 +
funasr/bin/train.py | 215 ++--
funasr/datasets/large_datasets/abs_iter_factory.py | 9
funasr/datasets/large_datasets/datapipes/batch.py | 213 ++++
funasr/datasets/large_datasets/datapipes/filter.py | 24
funasr/datasets/large_datasets/datapipes/map.py | 22
funasr/datasets/large_datasets/utils/__init__.py | 0
funasr/datasets/large_datasets/utils/low_frame_rate.py | 30
funasr/datasets/large_datasets/collate_fn.py | 196 ++++
funasr/datasets/large_datasets/__init__.py | 0
funasr/datasets/dataloader_entry.py | 38
funasr/datasets/large_datasets/dataset.py | 274 +++++
funasr/utils/misc.py | 23
/dev/null | 242 -----
funasr/datasets/large_datasets/utils/hotword_utils.py | 33
funasr/datasets/large_datasets/utils/tokenize.py | 95 ++
funasr/datasets/audio_datasets/samplers.py | 550 ++++++-----
funasr/datasets/large_datasets/datapipes/__init__.py | 0
25 files changed, 1,924 insertions(+), 808 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index e1273da..9fc8bf0 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -36,9 +36,14 @@
++model_revision="v2.0.4" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
-++dataset_conf.batch_size=32 \
-++dataset_conf.batch_type="example" \
+++dataset_conf.batch_size=20000 \
+++dataset_conf.batch_type="token" \
++dataset_conf.num_workers=4 \
-++train_conf.max_epoch=20 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=10 \
+++train_conf.resume=false \
+++train_conf.validate_interval=15 \
+++train_conf.save_checkpoint_interval=15 \
+++train_conf.keep_nbest_models=50 \
++optim_conf.lr=0.0002 \
++output_dir="${output_dir}" &> ${log_file}
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh b/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh
index dcbdf77..b883908 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh
@@ -59,13 +59,17 @@
--config-name "${config_name}" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
+++dataset_conf.batch_size=20000 \
+++dataset_conf.batch_type="token" \
+++dataset_conf.num_workers=4 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=10 \
+++train_conf.resume=false \
+++train_conf.validate_interval=15 \
+++train_conf.save_checkpoint_interval=15 \
+++train_conf.keep_nbest_models=50 \
+++optim_conf.lr=0.0002 \
+++init_param="${init_param}" \
++tokenizer_conf.token_list="${tokens}" \
++frontend_conf.cmvn_file="${cmvn_file}" \
-++dataset_conf.batch_size=32 \
-++dataset_conf.batch_type="example" \
-++dataset_conf.num_workers=4 \
-++train_conf.max_epoch=20 \
-++optim_conf.lr=0.0002 \
-++train_conf.log_interval=1 \
-++init_param="${init_param}" \
++output_dir="${output_dir}" &> ${log_file}
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 3c93371..3f97f9e 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -4,15 +4,23 @@
import os
import sys
import torch
+import torch.nn as nn
import hydra
import logging
+import time
import argparse
from io import BytesIO
+
+from contextlib import nullcontext
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
+from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from funasr.train_utils.average_nbest_models import average_checkpoints
from funasr.register import tables
from funasr.optimizers import optim_classes
@@ -23,10 +31,8 @@
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
-# from funasr.tokenizer.build_tokenizer import build_tokenizer
-# from funasr.tokenizer.token_id_converter import TokenIDConverter
-# from funasr.tokenizer.funtoken import build_tokenizer
-
+from funasr.utils.misc import prepare_model_dir
+from funasr import AutoModel
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
@@ -43,7 +49,6 @@
def main(**kwargs):
- print(kwargs)
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
@@ -56,65 +61,29 @@
tables.print()
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
- use_fsdp = kwargs.get("use_fsdp", None)
+ use_fsdp = kwargs.get("use_fsdp", False)
+ # use_ddp = False if use_fsdp else use_fsdp
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
+
+ logging.info("Build model, frontend, tokenizer")
+ device = kwargs.get("device", "cuda")
+ kwargs["device"] = "cpu"
+ model = AutoModel(**kwargs)
+
# save config.yaml
if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
- os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
- yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
- OmegaConf.save(config=kwargs, f=yaml_file)
- logging.info("config.yaml is saved to: %s", yaml_file)
-
- tokenizer = kwargs.get("tokenizer", None)
- if tokenizer is not None:
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
- tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
- kwargs["tokenizer"] = tokenizer
+ prepare_model_dir(**kwargs)
- # build frontend if frontend is none None
- frontend = kwargs.get("frontend", None)
- if frontend is not None:
- frontend_class = tables.frontend_classes.get(frontend)
- frontend = frontend_class(**kwargs["frontend_conf"])
- kwargs["frontend"] = frontend
- kwargs["input_size"] = frontend.output_size()
-
-
- # build model
- model_class = tables.model_classes.get(kwargs["model"])
- vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
- vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
-
-
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- if not isinstance(init_param, (list, tuple)):
- init_param = (init_param,)
- logging.info("init_param is not None: %s", init_param)
- for p in init_param:
- if os.path.exists(p):
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- path=p,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
- oss_bucket=kwargs.get("oss_bucket", None),
- scope_map=kwargs.get("scope_map", []),
- excludes=kwargs.get("excludes", None),
- )
- else:
- logging.info(f"Checkpoint does not exist, init randomly: {p}")
- elif kwargs.get("init", None):
- initialize(model, kwargs.get("init", "kaiming_normal"))
- else:
- print("No initialize method")
-
+ # parse kwargs
+ kwargs = model.kwargs
+ kwargs["device"] = device
+ tokenizer = kwargs["tokenizer"]
+ frontend = kwargs["frontend"]
+ model = model.model
+ del kwargs["model"]
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
@@ -135,18 +104,42 @@
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
- model = FSDP(model).cuda(local_rank)
+ # model = FSDP(model).cuda(local_rank)
+
+ def custom_auto_wrap_policy(
+ module: nn.Module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ # Additional custom arguments
+ min_num_params: int = int(1e8),
+ ) -> bool:
+ # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+ is_large = unwrapped_params >= min_num_params
+ requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+ return is_large and requires_grad_uniform
+
+ # Configure a custom `min_num_params`
+ my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+ torch.cuda.set_device(local_rank)
+ model = FSDP(model,
+ auto_wrap_policy=custom_auto_wrap_policy,
+ mixed_precision=None,
+ device_id=torch.cuda.current_device())
else:
model = model.to(device=kwargs.get("device", "cuda"))
-
+
+ logging.info(f"{model}")
+ kwargs["device"] = next(model.parameters()).device
# optim
+ logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
+ logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
@@ -154,45 +147,75 @@
# dataset
- dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
- dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
+ logging.info("Build dataloader")
+ dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
+ dataloader_tr, dataloader_val = dataloader_class(**kwargs)
- # dataloader
- batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
- batch_sampler_val = None
- if batch_sampler is not None:
- batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
- batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
- batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
- dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
- collate_fn=dataset_tr.collator,
- batch_sampler=batch_sampler,
- num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
- pin_memory=True)
-
- dataloader_val = torch.utils.data.DataLoader(dataset_val,
- collate_fn=dataset_val.collator,
- batch_sampler=batch_sampler_val,
- num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
- pin_memory=True)
- trainer = Trainer(
- model=model,
- optim=optim,
- scheduler=scheduler,
- dataloader_train=dataloader_tr,
- dataloader_val=dataloader_val,
- local_rank=local_rank,
- use_ddp=use_ddp,
- use_fsdp=use_fsdp,
- output_dir=kwargs.get("output_dir", "./exp"),
- resume=kwargs.get("resume", True),
- **kwargs.get("train_conf"),
- )
- trainer.run()
-
+ trainer = Trainer(local_rank=local_rank,
+ use_ddp=use_ddp,
+ use_fsdp=use_fsdp,
+ device=kwargs["device"],
+ output_dir=kwargs.get("output_dir", "./exp"),
+ **kwargs.get("train_conf"),
+ )
+
+ scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+ scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
+
+ trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+ tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
+ os.makedirs(tensorboard_dir, exist_ok=True)
+ try:
+ from tensorboardX import SummaryWriter
+ writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+ except:
+ writer = None
+
if use_ddp or use_fsdp:
- torch.distributed.destroy_process_group()
+ context = Join([model])
+ else:
+ context = nullcontext()
+
+ for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
+ time1 = time.perf_counter()
+ with context:
+
+ trainer.train_epoch(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ dataloader_train=dataloader_tr,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer
+ )
+ scheduler.step()
+ trainer.validate_epoch(
+ model=model,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer
+ )
+
+
+ trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+ time2 = time.perf_counter()
+ time_escaped = (time2 - time1) / 3600.0
+ logging.info(
+ f"rank: {local_rank}, "
+ f"time_escaped_epoch: {time_escaped:.3f} hours, "
+ f"estimated to finish {trainer.max_epoch} "
+ f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
+
+
+ if trainer.rank == 0:
+ average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
+
+ trainer.close()
+
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
deleted file mode 100644
index 89f5db3..0000000
--- a/funasr/bin/train_llm.py
+++ /dev/null
@@ -1,241 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-
-import os
-import sys
-import torch
-import torch.nn as nn
-import hydra
-import logging
-import time
-import argparse
-from io import BytesIO
-
-from contextlib import nullcontext
-import torch.distributed as dist
-from collections.abc import Sequence
-from omegaconf import DictConfig, OmegaConf
-from torch.cuda.amp import autocast, GradScaler
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.algorithms.join import Join
-from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-from funasr.train_utils.average_nbest_models import average_checkpoints
-
-from funasr.register import tables
-from funasr.optimizers import optim_classes
-from funasr.train_utils.trainer_llm import Trainer
-from funasr.schedulers import scheduler_classes
-from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
-from funasr.models.lora.utils import mark_only_lora_as_trainable
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-# from funasr.tokenizer.build_tokenizer import build_tokenizer
-# from funasr.tokenizer.token_id_converter import TokenIDConverter
-# from funasr.tokenizer.funtoken import build_tokenizer
-from funasr import AutoModel
-
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
- if kwargs.get("debug", False):
- import pdb; pdb.set_trace()
-
- assert "model" in kwargs
- if "model_conf" not in kwargs:
- logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
- kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-
-
- main(**kwargs)
-
-
-def main(**kwargs):
-
- # set random seed
- set_all_random_seed(kwargs.get("seed", 0))
- torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
- torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
- torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-
- local_rank = int(os.environ.get('LOCAL_RANK', 0))
- if local_rank == 0:
- tables.print()
- # Check if we are using DDP or FSDP
- use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
- use_fsdp = kwargs.get("use_fsdp", False)
- # use_ddp = False if use_fsdp else use_fsdp
- if use_ddp or use_fsdp:
- dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
- torch.cuda.set_device(local_rank)
-
- logging.info("Build model, frontend, tokenizer")
- device = kwargs.get("device", "cuda")
- kwargs["device"] = "cpu"
- model = AutoModel(**kwargs)
-
-
- # save config.yaml
- if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
- os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
- yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
- OmegaConf.save(config=kwargs, f=yaml_file)
- print(kwargs)
- logging.info("config.yaml is saved to: %s", yaml_file)
-
- # parse kwargs
- kwargs = model.kwargs
- kwargs["device"] = device
- tokenizer = kwargs["tokenizer"]
- frontend = kwargs["frontend"]
- model = model.model
- del kwargs["model"]
-
- # freeze_param
- freeze_param = kwargs.get("freeze_param", None)
- if freeze_param is not None:
- freeze_param = eval(freeze_param)
- if isinstance(freeze_param, Sequence):
- freeze_param = (freeze_param,)
- logging.info("freeze_param is not None: %s", freeze_param)
- for t in freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
-
-
- if use_ddp:
- model = model.cuda(local_rank)
- model = DDP(model, device_ids=[local_rank],
- find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
- elif use_fsdp:
- # model = FSDP(model).cuda(local_rank)
-
- def custom_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- # Additional custom arguments
- min_num_params: int = int(1e8),
- ) -> bool:
- # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
- is_large = unwrapped_params >= min_num_params
- requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
- return is_large and requires_grad_uniform
-
- # Configure a custom `min_num_params`
- my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
- torch.cuda.set_device(local_rank)
- model = FSDP(model,
- auto_wrap_policy=custom_auto_wrap_policy,
- mixed_precision=None,
- device_id=torch.cuda.current_device())
- else:
- model = model.to(device=kwargs.get("device", "cuda"))
-
- logging.info(f"{model}")
- kwargs["device"] = next(model.parameters()).device
-
- # optim
- logging.info("Build optim")
- optim = kwargs.get("optim", "adam")
- assert optim in optim_classes
- optim_class = optim_classes.get(optim)
- optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-
- # scheduler
- logging.info("Build scheduler")
- scheduler = kwargs.get("scheduler", "warmuplr")
- assert scheduler in scheduler_classes
- scheduler_class = scheduler_classes.get(scheduler)
- scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
-
-
- # dataset
- logging.info("Build dataloader")
- dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
- dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
-
- # dataloader
- batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
- batch_sampler_val = None
- if batch_sampler is not None:
- batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
- batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
- batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
-
- dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler)
- dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val)
-
- trainer = Trainer(local_rank=local_rank,
- use_ddp=use_ddp,
- use_fsdp=use_fsdp,
- device=kwargs["device"],
- output_dir=kwargs.get("output_dir", "./exp"),
- **kwargs.get("train_conf"),
- )
-
- scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
- scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
-
- trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
- tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
- os.makedirs(tensorboard_dir, exist_ok=True)
- try:
- from tensorboardX import SummaryWriter
- writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
- except:
- writer = None
-
- if use_ddp or use_fsdp:
- context = Join([model])
- else:
- context = nullcontext()
-
- for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
- time1 = time.perf_counter()
- with context:
-
- trainer.train_epoch(
- model=model,
- optim=optim,
- scheduler=scheduler,
- scaler=scaler,
- dataloader_train=dataloader_tr,
- dataloader_val=dataloader_val,
- epoch=epoch,
- writer=writer
- )
- scheduler.step()
- trainer.validate_epoch(
- model=model,
- dataloader_val=dataloader_val,
- epoch=epoch,
- writer=writer
- )
-
-
- trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
- time2 = time.perf_counter()
- time_escaped = (time2 - time1) / 3600.0
- logging.info(
- f"rank: {local_rank}, "
- f"time_escaped_epoch: {time_escaped:.3f} hours, "
- f"estimated to finish {trainer.max_epoch} "
- f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
-
-
- if trainer.rank == 0:
- average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
-
- trainer.close()
-
-
-
-
-if __name__ == "__main__":
- main_hydra()
\ No newline at end of file
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index 914e776..a0ff4b6 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -1,277 +1,327 @@
import torch
import numpy as np
import logging
+import math
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+from torch.utils.data import BatchSampler, Sampler
import torch.distributed as dist
from funasr.register import tables
-@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
-class BatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = False,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 5000)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
-
- def __len__(self):
- return (self.total_samples-1) // self.batch_size + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- sample_len_cur = source_len + target_len
-
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- if self.batch_type != 'example':
- max_token_padding *= max_token_cur
- if max_token_padding <= self.batch_size:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- yield batch
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
-
-
@tables.register("batch_sampler_classes", "BatchSampler")
+@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
+@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
+@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
-class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = True,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 1500)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- self.rank = rank
- self.world_size = world_size
-
- def __len__(self):
- return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- batch_size_total = self.batch_size * self.world_size
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- # if iter == iter_num -1 and self.drop_last:
- # continue
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- sample_len_cur = source_len + target_len
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- # if self.batch_type != 'example':
- # max_token_padding *= max_token_cur
- if max_token_padding <= batch_size_total:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
- yield batch_rank
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
-
-
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
-class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = True,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+ dataloader_args = {}
+ batch_type = kwargs.get("batch_type", "example")
+ if batch_type == "example":
+ batch_sampler = CustomDistributedBatchSampler(dataset, **kwargs)
- self.drop_last = drop_last
- self.pre_idx = -1
+ else:
+ batch_sampler = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
+
+ dataloader_args["batch_sampler"] = batch_sampler
+ dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+ dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+
+ return dataloader_args
+
+class CustomDistributedBatchSampler(Sampler):
+ def __init__(self, dataset,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ **kwargs,
+ ):
+
+ try:
+ rank = dist.get_rank()
+ num_replicas = dist.get_world_size()
+ except:
+ rank = 0
+ num_replicas = 1
+ self.rank = rank
+ self.num_replicas = num_replicas
self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 1500)
- self.shuffle_idx = np.arange(self.total_samples)
+ self.batch_size = batch_size
+ self.is_training = is_training
self.shuffle = shuffle and is_training
+ self.drop_last = drop_last
+ # self.total_size = len(dataset)
+ if self.drop_last:
+ self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+ else:
+ self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+ self.num_samples = int(self.total_size // self.num_replicas)
+ self.epoch = 0
+ self.max_token_length = kwargs.get("max_token_length", None)
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+
+ def __iter__(self):
+ # Generate a list of indices
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ # Add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
+
+ assert len(indices) == self.total_size
+
+ # Subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ # Filter out indices with length greater than the max length, if provided
+ if self.max_token_length is not None:
+ filtered_indices = []
+ for idx in indices:
+ source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+ if source_len <= self.max_token_length:
+ filtered_indices.append(idx)
+ indices = filtered_indices
+
+ # Now that we have only the indices for this replica, chunk them into batches
+ batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
+
+ # Drop the last batch if it's not full and drop_last is True
+ if self.drop_last and len(batches[-1]) != self.batch_size:
+ batches = batches[:-1]
+
+ return iter(batches)
+
+ def __len__(self):
+
+ return self.num_samples // self.batch_size
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+class CustomDistributedBufferBatchSampler(Sampler):
+ def __init__(self, dataset,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ sort_size: int = 1024,
+ **kwargs,
+ ):
try:
rank = dist.get_rank()
- world_size = dist.get_world_size()
+ num_replicas = dist.get_world_size()
except:
rank = 0
- world_size = 1
+ num_replicas = 1
self.rank = rank
- self.world_size = world_size
-
- def __len__(self):
- return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
+ self.num_replicas = num_replicas
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.shuffle = shuffle and is_training
+ self.drop_last = drop_last
+ # self.total_size = len(dataset)
+ if self.drop_last:
+ self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+ else:
+ self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+ self.num_samples = int(self.total_size // self.num_replicas)
+ self.epoch = 0
+ self.max_token_length = kwargs.get("max_token_length", None)
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+ self.sort_size = sort_size
def __iter__(self):
-
- batch_size_total = self.batch_size * self.world_size
+ # Generate a list of indices
if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
- batch_list_all_rank = []
- batch_list_cur = []
- max_token = 0
- num_sample = 0
+ # Add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- # if iter == iter_num - 1 and self.drop_last:
- # continue
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- sample_len_cur = source_len + target_len
-
- datalen_with_index.append([idx, sample_len_cur])
+ assert len(indices) == self.total_size
+
+ # Subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ # Filter out indices with length greater than the max length, if provided
+ if self.max_token_length is not None:
+ filtered_indices = []
+ for idx in indices:
+ source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+ if source_len <= self.max_token_length:
+ filtered_indices.append(idx)
+ indices = filtered_indices
+
+ # Buffer sorting logic
+ sorted_batches = []
+ buffer = []
+
+ for idx in indices:
+ buffer.append(idx)
+ if len(buffer) >= self.sort_size:
+ # Sort the buffer based on some criteria, e.g., dataset sample length
+ buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+ sorted_batches.extend(self._create_batches_from_buffer(buffer))
+ buffer = []
+
+ # Handle the remaining items in the buffer
+ if buffer:
+ buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+ sorted_batches.extend(self._create_batches_from_buffer(buffer))
+
+ return iter(sorted_batches)
+
+ def _create_batches_from_buffer(self, buffer):
+ # Function to convert the sorted buffer into batches
+ batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
+ if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
+ batched_buffer = batched_buffer[:-1]
+ return batched_buffer
+
+ def __len__(self):
+
+ return self.num_samples // self.batch_size
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+class CustomDistributedDynamicBatchSampler(Sampler):
+ def __init__(self, dataset,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ **kwargs,
+ ):
+
+ try:
+ rank = dist.get_rank()
+ num_replicas = dist.get_world_size()
+ except:
+ rank = 0
+ num_replicas = 1
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.shuffle = shuffle and is_training
+ self.drop_last = drop_last
+
+ self.total_size = len(self.dataset)
+ # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
+ self.epoch = 0
+
+ def __iter__(self):
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+
+ batches = []
+ batch = []
+ max_len_in_batch = 0
+ current_batch_length = 0
+
+ for idx in indices:
+ sample_length = self.dataset.get_source_len(idx)
+ potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
+ len(batch) + 1)
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for ii, item in enumerate(datalen_with_index_sort):
- is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
-
- if self.batch_type != 'example':
- max_token_padding *= max_token_cur
- if len(batch_list_all_rank) < self.world_size:
-
- if max_token_padding <= self.batch_size:
- batch_list_cur.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- batch_list_all_rank.append(batch_list_cur)
- batch_list_cur = []
- else:
- batch_rank = batch_list_all_rank[self.rank]
- yield batch_rank
- batch_list_all_rank = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
+ if potential_batch_length <= self.batch_size:
+ batch.append(idx)
+ if sample_length > max_len_in_batch:
+ max_len_in_batch = sample_length
+ current_batch_length = max_len_in_batch * len(batch)
+ else:
+ batches.append(batch)
+ batch = [idx]
+ max_len_in_batch = sample_length
+ current_batch_length = max_len_in_batch
+
+ # Add the last batch if it's not empty and we're not dropping it
+ if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
+ batches.append(batch)
+
+ return iter(batches)
+
+ def __len__(self):
+
+ return 1
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class DistributedSamplerWarp(BatchSampler):
+ def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False):
+ if num_replicas is None:
+ if not torch.distributed.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = torch.distributed.get_world_size()
+ if rank is None:
+ if not torch.distributed.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = torch.distributed.get_rank()
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.shuffle = shuffle
+ self.drop_last = drop_last
+
+ # Create an instance of the DistributedSampler
+ self.sampler = DistributedSampler(
+ self.dataset,
+ num_replicas=self.num_replicas,
+ rank=self.rank,
+ shuffle=self.shuffle
+ )
+
+ # Call BatchSampler's constructor
+ super().__init__(self.sampler, batch_size, drop_last)
+
+ def __iter__(self):
+ # If we shuffle, we need to call the set_epoch method
+ if self.shuffle:
+ self.sampler.set_epoch(self.epoch)
+
+ # Generate batch indices using the parent class
+ return super().__iter__()
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
new file mode 100644
index 0000000..21e3834
--- /dev/null
+++ b/funasr/datasets/dataloader_entry.py
@@ -0,0 +1,38 @@
+
+import logging
+import torch
+
+from funasr.register import tables
+
+@tables.register("dataloader_classes", "DataloaderMapStyle")
+def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
+ # dataset
+ logging.info("Build dataloader")
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
+ dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+ dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
+
+ # dataloader
+ batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+ batch_sampler_val = None
+ if batch_sampler is not None:
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+ batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
+
+ dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler)
+ dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val)
+
+ return dataloader_tr, dataloader_val
+
+
+@tables.register("dataloader_classes", "DataloaderIterable")
+def DataloaderIterable(frontend=None, tokenizer=None, **kwargs):
+ logging.info("Build dataloader")
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "LargeDataset"))
+ dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer,
+ is_training=True, **kwargs.get("dataset_conf"))
+ dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
+ is_training=False, **kwargs.get("dataset_conf"))
+
+ return dataset_tr, dataset_val
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/__init__.py b/funasr/datasets/large_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/large_datasets/__init__.py
diff --git a/funasr/datasets/large_datasets/abs_iter_factory.py b/funasr/datasets/large_datasets/abs_iter_factory.py
new file mode 100644
index 0000000..36e4dd2
--- /dev/null
+++ b/funasr/datasets/large_datasets/abs_iter_factory.py
@@ -0,0 +1,9 @@
+from abc import ABC
+from abc import abstractmethod
+from typing import Iterator
+
+
+class AbsIterFactory(ABC):
+ @abstractmethod
+ def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
+ raise NotImplementedError
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
new file mode 100644
index 0000000..8a255f9
--- /dev/null
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -0,0 +1,97 @@
+import logging
+from pathlib import Path
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import sentencepiece as spm
+from torch.utils.data import DataLoader
+
+from funasr.datasets.large_datasets.dataset import Dataset
+from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+
+from funasr.register import tables
+
+def read_symbol_table(symbol_table_file):
+ if isinstance(symbol_table_file, str):
+ symbol_table = {}
+ with open(symbol_table_file, "r", encoding="utf8") as fin:
+ for i, line in enumerate(fin):
+ char = line.strip()
+ symbol_table[char] = i
+ else:
+ assert isinstance(symbol_table_file, list)
+ symbol_table = {}
+ for i, char in enumerate(symbol_table_file):
+ symbol_table[char] = i
+ return symbol_table
+
+
+def load_seg_dict(seg_dict_file):
+ seg_dict = {}
+ assert isinstance(seg_dict_file, str)
+ with open(seg_dict_file, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+
+class SentencepiecesTokenizer(AbsTokenizer):
+ def __init__(self, model: Union[Path, str]):
+ self.model = str(model)
+ self.sp = None
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(model="{self.model}")'
+
+ def _build_sentence_piece_processor(self):
+ if self.sp is None:
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.load(self.model)
+
+ def text2tokens(self, line: str) -> List[str]:
+ self._build_sentence_piece_processor()
+ return self.sp.EncodeAsPieces(line)
+
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ self._build_sentence_piece_processor()
+ return self.sp.DecodePieces(list(tokens))
+
+@tables.register("dataset_classes", "LargeDataset")
+class LargeDataLoader(AbsIterFactory):
+ def __init__(self, args, mode="train"):
+ symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
+ if hasattr(args, "token_list") and args.token_list is not None:
+ symbol_table = read_symbol_table(args.token_list)
+ if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
+ seg_dict = load_seg_dict(args.seg_dict_file)
+ if hasattr(args, "punc_list") and args.punc_list is not None:
+ punc_dict = read_symbol_table(args.punc_list)
+ if hasattr(args, "bpemodel") and args.bpemodel is not None:
+ bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
+ self.dataset_conf = args.dataset_conf
+ if "frontend_conf" not in args:
+ self.frontend_conf = None
+ else:
+ self.frontend_conf = args.frontend_conf
+ self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
+ logging.info("dataloader config: {}".format(self.dataset_conf))
+ batch_mode = self.dataset_conf.get("batch_mode", "padding")
+ data_list = args.train_data_file if mode == "train" else args.valid_data_file
+ self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
+ self.dataset_conf, self.frontend_conf,
+ speed_perturb=self.speed_perturb if mode == "train" else None,
+ mode=mode, batch_mode=batch_mode)
+
+ def build_iter(self, epoch, shuffle=True):
+ self.dataset.set_epoch(epoch)
+ data_loader = DataLoader(self.dataset,
+ batch_size=None,
+ pin_memory=True,
+ num_workers=self.dataset_conf.get("num_workers", 8))
+ return data_loader
diff --git a/funasr/datasets/large_datasets/collate_fn.py b/funasr/datasets/large_datasets/collate_fn.py
new file mode 100644
index 0000000..aff25a8
--- /dev/null
+++ b/funasr/datasets/large_datasets/collate_fn.py
@@ -0,0 +1,196 @@
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
+
+
+class CommonCollateFn:
+ """Functor class of common_collate_fn()"""
+
+ def __init__(
+ self,
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+ max_sample_size=None
+ ):
+ self.float_pad_value = float_pad_value
+ self.int_pad_value = int_pad_value
+ self.not_sequence = set(not_sequence)
+ self.max_sample_size = max_sample_size
+
+ def __repr__(self):
+ return (
+ f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+ f"int_pad_value={self.float_pad_value})"
+ )
+
+ def __call__(
+ self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+ ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ return common_collate_fn(
+ data,
+ float_pad_value=self.float_pad_value,
+ int_pad_value=self.int_pad_value,
+ not_sequence=self.not_sequence,
+ )
+
+
+def common_collate_fn(
+ data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ """Concatenate ndarray-list to an array and convert to torch.Tensor.
+ """
+ uttids = [u for u, _ in data]
+ data = [d for _, d in data]
+
+ assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+ assert all(
+ not k.endswith("_lengths") for k in data[0]
+ ), f"*_lengths is reserved: {list(data[0])}"
+
+ output = {}
+ for key in data[0]:
+ if data[0][key].dtype.kind == "i":
+ pad_value = int_pad_value
+ else:
+ pad_value = float_pad_value
+
+ array_list = [d[key] for d in data]
+ tensor_list = [torch.from_numpy(a) for a in array_list]
+ tensor = pad_list(tensor_list, pad_value)
+ output[key] = tensor
+
+ if key not in not_sequence:
+ lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+ output[key + "_lengths"] = lens
+
+ output = (uttids, output)
+ return output
+
+
+class DiarCollateFn:
+ """Functor class of common_collate_fn()"""
+
+ def __init__(
+ self,
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+ max_sample_size=None
+ ):
+ self.float_pad_value = float_pad_value
+ self.int_pad_value = int_pad_value
+ self.not_sequence = set(not_sequence)
+ self.max_sample_size = max_sample_size
+
+ def __repr__(self):
+ return (
+ f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+ f"int_pad_value={self.float_pad_value})"
+ )
+
+ def __call__(
+ self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+ ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ return diar_collate_fn(
+ data,
+ float_pad_value=self.float_pad_value,
+ int_pad_value=self.int_pad_value,
+ not_sequence=self.not_sequence,
+ )
+
+
+def diar_collate_fn(
+ data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ """Concatenate ndarray-list to an array and convert to torch.Tensor.
+ """
+ uttids = [u for u, _ in data]
+ data = [d for _, d in data]
+
+ assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+ assert all(
+ not k.endswith("_lengths") for k in data[0]
+ ), f"*_lengths is reserved: {list(data[0])}"
+
+ output = {}
+ for key in data[0]:
+ if data[0][key].dtype.kind == "i":
+ pad_value = int_pad_value
+ else:
+ pad_value = float_pad_value
+
+ array_list = [d[key] for d in data]
+ tensor_list = [torch.from_numpy(a) for a in array_list]
+ tensor = pad_list_all_dim(tensor_list, pad_value)
+ output[key] = tensor
+
+ if key not in not_sequence:
+ lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+ output[key + "_lengths"] = lens
+
+ output = (uttids, output)
+ return output
+
+
+def crop_to_max_size(feature, target_size):
+ size = len(feature)
+ diff = size - target_size
+ if diff <= 0:
+ return feature
+
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return feature[start:end]
+
+
+def clipping_collate_fn(
+ data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+ max_sample_size=None,
+ not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ # mainly for pre-training
+ uttids = [u for u, _ in data]
+ data = [d for _, d in data]
+
+ assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+ assert all(
+ not k.endswith("_lengths") for k in data[0]
+ ), f"*_lengths is reserved: {list(data[0])}"
+
+ output = {}
+ for key in data[0]:
+ array_list = [d[key] for d in data]
+ tensor_list = [torch.from_numpy(a) for a in array_list]
+ sizes = [len(s) for s in tensor_list]
+ if max_sample_size is None:
+ target_size = min(sizes)
+ else:
+ target_size = min(min(sizes), max_sample_size)
+ tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
+ for i, (source, size) in enumerate(zip(tensor_list, sizes)):
+ diff = size - target_size
+ if diff == 0:
+ tensor[i] = source
+ else:
+ tensor[i] = crop_to_max_size(source, target_size)
+ output[key] = tensor
+
+ if key not in not_sequence:
+ lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
+ output[key + "_lengths"] = lens
+
+ output = (uttids, output)
+ return output
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/datapipes/__init__.py b/funasr/datasets/large_datasets/datapipes/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/__init__.py
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
new file mode 100644
index 0000000..35e5dba
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -0,0 +1,213 @@
+import random
+
+from itertools import count
+from functools import partial
+from torch.utils.data import IterableDataset
+from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+
+tiebreaker = count()
+
+
+def _default_len_fn(token):
+ return len(token), next(tiebreaker)
+
+
+def _token_len_fn(token, len_fn):
+ return len_fn(token), next(tiebreaker), token
+
+
+class MaxTokenBucketizerIterDataPipe(IterableDataset):
+
+ def __init__(
+ self,
+ datapipe,
+ batch_size=8000,
+ len_fn=_default_len_fn,
+ buffer_size=10240,
+ sort_size=500,
+ batch_mode="padding",
+ ):
+ assert batch_size > 0, "Batch size is required to be larger than 0!"
+ assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
+ assert sort_size > 0, "Sort size is required to be larger than 0!"
+
+ datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
+ self.datapipe = datapipe
+ self.batch_size = batch_size
+ self.buffer_size = buffer_size
+ self.sort_size = sort_size
+ self.batch_mode = batch_mode
+
+ def set_epoch(self, epoch):
+ self.datapipe.set_epoch(epoch)
+
+ def __iter__(self):
+ buffer = []
+ batch = []
+ bucket = []
+ max_lengths = 0
+ min_lengths = 999999
+ batch_lengths = 0
+
+ if self.batch_mode == "clipping":
+ assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
+
+ else:
+ if self.buffer_size == -1:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ buffer.sort()
+ for sample in buffer:
+ length, _, token = sample
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ bucket.append(batch)
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ random.shuffle(bucket)
+ if bucket:
+ for batch_sample in bucket:
+ yield batch_sample
+ if batch:
+ yield batch
+
+ elif self.buffer_size == 0:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ length, _, token = d
+ if length > self.batch_size:
+ continue
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ if batch:
+ yield batch
+
+ else:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
diff --git a/funasr/datasets/large_datasets/datapipes/filter.py b/funasr/datasets/large_datasets/datapipes/filter.py
new file mode 100644
index 0000000..6fe7153
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/filter.py
@@ -0,0 +1,24 @@
+from torch.utils.data import IterableDataset
+
+def default_fn(data):
+ return data
+
+
+class FilterIterDataPipe(IterableDataset):
+
+ def __init__(self,
+ datapipe,
+ fn=default_fn):
+ self.datapipe = datapipe
+ self.fn = fn
+
+ def set_epoch(self, epoch):
+ self.datapipe.set_epoch(epoch)
+
+ def __iter__(self):
+ assert callable(self.fn)
+ for data in self.datapipe:
+ if self.fn(data):
+ yield data
+ else:
+ continue
diff --git a/funasr/datasets/large_datasets/datapipes/map.py b/funasr/datasets/large_datasets/datapipes/map.py
new file mode 100644
index 0000000..dfcd6a0
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/map.py
@@ -0,0 +1,22 @@
+from torch.utils.data import IterableDataset
+
+
+def default_fn(data):
+ return data
+
+
+class MapperIterDataPipe(IterableDataset):
+
+ def __init__(self,
+ datapipe,
+ fn=default_fn):
+ self.datapipe = datapipe
+ self.fn = fn
+
+ def set_epoch(self, epoch):
+ self.datapipe.set_epoch(epoch)
+
+ def __iter__(self):
+ assert callable(self.fn)
+ for data in self.datapipe:
+ yield self.fn(data)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
new file mode 100644
index 0000000..d3489c1
--- /dev/null
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -0,0 +1,274 @@
+import logging
+import os
+import random
+from functools import partial
+
+import torch
+import torch.distributed as dist
+import torchaudio
+import numpy as np
+# import librosa
+import librosa
+from kaldiio import ReadHelper
+from torch.utils.data import IterableDataset
+
+from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
+from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
+from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+from funasr.datasets.large_datasets.utils.clipping import clipping
+from funasr.datasets.large_datasets.utils.filter import filter
+from funasr.datasets.large_datasets.utils.padding import padding
+from funasr.datasets.large_datasets.utils.tokenize import tokenize
+
+
+def read_lists(list_file):
+ lists = []
+ with open(list_file, 'r', encoding='utf8') as fin:
+ for line in fin:
+ parts = line.strip()
+ lists.append(parts)
+ return lists
+
+
+class AudioDataset(IterableDataset):
+ def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
+ mode="train"):
+ self.scp_lists = scp_lists
+ self.data_names = data_names
+ self.data_types = data_types
+ self.frontend_conf = frontend_conf
+ self.shuffle = shuffle
+ self.mode = mode
+ self.epoch = -1
+ self.rank = 0
+ self.world_size = 1
+ self.worker_id = 0
+ self.num_workers = 1
+ self.speed_perturb = speed_perturb
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def get_rank_data_list(self, data_index):
+ assert dist.is_available()
+ if dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ else:
+ self.rank = 0
+ self.world_size = 1
+
+ if self.mode == "train":
+ if self.shuffle:
+ random.seed(self.epoch)
+ random.shuffle(data_index)
+ return data_index[self.rank::self.world_size]
+
+ return data_index
+
+ def get_worker_data_list(self, rank_data_index):
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is None:
+ self.worker_id = 0
+ self.num_workers = 1
+ else:
+ self.worker_id = worker_info.id
+ self.num_workers = worker_info.num_workers
+
+ return rank_data_index[self.worker_id::self.num_workers]
+
+ def close_reader(self, reader_list):
+ for reader in reader_list:
+ reader.close()
+
+ def __iter__(self):
+ data_index = list(range(len(self.scp_lists)))
+ rank_data_index = self.get_rank_data_list(data_index)
+ worker_data_index = self.get_worker_data_list(rank_data_index)
+
+ for index in worker_data_index:
+ data = dict(scp=self.scp_lists[index])
+
+ assert 'scp' in data
+ scp = data['scp']
+ data_file_list = scp.strip().split()
+ data_name_list = self.data_names.split(",")
+ data_type_list = self.data_types.split(",")
+
+ for file in data_file_list:
+ assert os.path.exists(file), "{} not exists".format(file)
+
+ assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
+ "The item number of data, data_names, data_types must be the same "
+
+ reader_list = []
+ for data_file, data_type in zip(data_file_list, data_type_list):
+ if data_type == "kaldi_ark":
+ ark_reader = ReadHelper('ark:{}'.format(data_file))
+ reader_list.append(ark_reader)
+ elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
+ text_reader = open(data_file, "r", encoding="utf-8")
+ reader_list.append(text_reader)
+ elif data_type == "none":
+ continue
+ else:
+ raise TypeError("Data type {} is not supported".format(data_type))
+
+ for items in zip(*reader_list):
+ sample_dict = {}
+ for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
+ if data_type == "kaldi_ark":
+ key, mat = item
+ sample_dict[data_name] = mat
+ if data_name == "speech":
+ sample_dict["key"] = key
+ elif data_type == "sound":
+ key, path = item.strip().split()
+ try:
+ waveform, sampling_rate = torchaudio.load(path)
+ except:
+ # waveform, sampling_rate = librosa.load(path, dtype='float32')
+ waveform, sampling_rate = librosa.load(path, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ waveform = np.expand_dims(waveform, axis=0)
+ waveform = torch.tensor(waveform)
+ if self.frontend_conf is not None:
+ if sampling_rate != self.frontend_conf["fs"]:
+ waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
+ new_freq=self.frontend_conf["fs"])(waveform)
+ sampling_rate = self.frontend_conf["fs"]
+ waveform = waveform.numpy()
+ mat = waveform[0]
+ if self.speed_perturb is not None:
+ speed = random.choice(self.speed_perturb)
+ if speed != 1.0:
+ mat, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
+ mat = mat.view(-1).numpy()
+ sample_dict[data_name] = mat
+ sample_dict["sampling_rate"] = sampling_rate
+ if data_name == "speech":
+ sample_dict["key"] = key
+ elif data_type == "text_hotword":
+ text = item
+ segs = text.strip().split()
+ sample_dict[data_name] = segs[1:]
+ if "key" not in sample_dict:
+ sample_dict["key"] = segs[0]
+ sample_dict['hw_tag'] = 1
+ elif data_type == "text_nospace":
+ text = item
+ segs = text.strip().split(maxsplit=1)
+ sample_dict[data_name] = [x for x in segs[1]]
+ if "key" not in sample_dict:
+ sample_dict["key"] = segs[0]
+ else:
+ text = item
+ segs = text.strip().split()
+ sample_dict[data_name] = segs[1:]
+ if "key" not in sample_dict:
+ sample_dict["key"] = segs[0]
+ yield sample_dict
+
+ self.close_reader(reader_list)
+
+
+def len_fn_example(data):
+ return 1
+
+
+def len_fn_token(data):
+ assert "speech" in data
+ if "sampling_rate" in data:
+ return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+ else:
+ return data["speech"].shape[0]
+
+
+def Dataset(data_list_file,
+ dict,
+ seg_dict,
+ punc_dict,
+ bpe_tokenizer,
+ conf,
+ frontend_conf,
+ speed_perturb=None,
+ mode="train",
+ batch_mode="padding"):
+ scp_lists = read_lists(data_list_file)
+ shuffle = conf.get('shuffle', True)
+ data_names = conf.get("data_names", "speech,text")
+ data_types = conf.get("data_types", "kaldi_ark,text")
+
+ pre_hwfile = conf.get("pre_hwlist", None)
+ # pre_prob = conf.get("pre_prob", 0) # unused yet
+ if pre_hwfile is not None:
+ pre_hwlist = []
+ with open(pre_hwfile, 'r', encoding="utf-8") as fin:
+ for line in fin.readlines():
+ pre_hwlist.append(line.strip())
+ else:
+ pre_hwlist = None
+
+ hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
+ "double_rate": conf.get("double_rate", 0.1),
+ "hotword_min_length": conf.get("hotword_min_length", 2),
+ "hotword_max_length": conf.get("hotword_max_length", 8),
+ "pre_prob": conf.get("pre_prob", 0.0),
+ "pre_hwlist": pre_hwlist}
+
+
+
+ dataset = AudioDataset(scp_lists,
+ data_names,
+ data_types,
+ frontend_conf=frontend_conf,
+ shuffle=shuffle,
+ speed_perturb=speed_perturb,
+ mode=mode,
+ )
+
+ if "text" in data_names:
+ vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
+ tokenize_fn = partial(tokenize, **vocab)
+ dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
+
+ filter_conf = conf.get('filter_conf', {})
+ filter_fn = partial(filter, **filter_conf)
+ dataset = FilterIterDataPipe(dataset, fn=filter_fn)
+
+ if shuffle:
+ buffer_conf = conf.get('shuffle_conf', {})
+ buffer_size = buffer_conf['shuffle_size']
+ sort_size = buffer_conf['sort_size']
+ else:
+ buffer_size = 0
+ sort_size = 1
+
+ batch_conf = conf.get('batch_conf', {})
+ batch_size = batch_conf['batch_size']
+ batch_type = batch_conf['batch_type']
+
+ assert batch_type in ["example", "token"]
+ if batch_type == 'example':
+ len_fn = len_fn_example
+ else:
+ len_fn = len_fn_token
+
+ dataset = MaxTokenBucketizerIterDataPipe(dataset,
+ batch_size=batch_size,
+ len_fn=len_fn,
+ buffer_size=buffer_size,
+ sort_size=sort_size,
+ batch_mode=batch_mode)
+
+ int_pad_value = conf.get("int_pad_value", -1)
+ float_pad_value = conf.get("float_pad_value", 0.0)
+ padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
+ padding_fn = partial(padding, **padding_conf)
+ dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
+
+ return dataset
diff --git a/funasr/datasets/large_datasets/utils/__init__.py b/funasr/datasets/large_datasets/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/__init__.py
diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
new file mode 100644
index 0000000..2554aba
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+ assert isinstance(data, list)
+ assert "key" in data[0]
+
+ keys = [x["key"] for x in data]
+
+ batch = {}
+ data_names = data[0].keys()
+ for data_name in data_names:
+ if data_name == "key":
+ continue
+ else:
+ if data[0][data_name].dtype.kind == "i":
+ tensor_type = torch.int64
+ else:
+ tensor_type = torch.float32
+
+ tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+ tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+ length_clip = min(tensor_lengths)
+ tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+ for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+ diff = length - length_clip
+ assert diff >= 0
+ if diff == 0:
+ tensor_clip[i] = tensor
+ else:
+ tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+ batch[data_name] = tensor_clip
+ batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+ return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/filter.py b/funasr/datasets/large_datasets/utils/filter.py
new file mode 100644
index 0000000..1260a47
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/filter.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python
+
+
+def filter(data,
+ speech_length_min=100,
+ speech_length_max=15000,
+ token_length_min=0,
+ token_length_max=200):
+ assert "speech" in data or "text" in data
+
+ if "speech" in data and "text" in data:
+ if "sampling_rate" in data:
+ speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+ else:
+ speech_length = data["speech"].shape[0]
+ num_tokens = len(data['text'])
+ return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
+ elif "speech" in data:
+ if "sampling_rate" in data:
+ speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+ else:
+ speech_length = data["speech"].shape[0]
+ return speech_length_min < speech_length < speech_length_max
+ else:
+ num_tokens = len(data['text'])
+ return token_length_min < num_tokens < token_length_max
diff --git a/funasr/datasets/large_datasets/utils/hotword_utils.py b/funasr/datasets/large_datasets/utils/hotword_utils.py
new file mode 100644
index 0000000..73f8bdd
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/hotword_utils.py
@@ -0,0 +1,33 @@
+import random
+
+def sample_hotword(length,
+ hotword_min_length,
+ hotword_max_length,
+ sample_rate,
+ double_rate,
+ pre_prob,
+ pre_index=None,
+ pre_hwlist=None):
+ if length < hotword_min_length:
+ return [-1]
+ if random.random() < sample_rate:
+ if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
+ return pre_index
+ if length == hotword_min_length:
+ return [0, length-1]
+ elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
+ # sample two hotwords in a sentence
+ _max_hw_length = min(hotword_max_length, length // 2)
+ # first hotword
+ start1 = random.randint(0, length // 3)
+ end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
+ # second hotword
+ start2 = random.randint(end1 + 1, length - hotword_min_length)
+ end2 = random.randint(min(length-1, start2+hotword_min_length-1), min(length-1, start2+hotword_max_length-1))
+ return [start1, end1, start2, end2]
+ else: # single hotword
+ start = random.randint(0, length - hotword_min_length)
+ end = random.randint(min(length-1, start+hotword_min_length-1), min(length-1, start+hotword_max_length-1))
+ return [start, end]
+ else:
+ return [-1]
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/utils/low_frame_rate.py b/funasr/datasets/large_datasets/utils/low_frame_rate.py
new file mode 100644
index 0000000..76eb2da
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/low_frame_rate.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+
+def build_LFR_features(data, m, n):
+ """
+ Actually, this implements stacking frames and skipping frames.
+ if m = 1 and n = 1, just return the origin features.
+ if m = 1 and n > 1, it works like skipping.
+ if m > 1 and n = 1, it works like stacking but only support right frames.
+ if m > 1 and n > 1, it works like LFR.
+
+ Args:
+ inputs_batch: inputs is T x D np.ndarray
+ m: number of frames to stack
+ n: number of frames to skip
+ """
+
+ LFR_inputs = []
+ T = data.shape[0]
+ T_lfr = int(np.ceil(T / n))
+ for i in range(T_lfr):
+ if m <= T - i * n:
+ LFR_inputs.append(np.hstack(data[i*n:i*n+m]))
+ else:
+ num_padding = m - (T - i * n)
+ frame = np.hstack(data[i*n:])
+ for _ in range(num_padding):
+ frame = np.hstack((frame, data[-1]))
+ LFR_inputs.append(frame)
+ return np.vstack(LFR_inputs)
diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
new file mode 100644
index 0000000..26c6e84
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -0,0 +1,74 @@
+import numpy as np
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def padding(data, float_pad_value=0.0, int_pad_value=-1):
+ assert isinstance(data, list)
+ assert "key" in data[0]
+ assert "speech" in data[0] or "text" in data[0]
+
+ keys = [x["key"] for x in data]
+
+ batch = {}
+ data_names = data[0].keys()
+ for data_name in data_names:
+ if data_name == "key" or data_name == "sampling_rate":
+ continue
+ else:
+ if data_name != 'hotword_indxs':
+ if data[0][data_name].dtype.kind == "i":
+ pad_value = int_pad_value
+ tensor_type = torch.int64
+ else:
+ pad_value = float_pad_value
+ tensor_type = torch.float32
+
+ tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+ tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+ tensor_pad = pad_sequence(tensor_list,
+ batch_first=True,
+ padding_value=pad_value)
+ batch[data_name] = tensor_pad
+ batch[data_name + "_lengths"] = tensor_lengths
+
+ # SAC LABEL INCLUDE
+ if "hotword_indxs" in batch:
+ # if hotword indxs in batch
+ # use it to slice hotwords out
+ hotword_list = []
+ hotword_lengths = []
+ text = batch['text']
+ text_lengths = batch['text_lengths']
+ hotword_indxs = batch['hotword_indxs']
+ dha_pad = torch.ones_like(text) * -1
+ _, t1 = text.shape
+ t1 += 1 # TODO: as parameter which is same as predictor_bias
+ nth_hw = 0
+ for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
+ dha_pad[b][:length] = 8405
+ if hotword_indx[0] != -1:
+ start, end = int(hotword_indx[0]), int(hotword_indx[1])
+ hotword = one_text[start: end+1]
+ hotword_list.append(hotword)
+ hotword_lengths.append(end-start+1)
+ dha_pad[b][start: end+1] = one_text[start: end+1]
+ nth_hw += 1
+ if len(hotword_indx) == 4 and hotword_indx[2] != -1:
+ # the second hotword if exist
+ start, end = int(hotword_indx[2]), int(hotword_indx[3])
+ hotword_list.append(one_text[start: end+1])
+ hotword_lengths.append(end-start+1)
+ dha_pad[b][start: end+1] = one_text[start: end+1]
+ nth_hw += 1
+ hotword_list.append(torch.tensor([1]))
+ hotword_lengths.append(1)
+ hotword_pad = pad_sequence(hotword_list,
+ batch_first=True,
+ padding_value=0)
+ batch["hotword_pad"] = hotword_pad
+ batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
+ batch['dha_pad'] = dha_pad
+ del batch['hotword_indxs']
+ del batch['hotword_indxs_lengths']
+ return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
new file mode 100644
index 0000000..34a97c1
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+import re
+import numpy as np
+from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
+
+def forward_segment(text, seg_dict):
+ word_list = []
+ i = 0
+ while i < len(text):
+ longest_word = text[i]
+ for j in range(i + 1, len(text) + 1):
+ word = text[i:j]
+ if word in seg_dict:
+ if len(word) > len(longest_word):
+ longest_word = word
+ word_list.append(longest_word)
+ i += len(longest_word)
+ return word_list
+
+def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+def tokenize(data,
+ vocab=None,
+ seg_dict=None,
+ punc_dict=None,
+ bpe_tokenizer=None,
+ hw_config=None):
+ assert "text" in data
+ assert isinstance(vocab, dict)
+ text = data["text"]
+ token = []
+ vad = -2
+ if bpe_tokenizer is not None:
+ text = bpe_tokenizer.text2tokens(" ".join(text))
+ if seg_dict is not None:
+ assert isinstance(seg_dict, dict)
+ text = seg_tokenize(text, seg_dict)
+
+ length = len(text)
+ if 'hw_tag' in data:
+ pre_index = None
+ if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0:
+ # enable preset hotword detect in sampling
+ for hw in hw_config['pre_hwlist']:
+ hw = " ".join(seg_tokenize(hw, seg_dict))
+ _find = " ".join(text).find(hw)
+ if _find != -1:
+ # _find = text[:_find].count(" ") # bpe sometimes
+ pre_index = [_find, _find + max(hw.count(" "), 1)]
+ break
+ hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
+ data['hotword_indxs'] = hotword_indxs
+ del data['hw_tag']
+ for i in range(length):
+ x = text[i]
+ if i == length-1 and "punc" in data and x.startswith("vad:"):
+ vad = x[4:]
+ if len(vad) == 0:
+ vad = -1
+ else:
+ vad = int(vad)
+ elif x in vocab:
+ token.append(vocab[x])
+ else:
+ token.append(vocab['<unk>'])
+
+ if "punc" in data and punc_dict is not None:
+ punc_token = []
+ for punc in data["punc"]:
+ if punc in punc_dict:
+ punc_token.append(punc_dict[punc])
+ else:
+ punc_token.append(punc_dict["_"])
+ data["punc"] = np.array(punc_token)
+
+ data["text"] = np.array(token)
+ if vad is not -2:
+ data["vad_indexes"]=np.array([vad], dtype=np.int64)
+ return data
diff --git a/funasr/datasets/llm_datasets/samplers.py b/funasr/datasets/llm_datasets/samplers.py
deleted file mode 100644
index 914e776..0000000
--- a/funasr/datasets/llm_datasets/samplers.py
+++ /dev/null
@@ -1,277 +0,0 @@
-import torch
-import numpy as np
-import logging
-import torch.distributed as dist
-
-from funasr.register import tables
-
-
-@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
-class BatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = False,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 5000)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
-
- def __len__(self):
- return (self.total_samples-1) // self.batch_size + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- sample_len_cur = source_len + target_len
-
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- if self.batch_type != 'example':
- max_token_padding *= max_token_cur
- if max_token_padding <= self.batch_size:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- yield batch
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
-
-
-@tables.register("batch_sampler_classes", "BatchSampler")
-@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
-class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = True,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 1500)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- self.rank = rank
- self.world_size = world_size
-
- def __len__(self):
- return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- batch_size_total = self.batch_size * self.world_size
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- # if iter == iter_num -1 and self.drop_last:
- # continue
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- sample_len_cur = source_len + target_len
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- # if self.batch_type != 'example':
- # max_token_padding *= max_token_cur
- if max_token_padding <= batch_size_total:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
- yield batch_rank
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
-
-
-@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
-class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = True,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 1500)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- self.rank = rank
- self.world_size = world_size
-
- def __len__(self):
- return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- batch_size_total = self.batch_size * self.world_size
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch_list_all_rank = []
- batch_list_cur = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- # if iter == iter_num - 1 and self.drop_last:
- # continue
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- sample_len_cur = source_len + target_len
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for ii, item in enumerate(datalen_with_index_sort):
- is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
-
- if self.batch_type != 'example':
- max_token_padding *= max_token_cur
- if len(batch_list_all_rank) < self.world_size:
-
- if max_token_padding <= self.batch_size:
- batch_list_cur.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- batch_list_all_rank.append(batch_list_cur)
- batch_list_cur = []
- else:
- batch_rank = batch_list_all_rank[self.rank]
- yield batch_rank
- batch_list_all_rank = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
deleted file mode 100644
index 61f7d94..0000000
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ /dev/null
@@ -1,431 +0,0 @@
-import torch
-import numpy as np
-import logging
-import math
-import torch.distributed as dist
-from torch.utils.data import DistributedSampler
-from torch.utils.data import BatchSampler, Sampler
-import torch.distributed as dist
-
-from funasr.register import tables
-
-
-@tables.register("batch_sampler_classes", "RankFullGlobalShuffleBatchSampler")
-class RankFullGlobalShuffleBatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = True,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 1500)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- self.rank = rank
- self.world_size = world_size
-
- def __len__(self):
- return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- batch_size_total = self.batch_size * self.world_size
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- # if iter == iter_num -1 and self.drop_last:
- # continue
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- sample_len_cur = source_len + target_len
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- # if self.batch_type != 'example':
- # max_token_padding *= max_token_cur
- if max_token_padding <= batch_size_total:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
- yield batch_rank
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
-
-@tables.register("batch_sampler_classes", "DistributedSamplerWarp")
-class DistributedSamplerWarp(BatchSampler):
- def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False):
- if num_replicas is None:
- if not torch.distributed.is_available():
- raise RuntimeError("Requires distributed package to be available")
- num_replicas = torch.distributed.get_world_size()
- if rank is None:
- if not torch.distributed.is_available():
- raise RuntimeError("Requires distributed package to be available")
- rank = torch.distributed.get_rank()
-
- self.dataset = dataset
- self.batch_size = batch_size
- self.num_replicas = num_replicas
- self.rank = rank
- self.shuffle = shuffle
- self.drop_last = drop_last
-
- # Create an instance of the DistributedSampler
- self.sampler = DistributedSampler(
- self.dataset,
- num_replicas=self.num_replicas,
- rank=self.rank,
- shuffle=self.shuffle
- )
-
- # Call BatchSampler's constructor
- super().__init__(self.sampler, batch_size, drop_last)
-
- def __iter__(self):
- # If we shuffle, we need to call the set_epoch method
- if self.shuffle:
- self.sampler.set_epoch(self.epoch)
-
- # Generate batch indices using the parent class
- return super().__iter__()
-
- def set_epoch(self, epoch):
- self.epoch = epoch
-
-@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
-def CustomDistributedBatchSampler_fn(dataset, **kwargs):
- dataloader_args = {}
- dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
- dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
- dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-
- return dataloader_args
-
-@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
-class CustomDistributedBatchSampler(Sampler):
- def __init__(self, dataset,
- batch_size,
- num_replicas=None,
- rank=None,
- shuffle=True,
- drop_last=False,
- is_training: bool = True,
- **kwargs,
- ):
-
- try:
- rank = dist.get_rank()
- num_replicas = dist.get_world_size()
- except:
- rank = 0
- num_replicas = 1
- self.rank = rank
- self.num_replicas = num_replicas
- self.dataset = dataset
- self.batch_size = batch_size
- self.is_training = is_training
- self.shuffle = shuffle and is_training
- self.drop_last = drop_last
- # self.total_size = len(dataset)
- if self.drop_last:
- self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
- else:
- self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
- self.num_samples = int(self.total_size // self.num_replicas)
- self.epoch = 0
- self.max_token_length = kwargs.get("max_token_length", None)
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
- def __iter__(self):
- # Generate a list of indices
- if self.shuffle:
- g = torch.Generator()
- g.manual_seed(self.epoch)
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
- else:
- indices = list(range(len(self.dataset)))
-
- # Add extra samples to make it evenly divisible
- padding_size = self.total_size - len(indices)
- if padding_size <= len(indices):
- indices += indices[:padding_size]
- else:
- indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
-
- assert len(indices) == self.total_size
-
- # Subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
- assert len(indices) == self.num_samples
-
- # Filter out indices with length greater than the max length, if provided
- if self.max_token_length is not None:
- filtered_indices = []
- for idx in indices:
- source_len = self.dataset.get_source_len(idx) / self.length_scale_source
- if source_len <= self.max_token_length:
- filtered_indices.append(idx)
- indices = filtered_indices
-
- # Now that we have only the indices for this replica, chunk them into batches
- batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
-
- # Drop the last batch if it's not full and drop_last is True
- if self.drop_last and len(batches[-1]) != self.batch_size:
- batches = batches[:-1]
-
- return iter(batches)
-
- def __len__(self):
-
- return self.num_samples // self.batch_size
-
- def set_epoch(self, epoch):
- self.epoch = epoch
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler_fn")
-def CustomDistributedBatchSampler_fn(dataset, **kwargs):
- dataloader_args = {}
- dataloader_args["batch_sampler"] = CustomDistributedBufferBatchSampler(dataset, **kwargs)
- dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
- dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-
- return dataloader_args
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler")
-class CustomDistributedBatchSampler(Sampler):
- def __init__(self, dataset,
- batch_size,
- num_replicas=None,
- rank=None,
- shuffle=True,
- drop_last=False,
- is_training: bool = True,
- sort_size: int = 1024,
- **kwargs,
- ):
-
- try:
- rank = dist.get_rank()
- num_replicas = dist.get_world_size()
- except:
- rank = 0
- num_replicas = 1
- self.rank = rank
- self.num_replicas = num_replicas
- self.dataset = dataset
- self.batch_size = batch_size
- self.is_training = is_training
- self.shuffle = shuffle and is_training
- self.drop_last = drop_last
- # self.total_size = len(dataset)
- if self.drop_last:
- self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
- else:
- self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
- self.num_samples = int(self.total_size // self.num_replicas)
- self.epoch = 0
- self.max_token_length = kwargs.get("max_token_length", None)
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
- self.sort_size = sort_size
-
- def __iter__(self):
- # Generate a list of indices
- if self.shuffle:
- g = torch.Generator()
- g.manual_seed(self.epoch)
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
- else:
- indices = list(range(len(self.dataset)))
-
- # Add extra samples to make it evenly divisible
- padding_size = self.total_size - len(indices)
- if padding_size <= len(indices):
- indices += indices[:padding_size]
- else:
- indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
-
- assert len(indices) == self.total_size
-
- # Subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
- assert len(indices) == self.num_samples
-
- # Filter out indices with length greater than the max length, if provided
- if self.max_token_length is not None:
- filtered_indices = []
- for idx in indices:
- source_len = self.dataset.get_source_len(idx) / self.length_scale_source
- if source_len <= self.max_token_length:
- filtered_indices.append(idx)
- indices = filtered_indices
-
- # Buffer sorting logic
- sorted_batches = []
- buffer = []
-
- for idx in indices:
- buffer.append(idx)
- if len(buffer) >= self.sort_size:
- # Sort the buffer based on some criteria, e.g., dataset sample length
- buffer.sort(key=lambda x: self.dataset.get_source_len(x))
- sorted_batches.extend(self._create_batches_from_buffer(buffer))
- buffer = []
-
- # Handle the remaining items in the buffer
- if buffer:
- buffer.sort(key=lambda x: self.dataset.get_source_len(x))
- sorted_batches.extend(self._create_batches_from_buffer(buffer))
-
- return iter(sorted_batches)
-
- def _create_batches_from_buffer(self, buffer):
- # Function to convert the sorted buffer into batches
- batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
- if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
- batched_buffer = batched_buffer[:-1]
- return batched_buffer
-
- def __len__(self):
-
- return self.num_samples // self.batch_size
-
- def set_epoch(self, epoch):
- self.epoch = epoch
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler_fn")
-def CustomDistributedBatchSampler_fn(dataset, **kwargs):
- dataloader_args = {}
- dataloader_args["batch_sampler"] = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
- dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
- dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-
- return dataloader_args
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
-class CustomDistributedDynamicBatchSampler(Sampler):
- def __init__(self, dataset,
- batch_size,
- num_replicas=None,
- rank=None,
- shuffle=True,
- drop_last=False,
- is_training: bool = True,
- **kwargs,
- ):
-
- try:
- rank = dist.get_rank()
- num_replicas = dist.get_world_size()
- except:
- rank = 0
- num_replicas = 1
- self.rank = rank
- self.num_replicas = num_replicas
- self.dataset = dataset
- self.batch_size = batch_size
- self.is_training = is_training
- self.shuffle = shuffle and is_training
- self.drop_last = drop_last
-
- self.total_size = len(self.dataset)
- # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
- self.epoch = 0
-
- def __iter__(self):
- if self.shuffle:
- g = torch.Generator()
- g.manual_seed(self.epoch)
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
- else:
- indices = list(range(len(self.dataset)))
-
- indices = indices[self.rank:self.total_size:self.num_replicas]
-
- batches = []
- batch = []
- max_len_in_batch = 0
- current_batch_length = 0
-
- for idx in indices:
- sample_length = self.dataset.get_source_len(idx)
- potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
- len(batch) + 1)
-
- if potential_batch_length <= self.batch_size:
- batch.append(idx)
- if sample_length > max_len_in_batch:
- max_len_in_batch = sample_length
- current_batch_length = max_len_in_batch * len(batch)
- else:
- batches.append(batch)
- batch = [idx]
- max_len_in_batch = sample_length
- current_batch_length = max_len_in_batch
-
- # Add the last batch if it's not empty and we're not dropping it
- if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
- batches.append(batch)
-
- return iter(batches)
-
- def __len__(self):
-
- return -1
-
- def set_epoch(self, epoch):
- self.epoch = epoch
diff --git a/funasr/register.py b/funasr/register.py
index cfa1b20..45e2a85 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -15,6 +15,7 @@
predictor_classes = {}
stride_conv_classes = {}
tokenizer_classes = {}
+ dataloader_classes = {}
batch_sampler_classes = {}
dataset_classes = {}
index_ds_classes = {}
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index aae4513..c443c6f 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -1,3 +1,4 @@
+import math
import os
import time
import torch
@@ -7,8 +8,6 @@
import torch.distributed as dist
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
-# from torch.utils.tensorboard import SummaryWriter
-from tensorboardX import SummaryWriter
from pathlib import Path
from funasr.train_utils.device_funcs import to_device
@@ -40,11 +39,7 @@
resume (str, optional): Path to a checkpoint to resume training from.
"""
- def __init__(self, model,
- optim,
- scheduler,
- dataloader_train,
- dataloader_val,
+ def __init__(self,
local_rank,
use_ddp: bool = False,
use_fsdp: bool = False,
@@ -66,29 +61,31 @@
resume (str, optional): The file path to a checkpoint to resume training from.
"""
- self.model = model
- self.optim = optim
- self.scheduler = scheduler
- self.dataloader_train = dataloader_train
- self.dataloader_val = dataloader_val
self.output_dir = output_dir
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir, exist_ok=True)
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
- self.device = next(model.parameters()).device
+ self.device = kwargs.get('device', "cuda")
self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
- self.kwargs = kwargs
+ # self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.use_fp16 = use_fp16
self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
- scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
- scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
- self.scaler = scaler
+ # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+ # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
+ # self.scaler = scaler
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+ self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
+ self.accum_grad = kwargs.get("accum_grad", 1)
+ self.grad_clip = kwargs.get("grad_clip", 10.0)
+ self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
+ self.validate_interval = kwargs.get("validate_interval", 5000)
try:
@@ -100,13 +97,22 @@
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
-
- os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
- self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
-
-
-
- def _save_checkpoint(self, epoch, step=None):
+ self.train_acc_avg = 0.0
+ self.train_loss_avg = 0.0
+ self.val_acc_avg = 0.0
+ self.val_loss_avg = 0.0
+ self.best_acc_idx = 0
+ self.saved_ckpts = {}
+ self.val_acc_list = []
+ self.step_or_epoch = -1
+
+ def save_checkpoint(self, epoch,
+ step=None,
+ model=None,
+ optim=None,
+ scheduler=None,
+ scaler=None,
+ ):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
@@ -115,29 +121,65 @@
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
- state = {
- 'epoch': epoch,
- 'state_dict': self.model.state_dict(),
- 'optimizer': self.optim.state_dict(),
- 'scheduler': self.scheduler.state_dict(),
- }
- if self.scaler:
- state["scaler_state"] = self.scaler.state_dict()
- # Create output directory if it does not exist
- os.makedirs(self.output_dir, exist_ok=True)
- if step is None:
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
- else:
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
- torch.save(state, filename)
-
- print(f'\nCheckpoint saved to {filename}\n')
- latest = Path(os.path.join(self.output_dir, f'model.pt'))
- torch.save(state, latest)
+ if self.rank == 0:
+ logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+ self.step_or_epoch += 1
+ state = {
+ 'epoch': epoch,
+ 'state_dict': model.state_dict(),
+ 'optimizer': optim.state_dict(),
+ 'scheduler': scheduler.state_dict(),
+ "acc": self.val_acc_list,
+ "step_or_epoch": self.step_or_epoch,
+ }
+ if hasattr(model, "module"):
+ state["state_dict"] = model.module.state_dict()
+
+ if scaler:
+ state["scaler_state"] = scaler.state_dict()
+ # Create output directory if it does not exist
+ os.makedirs(self.output_dir, exist_ok=True)
+ if step is None:
+ ckpt_name = f'model.pt.ep{epoch}'
+ else:
+ ckpt_name = f'model.pt.ep{epoch}.{step}'
+ filename = os.path.join(self.output_dir, ckpt_name)
+ torch.save(state, filename)
+
+ logging.info(f'\nCheckpoint saved to {filename}\n')
+ latest = Path(os.path.join(self.output_dir, f'model.pt'))
+ torch.save(state, latest)
+
+ if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
+ self.best_acc_idx = self.step_or_epoch
+ best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+ torch.save(state, best_ckpt)
+ logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+ else:
+ logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
+
+ if self.keep_nbest_models > 0:
+ self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
+ if len(self.saved_ckpts) > self.keep_nbest_models:
+ min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+ if min_key in self.saved_ckpts:
+ del self.saved_ckpts[min_key]
+ filename = os.path.join(self.output_dir, min_key)
+ logging.info(f"Delete: {filename}")
+ if os.path.exists(filename):
+ os.remove(filename)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
- def _resume_checkpoint(self, resume_path):
+ def resume_checkpoint(self,
+ model=None,
+ optim=None,
+ scheduler=None,
+ scaler=None,
+ ):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
@@ -145,114 +187,79 @@
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
- ckpt = os.path.join(resume_path, "model.pt")
- if os.path.isfile(ckpt):
- checkpoint = torch.load(ckpt, map_location="cpu")
- self.start_epoch = checkpoint['epoch'] + 1
- # self.model.load_state_dict(checkpoint['state_dict'])
- src_state = checkpoint['state_dict']
- dst_state = self.model.state_dict()
- for k in dst_state.keys():
- if not k.startswith("module.") and "module."+k in src_state.keys():
- k_ddp = "module."+k
- else:
- k_ddp = k
- if k_ddp in src_state.keys():
- dst_state[k] = src_state[k_ddp]
- else:
- print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
-
- self.model.load_state_dict(dst_state)
- self.optim.load_state_dict(checkpoint['optimizer'])
- self.scheduler.load_state_dict(checkpoint['scheduler'])
- if self.scaler and 'scaler_state' in checkpoint:
- self.scaler.load_state_dict(checkpoint['scaler_state'])
- print(f"Checkpoint loaded successfully from '{ckpt}'")
- else:
- print(f"No checkpoint found at '{ckpt}', does not resume status!")
-
- self.model.to(self.device)
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
- def run(self):
- """
- Starts the training process, iterating over epochs, training the model,
- and saving checkpoints at the end of each epoch.
- """
if self.resume:
- self._resume_checkpoint(self.output_dir)
-
- for epoch in range(self.start_epoch, self.max_epoch + 1):
- time1 = time.perf_counter()
- self._train_epoch(epoch)
-
-
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
+ ckpt = os.path.join(self.output_dir, "model.pt")
+ if os.path.isfile(ckpt):
+ checkpoint = torch.load(ckpt)
+ self.start_epoch = checkpoint['epoch'] + 1
+ # self.model.load_state_dict(checkpoint['state_dict'])
+ src_state = checkpoint['state_dict']
+ dst_state = model.state_dict()
+ for k in dst_state.keys():
+ if not k.startswith("module.") and "module."+k in src_state.keys():
+ k_ddp = "module."+k
+ else:
+ k_ddp = k
+ if k_ddp in src_state.keys():
+ dst_state[k] = src_state[k_ddp]
+ else:
+ print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+
+ model.load_state_dict(dst_state)
+ optim.load_state_dict(checkpoint['optimizer'])
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ if scaler is not None and 'scaler_state' in checkpoint:
+ scaler.load_state_dict(checkpoint['scaler_state'])
- self._validate_epoch(epoch)
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
-
- if self.rank == 0:
- self._save_checkpoint(epoch)
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
- self.scheduler.step()
-
- time2 = time.perf_counter()
- time_escaped = (time2 - time1)/3600.0
- print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
-
- if self.rank == 0:
- average_checkpoints(self.output_dir, self.avg_nbest_model)
-
+ self.val_acc_list = checkpoint["acc"]
+ self.step_or_epoch = checkpoint["step_or_epoch"]
+
+ print(f"Checkpoint loaded successfully from '{ckpt}'")
+ else:
+ print(f"No checkpoint found at '{ckpt}', does not resume status!")
+
if self.use_ddp or self.use_fsdp:
dist.barrier()
-
-
- if self.writer:
- self.writer.close()
-
- def _train_epoch(self, epoch):
+
+ def train_epoch(self,
+ model=None,
+ optim=None,
+ scheduler=None,
+ scaler=None,
+ dataloader_train=None,
+ dataloader_val=None,
+ epoch=None,
+ writer=None,
+ ):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
- self.model.train()
- pbar = tqdm(colour="blue", desc=f"rank: {self.local_rank}, Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
- dynamic_ncols=True)
-
+ logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
+ model.train()
+
# Set the number of steps for gradient accumulation
- accum_grad = self.kwargs.get("accum_grad", 1)
+ accum_grad = self.accum_grad
# Initialize the gradient accumulation
- self.optim.zero_grad()
+ optim.zero_grad()
speed_stats = {}
time5 = time.perf_counter()
- for batch_idx, batch in enumerate(self.dataloader_train):
+ for batch_idx, batch in enumerate(dataloader_train):
self.batch_total += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
batch = to_device(batch, self.device)
- my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+ my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
with maybe_autocast(self.use_fp16):
- retval = self.model(**batch)
+ retval = model(**batch)
- if self.disable_gpu_cache: torch.cuda.empty_cache()
-
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
@@ -261,95 +268,105 @@
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed=True)
+ # stats, weight = recursive_average(stats, weight, distributed=True)
+ if self.use_ddp or self.use_fsdp:
+ dist.all_reduce(weight, op=dist.ReduceOp.SUM)
# Now weight is summation over all workers
- loss /= weight
+ loss /= weight.sum() # shape:[1] -> shape:[]
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
if self.use_fp16:
- self.scaler.scale(loss).backward()
+ scaler.scale(loss).backward()
else:
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+
+ self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+ if "acc" in stats:
+ self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
# Perform an optimizer step only after accumulating enough gradients
- if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
+ if (batch_idx + 1) % accum_grad == 0:
# Perform gradient clipping if it is set
- if self.kwargs.get("grad_clip", None) is not None:
+ if self.grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
- self.model.parameters(),
- max_norm=self.kwargs.get("grad_clip", 10.0),
- norm_type=self.kwargs.get("grad_clip_type", 2.0),
+ model.parameters(),
+ max_norm=self.grad_clip,
+ norm_type=self.grad_clip_type,
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
- self.optim.zero_grad() # Reset gradients
+ optim.zero_grad() # Reset gradients
continue
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.use_fp16:
- self.scaler.step(self.optim)
- self.scaler.update()
+ scaler.step(optim)
+ scaler.update()
else:
- self.optim.step()
- self.scheduler.step()
+ optim.step()
+ scheduler.step()
# Clear gradients for the next accumulation stage
- self.optim.zero_grad(set_to_none=True)
+ optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
+ lr = scheduler.get_last_lr()[0]
+ batch_num_epoch = -1
+ if hasattr(dataloader_train, "__len__"):
+ batch_num_epoch = len(dataloader_train)
+ self.log(epoch, batch_idx,
+ batch_num_epoch=batch_num_epoch,
+ lr=lr,
+ loss=loss.detach().cpu().item(),
+ speed_stats=speed_stats,
+ stats=stats,
+ writer=writer,
+ tag="train",
+ )
-
-
- if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_train):
- pbar.update(self.log_interval)
- gpu_info = "GPU, memory: {:.3f} GB, " \
- "{:.3f} GB, "\
- "{:.3f} GB, "\
- "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
- torch.cuda.max_memory_allocated()/1024/1024/1024,
- torch.cuda.memory_reserved()/1024/1024/1024,
- torch.cuda.max_memory_reserved()/1024/1024/1024,
- )
- lr = self.scheduler.get_last_lr()[0]
- time_now = datetime.now()
- time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
- description = (
- f"{time_now}, "
- f"rank: {self.local_rank}, "
- f"epoch: {epoch}/{self.max_epoch}, "
- f"step: {batch_idx+1}/{len(self.dataloader_train)}, total step: {self.batch_total}, "
- f"(loss: {loss.detach().cpu().item():.3f}), "
- f"(lr: {lr:.3e}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
- f"{speed_stats}, "
- f"{gpu_info}"
+ if (batch_idx + 1) % self.validate_interval == 0:
+ self.validate_epoch(
+ model=model,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer
)
- pbar.set_description(description)
- if self.writer:
- self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(), self.batch_total)
- self.writer.add_scalar(f'rank{self.local_rank}_lr/train', lr, self.batch_total)
- for key, var in stats.items():
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(), self.batch_total)
- for key, var in speed_stats.items():
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
- if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
- self._save_checkpoint(epoch, step=batch_idx+1)
- pbar.close()
+ if (batch_idx+1) % self.save_checkpoint_interval == 0:
+ self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
+
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
- def _validate_epoch(self, epoch):
+ def validate_epoch(self,
+ model=None,
+ dataloader_val=None,
+ epoch=None,
+ writer=None,
+ **kwargs,
+ ):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
@@ -357,18 +374,19 @@
Args:
epoch (int): The current epoch number.
"""
- self.model.eval()
+ logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
+ model.eval()
+
with torch.no_grad():
- pbar = tqdm(colour="red", desc=f"rank: {self.local_rank}, Validation Epoch: {epoch + 1}", total=len(self.dataloader_val),
- dynamic_ncols=True)
+
speed_stats = {}
time5 = time.perf_counter()
- for batch_idx, batch in enumerate(self.dataloader_val):
+ for batch_idx, batch in enumerate(dataloader_val):
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
time2 = time.perf_counter()
- retval = self.model(**batch)
+ retval = model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
@@ -378,8 +396,10 @@
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed=True)
+ if self.use_ddp or self.use_fsdp:
+ dist.all_reduce(weight, op=dist.ReduceOp.SUM)
# Now weight is summation over all workers
- loss /= weight
+ loss /= weight.sum() # shape:[1] -> shape:[]
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
@@ -387,29 +407,94 @@
loss = loss
time4 = time.perf_counter()
+ self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+ if "acc" in stats:
+ self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
- if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
- pbar.update(self.log_interval)
- time_now = datetime.now()
- time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
- description = (
- f"{time_now}, "
- f"rank: {self.local_rank}, "
- f"validation epoch: {epoch}/{self.max_epoch}, "
- f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
- f"(loss: {loss.detach().cpu().item():.3f}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
- f"{speed_stats}, "
- )
- pbar.set_description(description)
- if self.writer:
- self.writer.add_scalar(f"rank{self.local_rank}_Loss/val", loss.item(),
- epoch*len(self.dataloader_val) + batch_idx)
- for key, var in stats.items():
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', var.item(),
- epoch * len(self.dataloader_val) + batch_idx)
- for key, var in speed_stats.items():
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
- epoch * len(self.dataloader_val) + batch_idx)
+ batch_num_epoch = -1
+ if hasattr(dataloader_val, "__len__"):
+ batch_num_epoch = len(dataloader_val)
+ self.log(epoch, batch_idx,
+ batch_num_epoch=batch_num_epoch,
+ lr=0.0,
+ loss=loss.detach().cpu().item(),
+ speed_stats=speed_stats,
+ stats=stats,
+ writer=writer,
+ tag="val",
+ )
- self.model.train()
\ No newline at end of file
+ self.val_acc_list.append(self.val_acc_avg)
+ model.train()
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
+ def log(self,
+ epoch=0,
+ batch_idx=0,
+ batch_num_epoch=-1,
+ lr=0.0,
+ loss=0.0,
+ speed_stats=None,
+ stats=None,
+ writer=None,
+ tag="train",
+ ):
+
+ if (batch_idx + 1) % self.log_interval == 0:
+
+ gpu_info = "GPU, memory: usage: {:.3f} GB, " \
+ "peak: {:.3f} GB, " \
+ "cache: {:.3f} GB, " \
+ "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
+ torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
+ torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
+ )
+
+ loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+ acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
+ description = (
+ f"{tag}, "
+ f"rank: {self.local_rank}, "
+ f"epoch: {epoch}/{self.max_epoch}, "
+ f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
+ f"(loss_avg_rank: {loss:.3f}), "
+ f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
+ f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+ f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+ f"(lr: {lr:.3e}), "
+ f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
+ f"{speed_stats}, "
+ f"{gpu_info}"
+ )
+ logging.info(description)
+
+ if writer is not None:
+ writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
+ writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
+ writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
+ for key, var in stats.items():
+ writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+ for key, var in speed_stats.items():
+ writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+
+ def close(self, writer=None):
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ if writer is not None:
+ writer.close()
+
+ if self.use_ddp or self.use_fsdp:
+ torch.distributed.destroy_process_group()
\ No newline at end of file
diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
deleted file mode 100644
index 5f13b5a..0000000
--- a/funasr/train_utils/trainer_llm.py
+++ /dev/null
@@ -1,502 +0,0 @@
-import math
-import os
-import time
-import torch
-import logging
-from tqdm import tqdm
-from datetime import datetime
-import torch.distributed as dist
-from torch.cuda.amp import autocast, GradScaler
-from contextlib import nullcontext, contextmanager
-from pathlib import Path
-
-from funasr.train_utils.device_funcs import to_device
-from funasr.train_utils.recursive_op import recursive_average
-from funasr.train_utils.average_nbest_models import average_checkpoints
-from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-
-@contextmanager
-def maybe_autocast(enabled):
- if enabled:
- with autocast():
- yield
- else:
- yield
-
-class Trainer:
- """
- A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
- and optionally resuming from a saved checkpoint.
-
- Attributes:
- max_epoch (int): Maximum number of epochs for training.
- model (torch.nn.Module): The model to be trained.
- optim (torch.optim.Optimizer): The optimizer to use for training.
- scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
- dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
- dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
- output_dir (str): Directory where model checkpoints will be saved.
- resume (str, optional): Path to a checkpoint to resume training from.
- """
-
- def __init__(self,
- local_rank,
- use_ddp: bool = False,
- use_fsdp: bool = False,
- use_fp16: bool = False,
- output_dir: str="./",
- **kwargs):
- """
- Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
-
- Args:
- model (torch.nn.Module): The model to be trained.
- optim (torch.optim.Optimizer): The optimizer to use for training.
- scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
- dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
- dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
- **kwargs: Additional keyword arguments:
- max_epoch (int): The maximum number of epochs for training.
- output_dir (str): The directory where model checkpoints will be saved. Default is './'.
- resume (str, optional): The file path to a checkpoint to resume training from.
- """
-
- self.output_dir = output_dir
- if not os.path.exists(self.output_dir):
- os.makedirs(self.output_dir, exist_ok=True)
- self.resume = kwargs.get('resume', True)
- self.start_epoch = 0
- self.max_epoch = kwargs.get('max_epoch', 100)
- self.local_rank = local_rank
- self.use_ddp = use_ddp
- self.use_fsdp = use_fsdp
- self.device = kwargs.get('device', "cuda")
- self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
- # self.kwargs = kwargs
- self.log_interval = kwargs.get("log_interval", 50)
- self.batch_total = 0
- self.use_fp16 = use_fp16
- self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
- # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
- # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
- # self.scaler = scaler
- self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
- self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
- self.accum_grad = kwargs.get("accum_grad", 1)
- self.grad_clip = kwargs.get("grad_clip", 10.0)
- self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
- self.validate_interval = kwargs.get("validate_interval", 5000)
-
-
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- logging.warning("distributed is not initialized, only single shard")
- self.rank = rank
- self.world_size = world_size
- self.train_acc_avg = 0.0
- self.train_loss_avg = 0.0
- self.val_acc_avg = 0.0
- self.val_loss_avg = 0.0
- self.best_acc_idx = 0
- self.saved_ckpts = {}
- self.val_acc_list = []
- self.step_or_epoch = -1
-
-
-
-
-
- def save_checkpoint(self, epoch,
- step=None,
- model=None,
- optim=None,
- scheduler=None,
- scaler=None,
- ):
- """
- Saves a checkpoint containing the model's state, the optimizer's state,
- and the scheduler's state at the end of the given epoch. This method is
- intended to be called at the end of each epoch to save the training progress.
-
- Args:
- epoch (int): The epoch number at which the checkpoint is being saved.
- """
-
- if self.rank == 0:
- logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
- self.step_or_epoch += 1
- state = {
- 'epoch': epoch,
- 'state_dict': model.state_dict(),
- 'optimizer': optim.state_dict(),
- 'scheduler': scheduler.state_dict(),
- "acc": self.val_acc_list,
- "step_or_epoch": self.step_or_epoch,
- }
- if hasattr(model, "module"):
- state["state_dict"] = model.module.state_dict()
-
- if scaler:
- state["scaler_state"] = scaler.state_dict()
- # Create output directory if it does not exist
- os.makedirs(self.output_dir, exist_ok=True)
- if step is None:
- ckpt_name = f'model.pt.ep{epoch}'
- else:
- ckpt_name = f'model.pt.ep{epoch}.{step}'
- filename = os.path.join(self.output_dir, ckpt_name)
- torch.save(state, filename)
-
- logging.info(f'\nCheckpoint saved to {filename}\n')
- latest = Path(os.path.join(self.output_dir, f'model.pt'))
- torch.save(state, latest)
-
- if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
- self.best_acc_idx = self.step_or_epoch
- best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
- torch.save(state, best_ckpt)
- logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
- else:
- logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
-
- if self.keep_nbest_models > 0:
- self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
- if len(self.saved_ckpts) > self.keep_nbest_models:
-
- min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
- if min_key in self.saved_ckpts:
- del self.saved_ckpts[min_key]
- filename = os.path.join(self.output_dir, min_key)
- logging.info(f"Delete: {filename}")
- if os.path.exists(filename):
- os.remove(filename)
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
- def resume_checkpoint(self,
- model=None,
- optim=None,
- scheduler=None,
- scaler=None,
- ):
- """
- Resumes training from a checkpoint at the given file path.
- Loads the model's state, the optimizer's state, and the scheduler's state.
-
- Args:
- resume_path (str): The file path to the checkpoint to resume from.
- """
- if self.resume:
- ckpt = os.path.join(self.output_dir, "model.pt")
- if os.path.isfile(ckpt):
- checkpoint = torch.load(ckpt)
- self.start_epoch = checkpoint['epoch'] + 1
- # self.model.load_state_dict(checkpoint['state_dict'])
- src_state = checkpoint['state_dict']
- dst_state = model.state_dict()
- for k in dst_state.keys():
- if not k.startswith("module.") and "module."+k in src_state.keys():
- k_ddp = "module."+k
- else:
- k_ddp = k
- if k_ddp in src_state.keys():
- dst_state[k] = src_state[k_ddp]
- else:
- print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
-
- model.load_state_dict(dst_state)
- optim.load_state_dict(checkpoint['optimizer'])
- scheduler.load_state_dict(checkpoint['scheduler'])
- if scaler is not None and 'scaler_state' in checkpoint:
- scaler.load_state_dict(checkpoint['scaler_state'])
-
- self.val_acc_list = checkpoint["acc"]
- self.step_or_epoch = checkpoint["step_or_epoch"]
-
- print(f"Checkpoint loaded successfully from '{ckpt}'")
- else:
- print(f"No checkpoint found at '{ckpt}', does not resume status!")
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
-
- def train_epoch(self,
- model=None,
- optim=None,
- scheduler=None,
- scaler=None,
- dataloader_train=None,
- dataloader_val=None,
- epoch=None,
- writer=None,
- ):
- """
- Defines the training process for a single epoch with gradient accumulation.
- Args:
- epoch (int): The current epoch number.
- """
- logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
- model.train()
-
- # Set the number of steps for gradient accumulation
- accum_grad = self.accum_grad
- # Initialize the gradient accumulation
- optim.zero_grad()
- speed_stats = {}
- time5 = time.perf_counter()
-
- for batch_idx, batch in enumerate(dataloader_train):
- self.batch_total += 1
- time1 = time.perf_counter()
- speed_stats["data_load"] = f"{time1-time5:0.3f}"
-
- batch = to_device(batch, self.device)
-
- my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
- with my_context():
- time2 = time.perf_counter()
- with maybe_autocast(self.use_fp16):
- retval = model(**batch)
-
- if self.disable_gpu_cache: torch.cuda.empty_cache()
-
- time3 = time.perf_counter()
- speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
- loss, stats, weight = retval
- stats = {k: v for k, v in stats.items() if v is not None}
- if self.use_ddp or self.use_fsdp:
- # Apply weighted averaging for loss and stats
- loss = (loss * weight.type(loss.dtype)).sum()
- # if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed=True)
- # Now weight is summation over all workers
- loss /= weight
- # Multiply world_size because DistributedDataParallel
- # automatically normalizes the gradient by world_size.
- loss *= self.world_size
- # Scale the loss since we're not updating for every mini-batch
- loss = loss / accum_grad
- if self.use_fp16:
- scaler.scale(loss).backward()
- else:
- loss.backward()
- time4 = time.perf_counter()
- speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
-
- self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
- if "acc" in stats:
- self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
- if self.use_ddp or self.use_fsdp:
- train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
- train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
- dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
- self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
- self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
-
-
- # Perform an optimizer step only after accumulating enough gradients
- if (batch_idx + 1) % accum_grad == 0:
- # Perform gradient clipping if it is set
- if self.grad_clip > 0:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- model.parameters(),
- max_norm=self.grad_clip,
- norm_type=self.grad_clip_type,
- )
- if not torch.isfinite(grad_norm):
- logging.warning(
- f"The grad norm is {grad_norm}. Skipping updating the model."
- )
- optim.zero_grad() # Reset gradients
- continue
-
- # Execute an optimization step (update model parameters)
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
- if self.use_fp16:
- scaler.step(optim)
- scaler.update()
- else:
- optim.step()
- scheduler.step()
- # Clear gradients for the next accumulation stage
- optim.zero_grad(set_to_none=True)
- total_time = f"{time.perf_counter() - time5:0.3f}"
- time5 = time.perf_counter()
- speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
-
- speed_stats["total_time"] = total_time
- lr = scheduler.get_last_lr()[0]
- batch_num_epoch = -1
- if hasattr(dataloader_train, "__len__"):
- batch_num_epoch = len(dataloader_train)
- self.log(epoch, batch_idx,
- batch_num_epoch=batch_num_epoch,
- lr=lr,
- loss=loss.detach().cpu().item(),
- speed_stats=speed_stats,
- stats=stats,
- writer=writer,
- tag="train",
- )
-
- if (batch_idx + 1) % self.validate_interval == 0:
- self.validate_epoch(
- model=model,
- dataloader_val=dataloader_val,
- epoch=epoch,
- writer=writer
- )
-
- if (batch_idx+1) % self.save_checkpoint_interval == 0:
- self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
-
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
-
-
- def validate_epoch(self,
- model=None,
- dataloader_val=None,
- epoch=None,
- writer=None,
- **kwargs,
- ):
- """
- Defines the validation process for a single epoch.
- Should be implemented with the actual model validation steps.
-
- Args:
- epoch (int): The current epoch number.
- """
- logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
- model.eval()
-
- with torch.no_grad():
-
- speed_stats = {}
- time5 = time.perf_counter()
- for batch_idx, batch in enumerate(dataloader_val):
- time1 = time.perf_counter()
- speed_stats["data_load"] = f"{time1 - time5:0.3f}"
- batch = to_device(batch, self.device)
- time2 = time.perf_counter()
- retval = model(**batch)
- time3 = time.perf_counter()
- speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
- loss, stats, weight = retval
- stats = {k: v for k, v in stats.items() if v is not None}
- if self.use_ddp or self.use_fsdp:
- # Apply weighted averaging for loss and stats
- loss = (loss * weight.type(loss.dtype)).sum()
- # if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed=True)
- # Now weight is summation over all workers
- loss /= weight
- # Multiply world_size because DistributedDataParallel
- # automatically normalizes the gradient by world_size.
- loss *= self.world_size
- # Scale the loss since we're not updating for every mini-batch
- loss = loss
- time4 = time.perf_counter()
-
- self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
- if "acc" in stats:
- self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
- if self.use_ddp or self.use_fsdp:
- val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
- val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
- dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
- self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
- self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
-
- batch_num_epoch = -1
- if hasattr(dataloader_val, "__len__"):
- batch_num_epoch = len(dataloader_val)
- self.log(epoch, batch_idx,
- batch_num_epoch=batch_num_epoch,
- lr=0.0,
- loss=loss.detach().cpu().item(),
- speed_stats=speed_stats,
- stats=stats,
- writer=writer,
- tag="val",
- )
-
- self.val_acc_list.append(self.val_acc_avg)
- model.train()
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
-
- def log(self,
- epoch=0,
- batch_idx=0,
- batch_num_epoch=-1,
- lr=0.0,
- loss=0.0,
- speed_stats=None,
- stats=None,
- writer=None,
- tag="train",
- ):
-
- if (batch_idx + 1) % self.log_interval == 0:
-
- gpu_info = "GPU, memory: usage: {:.3f} GB, " \
- "peak: {:.3f} GB, " \
- "cache: {:.3f} GB, " \
- "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
- torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
- torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
- torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
- )
-
- loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
- acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
- description = (
- f"{tag}, "
- f"rank: {self.local_rank}, "
- f"epoch: {epoch}/{self.max_epoch}, "
- f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
- f"(loss_avg_rank: {loss:.3f}), "
- f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
- f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
- f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
- f"(lr: {lr:.3e}), "
- f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
- f"{speed_stats}, "
- f"{gpu_info}"
- )
- logging.info(description)
-
- if writer is not None:
- writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
- writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
- writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
- for key, var in stats.items():
- writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
- for key, var in speed_stats.items():
- writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
-
- def close(self, writer=None):
-
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
-
- if writer is not None:
- writer.close()
-
- if self.use_ddp or self.use_fsdp:
- torch.distributed.destroy_process_group()
\ No newline at end of file
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index a08f263..ef18a61 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -1,7 +1,9 @@
+import os
import io
+import shutil
from collections import OrderedDict
import numpy as np
-
+from omegaconf import DictConfig, OmegaConf
def statistic_model_parameters(model, prefix=None):
var_dict = model.state_dict()
@@ -52,4 +54,21 @@
if isinstance(value, dict) and key in original:
deep_update(original[key], value)
else:
- original[key] = value
\ No newline at end of file
+ original[key] = value
+
+
+def prepare_model_dir(**kwargs):
+
+
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+ OmegaConf.save(config=kwargs, f=yaml_file)
+ print(kwargs)
+ logging.info("config.yaml is saved to: %s", yaml_file)
+
+ # model_path = kwargs.get("model_path")
+ # if model_path is not None:
+ # config_json = os.path.join(model_path, "configuration.json")
+ # if os.path.exists(config_json):
+ # shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
deleted file mode 100644
index 36eebdc..0000000
--- a/funasr/utils/prepare_data.py
+++ /dev/null
@@ -1,242 +0,0 @@
-import logging
-import os
-import shutil
-from multiprocessing import Pool
-
-import kaldiio
-import numpy as np
-import librosa
-import torch.distributed as dist
-import torchaudio
-
-
-def filter_wav_text(data_dir, dataset):
- wav_file = os.path.join(data_dir, dataset, "wav.scp")
- text_file = os.path.join(data_dir, dataset, "text")
- with open(wav_file) as f_wav, open(text_file) as f_text:
- wav_lines = f_wav.readlines()
- text_lines = f_text.readlines()
- os.rename(wav_file, "{}.bak".format(wav_file))
- os.rename(text_file, "{}.bak".format(text_file))
- wav_dict = {}
- for line in wav_lines:
- parts = line.strip().split()
- if len(parts) < 2:
- continue
- wav_dict[parts[0]] = parts[1]
- text_dict = {}
- for line in text_lines:
- parts = line.strip().split()
- if len(parts) < 2:
- continue
- text_dict[parts[0]] = " ".join(parts[1:])
- filter_count = 0
- with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
- for sample_name, wav_path in wav_dict.items():
- if sample_name in text_dict.keys():
- f_wav.write(sample_name + " " + wav_path + "\n")
- f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
- else:
- filter_count += 1
- logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".
- format(filter_count, len(wav_lines), dataset))
-
-
-def wav2num_frame(wav_path, frontend_conf):
- try:
- waveform, sampling_rate = torchaudio.load(wav_path)
- except:
- waveform, sampling_rate = librosa.load(wav_path)
- waveform = np.expand_dims(waveform, axis=0)
- n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
- feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
- return n_frames, feature_dim
-
-
-def calc_shape_core(root_path, args, idx):
- file_name = args.data_file_names.split(",")[0]
- data_name = args.dataset_conf.get("data_names", "speech,text").split(",")[0]
- scp_file = os.path.join(root_path, "{}.{}".format(file_name, idx))
- shape_file = os.path.join(root_path, "{}_shape.{}".format(data_name, idx))
- with open(scp_file) as f:
- lines = f.readlines()
- data_type = args.dataset_conf.get("data_types", "sound,text").split(",")[0]
- if data_type == "sound":
- frontend_conf = args.frontend_conf
- dataset_conf = args.dataset_conf
- length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "{}_length_min".format(data_name)) else -1
- length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "{}_length_max".format(data_name)) else -1
- with open(shape_file, "w") as f:
- for line in lines:
- sample_name, wav_path = line.strip().split()
- n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
- write_flag = True
- if n_frames > 0 and length_min > 0:
- write_flag = n_frames >= length_min
- if n_frames > 0 and length_max > 0:
- write_flag = n_frames <= length_max
- if write_flag:
- f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
- f.flush()
- elif data_type == "kaldi_ark":
- dataset_conf = args.dataset_conf
- length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "{}_length_min".format(data_name)) else -1
- length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "{}_length_max".format(data_name)) else -1
- with open(shape_file, "w") as f:
- for line in lines:
- sample_name, feature_path = line.strip().split()
- feature = kaldiio.load_mat(feature_path)
- n_frames, feature_dim = feature.shape
- write_flag = True
- if n_frames > 0 and length_min > 0:
- write_flag = n_frames >= length_min
- if n_frames > 0 and length_max > 0:
- write_flag = n_frames <= length_max
- if write_flag:
- f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
- f.flush()
- elif data_type == "text":
- with open(shape_file, "w") as f:
- for line in lines:
- sample_name, text = line.strip().split(maxsplit=1)
- n_tokens = len(text.split())
- f.write("{} {}\n".format(sample_name, str(int(np.ceil(n_tokens)))))
- f.flush()
- else:
- raise RuntimeError("Unsupported data_type: {}".format(data_type))
-
-
-def calc_shape(args, dataset, nj=64):
- data_name = args.dataset_conf.get("data_names", "speech,text").split(",")[0]
- shape_path = os.path.join(args.data_dir, dataset, "{}_shape".format(data_name))
- if os.path.exists(shape_path):
- logging.info('Shape file for small dataset already exists.')
- return
-
- split_shape_path = os.path.join(args.data_dir, dataset, "{}_shape_files".format(data_name))
- if os.path.exists(split_shape_path):
- shutil.rmtree(split_shape_path)
- os.mkdir(split_shape_path)
-
- # split
- file_name = args.data_file_names.split(",")[0]
- scp_file = os.path.join(args.data_dir, dataset, file_name)
- with open(scp_file) as f:
- lines = f.readlines()
- num_lines = len(lines)
- num_job_lines = num_lines // nj
- start = 0
- for i in range(nj):
- end = start + num_job_lines
- file = os.path.join(split_shape_path, "{}.{}".format(file_name, str(i + 1)))
- with open(file, "w") as f:
- if i == nj - 1:
- f.writelines(lines[start:])
- else:
- f.writelines(lines[start:end])
- start = end
-
- p = Pool(nj)
- for i in range(nj):
- p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
- logging.info("Generating shape files, please wait a few minutes...")
- p.close()
- p.join()
-
- # combine
- with open(shape_path, "w") as f:
- for i in range(nj):
- job_file = os.path.join(split_shape_path, "{}_shape.{}".format(data_name, str(i + 1)))
- with open(job_file) as job_f:
- lines = job_f.readlines()
- f.writelines(lines)
- logging.info('Generating shape files done.')
-
-
-def generate_data_list(args, data_dir, dataset, nj=64):
- data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
- file_names = args.data_file_names.split(",")
- concat_data_name = "_".join(data_names)
- list_file = os.path.join(data_dir, dataset, "{}_data.list".format(concat_data_name))
- if os.path.exists(list_file):
- logging.info('Data list for large dataset already exists.')
- return
- split_path = os.path.join(data_dir, dataset, "split")
- if os.path.exists(split_path):
- shutil.rmtree(split_path)
- os.mkdir(split_path)
-
- data_lines_list = []
- for file_name in file_names:
- with open(os.path.join(data_dir, dataset, file_name)) as f:
- lines = f.readlines()
- data_lines_list.append(lines)
- num_lines = len(data_lines_list[0])
- num_job_lines = num_lines // nj
- start = 0
- for i in range(nj):
- end = start + num_job_lines
- split_path_nj = os.path.join(split_path, str(i + 1))
- os.mkdir(split_path_nj)
- for file_id, file_name in enumerate(file_names):
- file = os.path.join(split_path_nj, file_name)
- with open(file, "w") as f:
- if i == nj - 1:
- f.writelines(data_lines_list[file_id][start:])
- else:
- f.writelines(data_lines_list[file_id][start:end])
- start = end
-
- with open(list_file, "w") as f_data:
- for i in range(nj):
- path = ""
- for file_name in file_names:
- path = path + " " + os.path.join(split_path, str(i + 1), file_name)
- f_data.write(path + "\n")
-
-
-def prepare_data(args, distributed_option):
- data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
- data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
- file_names = args.data_file_names.split(",")
- batch_type = args.dataset_conf["batch_conf"]["batch_type"]
- print("data_names: {}, data_types: {}, file_names: {}".format(data_names, data_types, file_names))
- assert len(data_names) == len(data_types) == len(file_names)
- if args.dataset_type == "small":
- args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "{}_shape".format(data_names[0]))]
- args.valid_shape_file = [os.path.join(args.data_dir, args.valid_set, "{}_shape".format(data_names[0]))]
- args.train_data_path_and_name_and_type, args.valid_data_path_and_name_and_type = [], []
- for file_name, data_name, data_type in zip(file_names, data_names, data_types):
- args.train_data_path_and_name_and_type.append(
- ["{}/{}/{}".format(args.data_dir, args.train_set, file_name), data_name, data_type])
- args.valid_data_path_and_name_and_type.append(
- ["{}/{}/{}".format(args.data_dir, args.valid_set, file_name), data_name, data_type])
- if os.path.exists(args.train_shape_file[0]):
- assert os.path.exists(args.valid_shape_file[0])
- print('shape file for small dataset already exists.')
- return
- else:
- concat_data_name = "_".join(data_names)
- args.train_data_file = os.path.join(args.data_dir, args.train_set, "{}_data.list".format(concat_data_name))
- args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "{}_data.list".format(concat_data_name))
- if os.path.exists(args.train_data_file):
- assert os.path.exists(args.valid_data_file)
- print('data list for large dataset already exists.')
- return
-
- distributed = distributed_option.distributed
- if not distributed or distributed_option.dist_rank == 0:
- if hasattr(args, "filter_input") and args.filter_input:
- filter_wav_text(args.data_dir, args.train_set)
- filter_wav_text(args.data_dir, args.valid_set)
-
- if args.dataset_type == "small" and batch_type != "unsorted":
- calc_shape(args, args.train_set)
- calc_shape(args, args.valid_set)
-
- if args.dataset_type == "large":
- generate_data_list(args, args.data_dir, args.train_set)
- generate_data_list(args, args.data_dir, args.valid_set)
-
- if distributed:
- dist.barrier()
--
Gitblit v1.9.1