| | |
| | | ) # speech: [b, T, d] |
| | | if self.permute: |
| | | speech = speech.permute(0, 2, 1) |
| | | if speech_lengths > self.batch_size: |
| | | continue |
| | | # if speech_lengths > self.batch_size: |
| | | # continue |
| | | |
| | | fbank_lens = speech_lengths[0].item() |
| | | olens = 1 + (fbanks_len - 3 + 2 * 1) // 2 |
| | | olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2 |
| | | olens = 1 + (olens - 3 + 2 * 1) // 2 |
| | | sub_token_len = (olens - 1) // 2 + 1 |
| | | sub_token = [0] * sub_token_len[0] |
| | | sub_token = [0] * sub_token_len |
| | | fbank_beg_i = [len(source_ids)] |
| | | source_ids += sub_token |
| | | fbank_mask_i += [1] * len(sub_token) |
| | | |
| | | source_mask = [-100] * len(source_ids) |
| | | target_out = f"{target_out}<|im_end|>" |
| | | target_ids = tokenizer.encode(target_out) |
| | | target_ids = self.tokenizer.encode(target_out) |
| | | input_ids += source_ids + target_ids |
| | | labels += source_mask + target_ids |
| | | fbank_mask += fbank_mask_i |