From 6d17715edfbcdf9b2cdd888d7cfd0ab5b6c12008 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 15 二月 2023 17:30:26 +0800
Subject: [PATCH] Merge branch 'dev_wjm' of https://github.com/alibaba-damo-academy/FunASR into dev_wjm

---
 funasr/export/models/predictor/cif.py |   50 --------------------------------------------------
 1 files changed, 0 insertions(+), 50 deletions(-)

diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 32a3c13..5518cb8 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -116,53 +116,3 @@
 		pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
 		list_ls.append(torch.cat([l, pad_l], 0))
 	return torch.stack(list_ls, 0), fires
-
-
-def CifPredictorV2_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor_scripts.save('test.pt')
-	loaded = torch.jit.load('test.pt')
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	# print(cif_output)
-	print(predictor_scripts.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	print(cif_output)
-
-
-def CifPredictorV2_export_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	# predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor = CifPredictorV2(2, 1, 1)
-	predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
-	predictor_trace.save('test_trace.pt')
-	loaded = torch.jit.load('test_trace.pt')
-	
-	x = torch.rand([3, 30, 2])
-	x_len = torch.IntTensor([6, 20, 30])
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	print(cif_output)
-	# print(predictor_trace.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	# print(cif_output)
-
-
-if __name__ == '__main__':
-	# CifPredictorV2_test()
-	CifPredictorV2_export_test()
\ No newline at end of file

--
Gitblit v1.9.1