Funasr1.0 (#1261)
* funasr1.0 funetine
* funasr1.0 pbar
* update with main (#1260)
* Update websocket_protocol_zh.md
* update
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
---------
Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
| | |
| | | python funasr/bin/train.py \ |
| | | +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ |
| | | +model_revision="v2.0.2" \ |
| | | +train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ |
| | | +train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ |
| | | +valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ |
| | | ++dataset_conf.batch_size=2 \ |
| | | ++dataset_conf.batch_type="example" \ |
| | | ++train_conf.max_epoch=2 \ |
| | | +output_dir="outputs/debug/ckpt/funasr2/exp2" \ |
| | | +device="cpu" \ |
| | | +debug="true" |
| | |
| | | spk_model_revision="v2.0.2", |
| | | ) |
| | | |
| | | res = model.generate(input=f"{model.model_path}/example/asr_example.wav", |
| | | res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", |
| | | hotword='达摩院 魔搭') |
| | | print(res) |
| | |
| | | speed_stats = {} |
| | | asr_result_list = [] |
| | | num_samples = len(data_list) |
| | | pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) |
| | | disable_pbar = kwargs.get("disable_pbar", False) |
| | | pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None |
| | | time_speech_total = 0.0 |
| | | time_escape_total = 0.0 |
| | | for beg_idx in range(0, num_samples, batch_size): |
| | |
| | | time2 = time.perf_counter() |
| | | |
| | | asr_result_list.extend(results) |
| | | pbar.update(1) |
| | | |
| | | |
| | | # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() |
| | | batch_data_time = meta_data.get("batch_data_time", -1) |
| | | time_escape = time2 - time1 |
| | |
| | | description = ( |
| | | f"{speed_stats}, " |
| | | ) |
| | | pbar.set_description(description) |
| | | if pbar: |
| | | pbar.update(1) |
| | | pbar.set_description(description) |
| | | time_speech_total += batch_data_time |
| | | time_escape_total += time_escape |
| | | |
| | | pbar.update(1) |
| | | pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") |
| | | |
| | | if pbar: |
| | | pbar.update(1) |
| | | pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") |
| | | torch.cuda.empty_cache() |
| | | return asr_result_list |
| | | |
| | |
| | | time_speech_total_per_sample = speech_lengths/16000 |
| | | time_speech_total_all_samples += time_speech_total_per_sample |
| | | |
| | | pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True) |
| | | |
| | | all_segments = [] |
| | | for j, _ in enumerate(range(0, n)): |
| | | pbar_sample.update(1) |
| | | batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) |
| | | if j < n - 1 and ( |
| | | batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and ( |
| | |
| | | batch_size_ms_cum = 0 |
| | | end_idx = j + 1 |
| | | speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) |
| | | results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) |
| | | results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg) |
| | | if self.spk_model is not None: |
| | | |
| | | |
| | | |
| | | # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] |
| | | for _b in range(len(speech_j)): |
| | | vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \ |
| | | sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \ |
| | | vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, |
| | | sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, |
| | | speech_j[_b]]] |
| | | segments = sv_chunk(vad_segments) |
| | | all_segments.extend(segments) |
| | |
| | | results_sorted.extend(results) |
| | | |
| | | |
| | | pbar_total.update(1) |
| | | |
| | | end_asr_total = time.time() |
| | | time_escape_total_per_sample = end_asr_total - beg_asr_total |
| | | pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " |
| | | pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " |
| | | f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " |
| | | f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") |
| | | |
| | | |
| | | restored_data = [0] * n |
| | | for j in range(n): |
| | |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, |
| | | **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler_val = None |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | |
| | | |
| | | dataloader_val = torch.utils.data.DataLoader(dataset_val, |
| | | collate_fn=dataset_val.collator, |
| | | batch_sampler=batch_sampler_val, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |
| | | scheduler=scheduler, |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=None, |
| | | dataloader_val=dataloader_val, |
| | | local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | |
| | | return len(self.contents) |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | try: |
| | | data = self.contents[index] |
| | | except: |
| | | print(index) |
| | | return data |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | |
| | | buffer_size: int = 30, |
| | | drop_last: bool = False, |
| | | shuffle: bool = True, |
| | | is_training: bool = True, |
| | | **kwargs): |
| | | |
| | | self.drop_last = drop_last |
| | |
| | | self.buffer_size = buffer_size |
| | | self.max_token_length = kwargs.get("max_token_length", 5000) |
| | | self.shuffle_idx = np.arange(self.total_samples) |
| | | self.shuffle = shuffle |
| | | self.shuffle = shuffle and is_training |
| | | |
| | | def __len__(self): |
| | | return self.total_samples |
| | |
| | | self.use_1st_decoder_loss = use_1st_decoder_loss |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.error_calculator = None |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | - acc |
| | | - max |
| | | keep_nbest_models: 10 |
| | | avg_nbest_model: 5 |
| | | log_interval: 50 |
| | | |
| | | optim: adam |
| | |
| | | |
| | | import torch |
| | | from typing import Collection |
| | | import os |
| | | import torch |
| | | import re |
| | | from collections import OrderedDict |
| | | from functools import cmp_to_key |
| | | |
| | | from funasr.train.reporter import Reporter |
| | | |
| | | # @torch.no_grad() |
| | | # def average_nbest_models( |
| | | # output_dir: Path, |
| | | # best_model_criterion: Sequence[Sequence[str]], |
| | | # nbest: Union[Collection[int], int], |
| | | # suffix: Optional[str] = None, |
| | | # oss_bucket=None, |
| | | # pai_output_dir=None, |
| | | # ) -> None: |
| | | # """Generate averaged model from n-best models |
| | | # |
| | | # Args: |
| | | # output_dir: The directory contains the model file for each epoch |
| | | # reporter: Reporter instance |
| | | # best_model_criterion: Give criterions to decide the best model. |
| | | # e.g. [("valid", "loss", "min"), ("train", "acc", "max")] |
| | | # nbest: Number of best model files to be averaged |
| | | # suffix: A suffix added to the averaged model file name |
| | | # """ |
| | | # if isinstance(nbest, int): |
| | | # nbests = [nbest] |
| | | # else: |
| | | # nbests = list(nbest) |
| | | # if len(nbests) == 0: |
| | | # warnings.warn("At least 1 nbest values are required") |
| | | # nbests = [1] |
| | | # if suffix is not None: |
| | | # suffix = suffix + "." |
| | | # else: |
| | | # suffix = "" |
| | | # |
| | | # # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] |
| | | # nbest_epochs = [ |
| | | # (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) |
| | | # for ph, k, m in best_model_criterion |
| | | # if reporter.has(ph, k) |
| | | # ] |
| | | # |
| | | # _loaded = {} |
| | | # for ph, cr, epoch_and_values in nbest_epochs: |
| | | # _nbests = [i for i in nbests if i <= len(epoch_and_values)] |
| | | # if len(_nbests) == 0: |
| | | # _nbests = [1] |
| | | # |
| | | # for n in _nbests: |
| | | # if n == 0: |
| | | # continue |
| | | # elif n == 1: |
| | | # # The averaged model is same as the best model |
| | | # e, _ = epoch_and_values[0] |
| | | # op = output_dir / f"{e}epoch.pb" |
| | | # sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" |
| | | # if sym_op.is_symlink() or sym_op.exists(): |
| | | # sym_op.unlink() |
| | | # sym_op.symlink_to(op.name) |
| | | # else: |
| | | # op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" |
| | | # logging.info( |
| | | # f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' |
| | | # ) |
| | | # |
| | | # avg = None |
| | | # # 2.a. Averaging model |
| | | # for e, _ in epoch_and_values[:n]: |
| | | # if e not in _loaded: |
| | | # if oss_bucket is None: |
| | | # _loaded[e] = torch.load( |
| | | # output_dir / f"{e}epoch.pb", |
| | | # map_location="cpu", |
| | | # ) |
| | | # else: |
| | | # buffer = BytesIO( |
| | | # oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) |
| | | # _loaded[e] = torch.load(buffer) |
| | | # states = _loaded[e] |
| | | # |
| | | # if avg is None: |
| | | # avg = states |
| | | # else: |
| | | # # Accumulated |
| | | # for k in avg: |
| | | # avg[k] = avg[k] + states[k] |
| | | # for k in avg: |
| | | # if str(avg[k].dtype).startswith("torch.int"): |
| | | # # For int type, not averaged, but only accumulated. |
| | | # # e.g. BatchNorm.num_batches_tracked |
| | | # # (If there are any cases that requires averaging |
| | | # # or the other reducing method, e.g. max/min, for integer type, |
| | | # # please report.) |
| | | # pass |
| | | # else: |
| | | # avg[k] = avg[k] / n |
| | | # |
| | | # # 2.b. Save the ave model and create a symlink |
| | | # if oss_bucket is None: |
| | | # torch.save(avg, op) |
| | | # else: |
| | | # buffer = BytesIO() |
| | | # torch.save(avg, buffer) |
| | | # oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), |
| | | # buffer.getvalue()) |
| | | # |
| | | # # 3. *.*.ave.pb is a symlink to the max ave model |
| | | # if oss_bucket is None: |
| | | # op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" |
| | | # sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" |
| | | # if sym_op.is_symlink() or sym_op.exists(): |
| | | # sym_op.unlink() |
| | | # sym_op.symlink_to(op.name) |
| | | |
| | | |
| | | def _get_checkpoint_paths(output_dir: str, last_n: int=5): |
| | | """ |
| | | Get the paths of the last 'last_n' checkpoints by parsing filenames |
| | | in the output directory. |
| | | """ |
| | | # List all files in the output directory |
| | | files = os.listdir(output_dir) |
| | | # Filter out checkpoint files and extract epoch numbers |
| | | checkpoint_files = [f for f in files if f.startswith("model.pt.e")] |
| | | # Sort files by epoch number in descending order |
| | | checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True) |
| | | # Get the last 'last_n' checkpoint paths |
| | | checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]] |
| | | return checkpoint_paths |
| | | |
| | | @torch.no_grad() |
| | | def average_nbest_models( |
| | | output_dir: Path, |
| | | reporter: Reporter, |
| | | best_model_criterion: Sequence[Sequence[str]], |
| | | nbest: Union[Collection[int], int], |
| | | suffix: Optional[str] = None, |
| | | oss_bucket=None, |
| | | pai_output_dir=None, |
| | | ) -> None: |
| | | """Generate averaged model from n-best models |
| | | |
| | | Args: |
| | | output_dir: The directory contains the model file for each epoch |
| | | reporter: Reporter instance |
| | | best_model_criterion: Give criterions to decide the best model. |
| | | e.g. [("valid", "loss", "min"), ("train", "acc", "max")] |
| | | nbest: Number of best model files to be averaged |
| | | suffix: A suffix added to the averaged model file name |
| | | def average_checkpoints(output_dir: str, last_n: int=5): |
| | | """ |
| | | if isinstance(nbest, int): |
| | | nbests = [nbest] |
| | | else: |
| | | nbests = list(nbest) |
| | | if len(nbests) == 0: |
| | | warnings.warn("At least 1 nbest values are required") |
| | | nbests = [1] |
| | | if suffix is not None: |
| | | suffix = suffix + "." |
| | | else: |
| | | suffix = "" |
| | | Average the last 'last_n' checkpoints' model state_dicts. |
| | | If a tensor is of type torch.int, perform sum instead of average. |
| | | """ |
| | | checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) |
| | | state_dicts = [] |
| | | |
| | | # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] |
| | | nbest_epochs = [ |
| | | (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) |
| | | for ph, k, m in best_model_criterion |
| | | if reporter.has(ph, k) |
| | | ] |
| | | # Load state_dicts from checkpoints |
| | | for path in checkpoint_paths: |
| | | if os.path.isfile(path): |
| | | state_dicts.append(torch.load(path, map_location='cpu')['state_dict']) |
| | | else: |
| | | print(f"Checkpoint file {path} not found.") |
| | | continue |
| | | |
| | | _loaded = {} |
| | | for ph, cr, epoch_and_values in nbest_epochs: |
| | | _nbests = [i for i in nbests if i <= len(epoch_and_values)] |
| | | if len(_nbests) == 0: |
| | | _nbests = [1] |
| | | # Check if we have any state_dicts to average |
| | | if not state_dicts: |
| | | raise RuntimeError("No checkpoints found for averaging.") |
| | | |
| | | for n in _nbests: |
| | | if n == 0: |
| | | continue |
| | | elif n == 1: |
| | | # The averaged model is same as the best model |
| | | e, _ = epoch_and_values[0] |
| | | op = output_dir / f"{e}epoch.pb" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" |
| | | if sym_op.is_symlink() or sym_op.exists(): |
| | | sym_op.unlink() |
| | | sym_op.symlink_to(op.name) |
| | | else: |
| | | op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" |
| | | logging.info( |
| | | f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' |
| | | ) |
| | | |
| | | avg = None |
| | | # 2.a. Averaging model |
| | | for e, _ in epoch_and_values[:n]: |
| | | if e not in _loaded: |
| | | if oss_bucket is None: |
| | | _loaded[e] = torch.load( |
| | | output_dir / f"{e}epoch.pb", |
| | | map_location="cpu", |
| | | ) |
| | | else: |
| | | buffer = BytesIO( |
| | | oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) |
| | | _loaded[e] = torch.load(buffer) |
| | | states = _loaded[e] |
| | | |
| | | if avg is None: |
| | | avg = states |
| | | else: |
| | | # Accumulated |
| | | for k in avg: |
| | | avg[k] = avg[k] + states[k] |
| | | for k in avg: |
| | | if str(avg[k].dtype).startswith("torch.int"): |
| | | # For int type, not averaged, but only accumulated. |
| | | # e.g. BatchNorm.num_batches_tracked |
| | | # (If there are any cases that requires averaging |
| | | # or the other reducing method, e.g. max/min, for integer type, |
| | | # please report.) |
| | | pass |
| | | else: |
| | | avg[k] = avg[k] / n |
| | | |
| | | # 2.b. Save the ave model and create a symlink |
| | | if oss_bucket is None: |
| | | torch.save(avg, op) |
| | | else: |
| | | buffer = BytesIO() |
| | | torch.save(avg, buffer) |
| | | oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), |
| | | buffer.getvalue()) |
| | | |
| | | # 3. *.*.ave.pb is a symlink to the max ave model |
| | | if oss_bucket is None: |
| | | op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" |
| | | if sym_op.is_symlink() or sym_op.exists(): |
| | | sym_op.unlink() |
| | | sym_op.symlink_to(op.name) |
| | | # Average or sum weights |
| | | avg_state_dict = OrderedDict() |
| | | for key in state_dicts[0].keys(): |
| | | tensors = [state_dict[key].cpu() for state_dict in state_dicts] |
| | | # Check the type of the tensor |
| | | if str(tensors[0].dtype).startswith("torch.int"): |
| | | # Perform sum for integer tensors |
| | | summed_tensor = sum(tensors) |
| | | avg_state_dict[key] = summed_tensor |
| | | else: |
| | | # Perform average for other types of tensors |
| | | stacked_tensors = torch.stack(tensors) |
| | | avg_state_dict[key] = torch.mean(stacked_tensors, dim=0) |
| | | |
| | | torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}")) |
| | | return avg_state_dict |
| | |
| | | from contextlib import nullcontext |
| | | # from torch.utils.tensorboard import SummaryWriter |
| | | from tensorboardX import SummaryWriter |
| | | from pathlib import Path |
| | | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | | |
| | | from funasr.train_utils.average_nbest_models import average_checkpoints |
| | | |
| | | class Trainer: |
| | | """ |
| | |
| | | self.use_ddp = use_ddp |
| | | self.use_fsdp = use_fsdp |
| | | self.device = next(model.parameters()).device |
| | | self.avg_nbest_model = kwargs.get("avg_nbest_model", 5) |
| | | self.kwargs = kwargs |
| | | |
| | | if self.resume: |
| | | self._resume_checkpoint(self.resume) |
| | | |
| | | try: |
| | | rank = dist.get_rank() |
| | |
| | | } |
| | | # Create output directory if it does not exist |
| | | os.makedirs(self.output_dir, exist_ok=True) |
| | | filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') |
| | | filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') |
| | | torch.save(state, filename) |
| | | |
| | | print(f'Checkpoint saved to {filename}') |
| | | latest = Path(os.path.join(self.output_dir, f'model.pt')) |
| | | try: |
| | | latest.unlink() |
| | | except: |
| | | pass |
| | | |
| | | latest.symlink_to(filename) |
| | | |
| | | def _resume_checkpoint(self, resume_path): |
| | | """ |
| | |
| | | Args: |
| | | resume_path (str): The file path to the checkpoint to resume from. |
| | | """ |
| | | if os.path.isfile(resume_path): |
| | | checkpoint = torch.load(resume_path) |
| | | ckpt = os.path.join(resume_path, "model.pt") |
| | | if os.path.isfile(ckpt): |
| | | checkpoint = torch.load(ckpt) |
| | | self.start_epoch = checkpoint['epoch'] + 1 |
| | | self.model.load_state_dict(checkpoint['state_dict']) |
| | | self.optim.load_state_dict(checkpoint['optimizer']) |
| | | self.scheduler.load_state_dict(checkpoint['scheduler']) |
| | | print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})") |
| | | print(f"Checkpoint loaded successfully from '{ckpt}'") |
| | | else: |
| | | print(f"No checkpoint found at '{resume_path}', starting from scratch") |
| | | print(f"No checkpoint found at '{ckpt}', starting from scratch") |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | def run(self): |
| | | """ |
| | | Starts the training process, iterating over epochs, training the model, |
| | | and saving checkpoints at the end of each epoch. |
| | | """ |
| | | if self.resume: |
| | | self._resume_checkpoint(self.output_dir) |
| | | |
| | | for epoch in range(self.start_epoch, self.max_epoch + 1): |
| | | |
| | | self._train_epoch(epoch) |
| | | # self._validate_epoch(epoch) |
| | | |
| | | self._validate_epoch(epoch) |
| | | |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | self.scheduler.step() |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | self.scheduler.step() |
| | | |
| | | |
| | | if self.rank == 0: |
| | | average_checkpoints(self.output_dir, self.avg_nbest_model) |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | self.writer.close() |
| | | |
| | | |
| | | def _train_epoch(self, epoch): |
| | | """ |
| | |
| | | for batch_idx, batch in enumerate(self.dataloader_train): |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1-time5:0.3f}" |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | batch = to_device(batch, self.device) |
| | | |
| | | my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext |
| | |
| | | speed_stats["optim_time"] = f"{time5 - time4:0.3f}" |
| | | |
| | | speed_stats["total_time"] = total_time |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"Epoch: {epoch + 1}/{self.max_epoch}, " |
| | | f"Epoch: {epoch}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | |
| | | """ |
| | | self.model.eval() |
| | | with torch.no_grad(): |
| | | for data, target in self.dataloader_val: |
| | | # Implement the model validation steps here |
| | | pass |
| | | pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val), |
| | | dynamic_ncols=True) |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | for batch_idx, batch in enumerate(self.dataloader_val): |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1 - time5:0.3f}" |
| | | batch = to_device(batch, self.device) |
| | | time2 = time.perf_counter() |
| | | retval = self.model(**batch) |
| | | time3 = time.perf_counter() |
| | | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| | | loss, stats, weight = retval |
| | | stats = {k: v for k, v in stats.items() if v is not None} |
| | | if self.use_ddp or self.use_fsdp: |
| | | # Apply weighted averaging for loss and stats |
| | | loss = (loss * weight.type(loss.dtype)).sum() |
| | | # if distributed, this method can also apply all_reduce() |
| | | stats, weight = recursive_average(stats, weight, distributed=True) |
| | | # Now weight is summation over all workers |
| | | loss /= weight |
| | | # Multiply world_size because DistributedDataParallel |
| | | # automatically normalizes the gradient by world_size. |
| | | loss *= self.world_size |
| | | # Scale the loss since we're not updating for every mini-batch |
| | | loss = loss |
| | | time4 = time.perf_counter() |
| | | |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"validation: \nEpoch: {epoch}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" |
| | | ) |
| | | pbar.set_description(description) |
| | | if self.writer: |
| | | self.writer.add_scalar('Loss/val', loss.item(), |
| | | epoch*len(self.dataloader_train) + batch_idx) |
| | | for key, var in stats.items(): |
| | | self.writer.add_scalar(f'{key}/val', var.item(), |
| | | epoch * len(self.dataloader_train) + batch_idx) |
| | | for key, var in speed_stats.items(): |
| | | self.writer.add_scalar(f'{key}/val', eval(var), |
| | | epoch * len(self.dataloader_train) + batch_idx) |