Kyle He
2025-08-14 82a07e2f6ec60aa25a3931e9ee0d99ead642484a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
 
import types
import torch
import torch.nn.functional as F
 
 
def export_rebuild_model(model, **kwargs):
    model.device = kwargs.get("device")
 
    # store original forward since self.extract_features is calling it
    model._original_forward = model.forward
 
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
 
    return model
 
 
def export_forward(self, x: torch.Tensor):
    with torch.no_grad():
        if self.cfg.normalize:
            mean = torch.mean(x, dim=1, keepdim=True)
            var = torch.var(x, dim=1, keepdim=True, unbiased=False)
            x = (x - mean) / torch.sqrt(var + 1e-5)
            x = x.view(x.shape[0], -1)
 
        # Call the original forward directly just like extract_features
        # Cannot directly use self.extract_features since it is being replaced by export_forward
        res = self._original_forward(
            source=x, padding_mask=None, mask=False, features_only=True, remove_extra_tokens=True
        )
 
        x = res["x"]
 
        return x
 
 
def export_dummy_inputs(self):
    return (torch.randn(1, 16000),)
 
 
def export_input_names(self):
    return ["input"]
 
 
def export_output_names(self):
    return ["output"]
 
 
def export_dynamic_axes(self):
    return {
        "input": {
            0: "batch_size",
            1: "sequence_length",
        },
        "output": {0: "batch_size", 1: "sequence_length"},
    }
 
 
def export_name(self):
    return "emotion2vec"