#!/usr/bin/env python3
|
# -*- encoding: utf-8 -*-
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
import types
|
import torch
|
import torch.nn.functional as F
|
|
|
def export_rebuild_model(model, **kwargs):
|
model.device = kwargs.get("device")
|
|
# store original forward since self.extract_features is calling it
|
model._original_forward = model.forward
|
|
model.forward = types.MethodType(export_forward, model)
|
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
model.export_input_names = types.MethodType(export_input_names, model)
|
model.export_output_names = types.MethodType(export_output_names, model)
|
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
model.export_name = types.MethodType(export_name, model)
|
|
return model
|
|
|
def export_forward(self, x: torch.Tensor):
|
with torch.no_grad():
|
if self.cfg.normalize:
|
mean = torch.mean(x, dim=1, keepdim=True)
|
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
|
x = (x - mean) / torch.sqrt(var + 1e-5)
|
x = x.view(x.shape[0], -1)
|
|
# Call the original forward directly just like extract_features
|
# Cannot directly use self.extract_features since it is being replaced by export_forward
|
res = self._original_forward(
|
source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True
|
)
|
|
x = res["x"]
|
|
return x
|
|
|
def export_dummy_inputs(self):
|
return (torch.randn(1, 16000),)
|
|
|
def export_input_names(self):
|
return ["input"]
|
|
|
def export_output_names(self):
|
return ["output"]
|
|
|
def export_dynamic_axes(self):
|
return {
|
"input": {
|
0: "batch_size",
|
1: "sequence_length",
|
},
|
"output": {0: "batch_size", 1: "sequence_length"},
|
}
|
|
|
def export_name(self):
|
return "emotion2vec"
|