From 4daea3711063c64485be3c00eaa9727404549f51 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 24 二月 2023 17:55:00 +0800
Subject: [PATCH] onnx
---
funasr/export/models/predictor/cif.py | 53 ++---------------------------------------------------
1 files changed, 2 insertions(+), 51 deletions(-)
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 32a3c13..fcfcd5f 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -109,60 +109,11 @@
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
- max_label_len = len_labels.max()
+ max_label_len = len_labels.max().item()
+ print("type: {}".format(type(max_label_len)))
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
-
-
-def CifPredictorV2_test():
- x = torch.rand([2, 21, 2])
- x_len = torch.IntTensor([6, 21])
-
- mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
- x = x * mask[:, :, None]
-
- predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
- # cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
- predictor_scripts.save('test.pt')
- loaded = torch.jit.load('test.pt')
- cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
- # print(cif_output)
- print(predictor_scripts.code)
- # predictor = CifPredictorV2(2, 1, 1)
- # cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
- print(cif_output)
-
-
-def CifPredictorV2_export_test():
- x = torch.rand([2, 21, 2])
- x_len = torch.IntTensor([6, 21])
-
- mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
- x = x * mask[:, :, None]
-
- # predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
- # cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
- predictor = CifPredictorV2(2, 1, 1)
- predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
- predictor_trace.save('test_trace.pt')
- loaded = torch.jit.load('test_trace.pt')
-
- x = torch.rand([3, 30, 2])
- x_len = torch.IntTensor([6, 20, 30])
- mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
- x = x * mask[:, :, None]
- cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
- print(cif_output)
- # print(predictor_trace.code)
- # predictor = CifPredictorV2(2, 1, 1)
- # cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
- # print(cif_output)
-
-
-if __name__ == '__main__':
- # CifPredictorV2_test()
- CifPredictorV2_export_test()
\ No newline at end of file
--
Gitblit v1.9.1