| | |
| | | |
| | | super().__init__() |
| | | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | |
| | | # audio encoder |
| | | hub = audio_encoder_conf.get("hub", None) |
| | | if hub == "ms": |
| | |
| | | # llm |
| | | hub = llm_conf.get("hub", "hf") |
| | | self.llm = None |
| | | # if hub == "hf": |
| | | # from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| | | # |
| | | # init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") |
| | | # |
| | | # model = AutoModelForCausalLM.from_pretrained( |
| | | # init_param_path, |
| | | # load_in_8bit=None, |
| | | # device_map=None, |
| | | # use_cache=None, |
| | | # ) |
| | | # freeze = llm_conf.get("freeze", True) |
| | | # if freeze: |
| | | # for name, param in model.named_parameters(): |
| | | # param.requires_grad = False |
| | | # model.eval() |
| | | # self.llm = model |
| | | if hub == "hf": |
| | | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| | | |
| | | init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") |
| | | |
| | | model = AutoModelForCausalLM.from_pretrained( |
| | | init_param_path, |
| | | load_in_8bit=None, |
| | | device_map=None, |
| | | use_cache=None, |
| | | ) |
| | | freeze = llm_conf.get("freeze", True) |
| | | if freeze: |
| | | for name, param in model.named_parameters(): |
| | | param.requires_grad = False |
| | | model.eval() |
| | | self.llm = model |
| | | |
| | | # adaptor |
| | | adaptor_class = tables.adaptor_classes.get(audio_adaptor) |
| | |
| | | audio_adaptor = adaptor_class(**audio_adaptor_conf) |
| | | |
| | | self.audio_adaptor = audio_adaptor |
| | | |
| | | self.blank_id = blank_id |
| | | self.sos = sos if sos is not None else vocab_size - 1 |
| | | self.eos = eos if eos is not None else vocab_size - 1 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | |
| | | self.error_calculator = None |
| | | |
| | |
| | | batch_size = speech.shape[0] |
| | | |
| | | # audio encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | |
| | | # audio_adaptor |
| | | encoder_out = self.audio_adaptor(encoder_out) |
| | | encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) |
| | | |
| | | input_ids[input_ids == -1] = 0 |
| | | input_ids[input_ids == -100] = 0 |
| | |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | batch_size = int((labels_ids > 0 + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | speech = speech.permute(0, 2, 1) |
| | | res = self.audio_encoder(speech) |
| | | if isinstance(res, (list, tuple)): |
| | | encoder_out, encoder_out_lens = res[0], res[1] |
| | | else: |
| | | encoder_out, encoder_out_lens = res, speech_lengths |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def inference( |
| | | self, |