From 500197b8ad561929bd782b5a69f6a1042e909090 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 07 三月 2024 15:24:42 +0800
Subject: [PATCH] Dev gzf (#1443)
---
examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml | 11 +-
examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh | 15 +--
funasr/models/llm_asr/model.py | 2
funasr/datasets/llm_datasets_vicuna/samplers.py | 206 +++++++++++++++++++++++++++++++--------------------
4 files changed, 139 insertions(+), 95 deletions(-)
diff --git a/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml b/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml
index eacbd35..331ac0d 100644
--- a/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml
+++ b/examples/industrial_data_pretraining/llm_asr/conf/whisper_vicuna_linear.yaml
@@ -58,7 +58,7 @@
train_conf:
accum_grad: 1
grad_clip: 5
- max_epoch: 150
+ max_epoch: 15
keep_nbest_models: 10
log_interval: 10
@@ -68,16 +68,15 @@
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
- warmup_steps: 1500
+ warmup_steps: 1000
dataset: AudioLLMVicunaDataset
dataset_conf:
index_ds: IndexDSJsonl
- batch_sampler: RankFullLocalShuffleBatchSampler
+ batch_sampler: CustomDistributedBatchSampler
batch_type: example # example or length
- batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
- max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
- buffer_size: 500
+ batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 3000 # filter samples if source_token_len+target_token_len > max_token_length,
shuffle: True
num_workers: 4
# preprocessor_text: TextPreprocessRemovePunctuation
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
index 6632175..6b88e75 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
@@ -15,8 +15,8 @@
#++data_type_list='["source", "target"]' \
#++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
-train_data="/nfs/zhifu.gzf/data/datalist/aishell1_aishell2_wav_speech_llm_train_data_del_tail500.json"
-val_data="/nfs/zhifu.gzf/data/datalist/aishell1_aishell2_wav_speech_llm_train_data_tail500.json"
+train_data="/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl"
+val_data="/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl"
# exp output dir
output_dir="/nfs/zhifu.gzf/ckpt/exp/llm_asr_whisper_vicuna_exp1"
@@ -38,10 +38,9 @@
--config-name "${config}" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
-++dataset_conf.batch_size=2 \
-++dataset_conf.batch_type="example" \
-++dataset_conf.num_workers=0 \
-++train_conf.max_epoch=11 \
-++optim_conf.lr=0.0002 \
+++dataset_conf.batch_size=4 \
+++dataset_conf.num_workers=4 \
+++train_conf.max_epoch=15 \
+++optim_conf.lr=0.0001 \
++init_param="${init_param}" \
-++output_dir="${output_dir}" &> ${log_file}
+++output_dir="${output_dir}" &> ${log_file} &
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index 29ec11d..512250d 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -1,91 +1,15 @@
import torch
import numpy as np
import logging
+import math
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+from torch.utils.data import BatchSampler, Sampler
import torch.distributed as dist
from funasr.register import tables
-@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
-class BatchSampler(torch.utils.data.BatchSampler):
-
- def __init__(self, dataset,
- batch_type: str = "example",
- batch_size: int = 100,
- buffer_size: int = 30,
- drop_last: bool = False,
- shuffle: bool = True,
- is_training: bool = True,
- **kwargs):
-
- self.drop_last = drop_last
- self.pre_idx = -1
- self.dataset = dataset
- self.total_samples = len(dataset)
- self.batch_type = batch_type
- self.batch_size = int(batch_size)
- self.buffer_size = buffer_size
- self.max_token_length = kwargs.get("max_token_length", 5000)
- self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle and is_training
- self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
-
- def __len__(self):
- return (self.total_samples-1) // self.batch_size + 1
-
- def set_epoch(self, epoch):
- np.random.seed(epoch)
-
- def __iter__(self):
-
- if self.shuffle:
- np.random.shuffle(self.shuffle_idx)
-
- batch = []
- max_token = 0
- num_sample = 0
-
- iter_num = (self.total_samples - 1) // self.buffer_size + 1
- # print("iter_num: ", iter_num)
- for iter in range(self.pre_idx + 1, iter_num):
- datalen_with_index = []
- for i in range(self.buffer_size):
- idx = iter * self.buffer_size + i
- if idx >= self.total_samples:
- continue
-
- idx_map = self.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
- source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
- sample_len_cur = source_len + target_len
-
-
- datalen_with_index.append([idx, sample_len_cur])
-
- datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
- for item in datalen_with_index_sort:
- idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_token_length:
- continue
-
- max_token_cur = max(max_token, sample_len_cur_raw)
- max_token_padding = 1 + num_sample
- if self.batch_type != 'example':
- max_token_padding *= max_token_cur
- if max_token_padding <= self.batch_size:
- batch.append(idx)
- max_token = max_token_cur
- num_sample += 1
- else:
- yield batch
- batch = [idx]
- max_token = sample_len_cur_raw
- num_sample = 1
-
-
-@tables.register("batch_sampler_classes", "BatchSampler")
@tables.register("batch_sampler_classes", "RankFullGlobalShuffleBatchSampler")
class RankFullGlobalShuffleBatchSampler(torch.utils.data.BatchSampler):
@@ -177,3 +101,125 @@
max_token = sample_len_cur_raw
num_sample = 1
+@tables.register("batch_sampler_classes", "DistributedSamplerWarp")
+class DistributedSamplerWarp(BatchSampler):
+ def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False):
+ if num_replicas is None:
+ if not torch.distributed.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = torch.distributed.get_world_size()
+ if rank is None:
+ if not torch.distributed.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = torch.distributed.get_rank()
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.shuffle = shuffle
+ self.drop_last = drop_last
+
+ # Create an instance of the DistributedSampler
+ self.sampler = DistributedSampler(
+ self.dataset,
+ num_replicas=self.num_replicas,
+ rank=self.rank,
+ shuffle=self.shuffle
+ )
+
+ # Call BatchSampler's constructor
+ super().__init__(self.sampler, batch_size, drop_last)
+
+ def __iter__(self):
+ # If we shuffle, we need to call the set_epoch method
+ if self.shuffle:
+ self.sampler.set_epoch(self.epoch)
+
+ # Generate batch indices using the parent class
+ return super().__iter__()
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
+class CustomDistributedBatchSampler(Sampler):
+ def __init__(self, dataset,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ **kwargs,
+ ):
+
+ try:
+ rank = dist.get_rank()
+ num_replicas = dist.get_world_size()
+ except:
+ rank = 0
+ num_replicas = 1
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.shuffle = shuffle and is_training
+ self.drop_last = drop_last
+ # self.total_size = len(dataset)
+ if self.drop_last:
+ self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+ else:
+ self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+ self.num_samples = int(self.total_size // self.num_replicas)
+ self.epoch = 0
+ self.max_token_length = kwargs.get("max_token_length", None)
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+
+ def __iter__(self):
+ # Generate a list of indices
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ # Add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
+
+ assert len(indices) == self.total_size
+
+ # Subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ # Filter out indices with length greater than the max length, if provided
+ if self.max_token_length is not None:
+ filtered_indices = []
+ for idx in indices:
+ source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+ if source_len <= self.max_token_length:
+ filtered_indices.append(idx)
+ indices = filtered_indices
+
+ # Now that we have only the indices for this replica, chunk them into batches
+ batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
+
+ # Drop the last batch if it's not full and drop_last is True
+ if self.drop_last and len(batches[-1]) != self.batch_size:
+ batches = batches[:-1]
+
+ return iter(batches)
+
+ def __len__(self):
+
+ return self.num_samples // self.batch_size
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 3223190..f5dd3b1 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -218,7 +218,7 @@
):
speech = speech.permute(0, 2, 1)
res = self.audio_encoder(speech)
- if len(res) > 1:
+ if isinstance(res, (list, tuple)):
encoder_out, encoder_out_lens = res[0], res[1]
else:
encoder_out, encoder_out_lens = res, speech_lengths
--
Gitblit v1.9.1