| | |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from torch.cuda.amp import autocast |
| | | |
| | | import re |
| | | from funasr.models.scama.utils import sequence_mask |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.ctc.ctc import CTC |
| | |
| | | user = contents["user"] |
| | | assistant = contents["assistant"] |
| | | pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)") |
| | | input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], [] |
| | | input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = ( |
| | | [], |
| | | [], |
| | | [], |
| | | [], |
| | | [], |
| | | [], |
| | | [], |
| | | ) |
| | | |
| | | for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)): |
| | | |
| | | source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" |
| | | |
| | | splits = pattern.split(source_input) |
| | | source_ids = [] |
| | | source_ids_i = [] |
| | | fbank_mask_i = [] |
| | | fbank_beg_i = [] |
| | | fbank_lens_i = [] |
| | | # target_ids_i = [] |
| | | for k, sub_str in enumerate(splits): |
| | | if not sub_str.startswith("<|startofspeech|>"): |
| | | sub_token = tokenizer.encode(sub_str) |
| | | source_ids += sub_token |
| | | source_ids_i += sub_token |
| | | fbank_mask_i += [0] * len(sub_token) |
| | | else: |
| | | sub_str = sub_str.replace("<|startofspeech|>", "").replace( |
| | |
| | | olens = 1 + (olens - 3 + 2 * 1) // 2 |
| | | sub_token_len = (olens - 1) // 2 + 1 |
| | | sub_token = [0] * sub_token_len |
| | | fbank_beg_i = [len(source_ids)] |
| | | source_ids += sub_token |
| | | fbank_beg_i = [len(source_ids_i)] |
| | | source_ids_i += sub_token |
| | | fbank_mask_i += [1] * len(sub_token) |
| | | |
| | | source_mask = [-100] * len(source_ids) |
| | | source_mask = [-100] * len(source_ids_i) |
| | | target_out = f"{target_out}<|im_end|>" |
| | | target_ids = tokenizer.encode(target_out) |
| | | input_ids += source_ids + target_ids |
| | | input_ids += source_ids_i + target_ids |
| | | labels += source_mask + target_ids |
| | | fbank_mask += fbank_mask_i |
| | | fbank_beg.append(fbank_beg_i) |