| | |
| | | from torch import Tensor |
| | | from torch import nn |
| | | from torch.cuda.amp import autocast |
| | | from funasr.metrics.compute_acc import compute_accuracy |
| | | from funasr.metrics.compute_acc import compute_accuracy, th_accuracy |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.ctc.ctc import CTC |
| | | |
| | | from funasr.register import tables |
| | | |
| | |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | ) |
| | | with autocast(False): |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | ) |
| | | |
| | | loss = loss_att |
| | | stats = {} |
| | | stats["acc"] = acc_att |
| | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.encoder_output_size = encoder_output_size |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | if isinstance(task, str): |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | |
| | | |
| | | sos = kwargs.get("model_conf").get("sos") |
| | | if isinstance(sos, str): |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | initial_prompt = kwargs.get("initial_prompt", f"{task}") |
| | | initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | initial_prompt_lid = ( |
| | | f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | ) |
| | | initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all") |
| | | sos_int = [sos] + initial_prompt_lid_int |
| | | eos = kwargs.get("model_conf").get("eos") |
| | |
| | | ) |
| | | self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1]) |
| | | |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :], speech_lengths |
| | | ) |
| | | encoder_out, encoder_out_lens = self.encode(speech[None, :, :], speech_lengths) |
| | | |
| | | if text_token_int is not None: |
| | | i = 0 |
| | |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.utils import postprocess_utils |