| | |
| | | model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) |
| | | model.export_name = types.MethodType(export_name, model) |
| | | |
| | | model.export_name = "emotion2vec" |
| | | return model |
| | | |
| | | |
| | | def export_forward( |
| | | self, x: torch.Tensor |
| | | ): |
| | | def export_forward(self, x: torch.Tensor): |
| | | with torch.no_grad(): |
| | | if self.cfg.normalize: |
| | | mean = torch.mean(x, dim=1, keepdim=True) |
| | |
| | | # Call the original forward directly just like extract_features |
| | | # Cannot directly use self.extract_features since it is being replaced by export_forward |
| | | res = self._original_forward( |
| | | source=x, |
| | | padding_mask=None, |
| | | mask=False, |
| | | features_only=True, |
| | | remove_extra_tokens=True |
| | | source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True |
| | | ) |
| | | |
| | | x = res["x"] |