| | |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | |
| | | self.predictor_bias = predictor_bias |
| | | self.sampling_ratio = sampling_ratio |
| | | self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) |
| | | # self.step_cur = 0 |
| | | # |
| | | |
| | | |
| | | self.share_embedding = share_embedding |
| | | if self.share_embedding: |
| | | self.decoder.embed = None |
| | |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | |
| | | stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | # Encoder |
| | | if kwargs.get("fp16", False): |
| | | speech = speech.half() |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | |
| | | |
| | | return results, meta_data |
| | | |
| | | def export(self, **kwargs): |
| | | from .export_meta import export_rebuild_model |
| | | if 'max_seq_len' not in kwargs: |
| | | kwargs['max_seq_len'] = 512 |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |
| | | |