From 6254b0b7641f251109d7b81b4cacfcd67c0a5407 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 二月 2023 16:01:08 +0800
Subject: [PATCH] export model test onnx

---
 /dev/null                     |  212 -----------------------------------------------------
 funasr/export/export_model.py |    3 
 2 files changed, 2 insertions(+), 213 deletions(-)

diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 17bc138..2441509 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -87,5 +87,6 @@
         )
 
 if __name__ == '__main__':
-    export_model = ASRModelExportParaformer()
+    output_dir = "../export"
+    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
     export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
\ No newline at end of file
diff --git a/funasr/export/models/predictor/cif_test.py b/funasr/export/models/predictor/cif_test.py
deleted file mode 100644
index 954c434..0000000
--- a/funasr/export/models/predictor/cif_test.py
+++ /dev/null
@@ -1,212 +0,0 @@
-import torch
-from torch import nn
-import logging
-import numpy as np
-
-
-def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
-	if maxlen is None:
-		maxlen = lengths.max()
-	row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
-	matrix = torch.unsqueeze(lengths, dim=-1)
-	mask = row_vector < matrix
-	mask = mask.detach()
-	
-	return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
-
-
-def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
-
-	if length_dim == 0:
-		raise ValueError("length_dim cannot be 0: {}".format(length_dim))
-	
-	if not isinstance(lengths, list):
-		lengths = lengths.tolist()
-	bs = int(len(lengths))
-	if maxlen is None:
-		if xs is None:
-			maxlen = int(max(lengths))
-		else:
-			maxlen = xs.size(length_dim)
-	else:
-		assert xs is None
-		assert maxlen >= int(max(lengths))
-	
-	seq_range = torch.arange(0, maxlen, dtype=torch.int64)
-	seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
-	seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
-	mask = seq_range_expand >= seq_length_expand
-	
-	if xs is not None:
-		assert xs.size(0) == bs, (xs.size(0), bs)
-		
-		if length_dim < 0:
-			length_dim = xs.dim() + length_dim
-		# ind = (:, None, ..., None, :, , None, ..., None)
-		ind = tuple(
-			slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
-		)
-		mask = mask[ind].expand_as(xs).to(xs.device)
-	return mask
-
-
-
-class CifPredictorV2(nn.Module):
-	def __init__(self,
-	             idim: int,
-	             l_order: int,
-	             r_order: int,
-	             threshold: float = 1.0,
-	             dropout: float = 0.1,
-	             smooth_factor: float = 1.0,
-	             noise_threshold: float = 0,
-	             tail_threshold: float = 0.0,
-	             ):
-		super(CifPredictorV2, self).__init__()
-		
-		self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
-		self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
-		self.cif_output = nn.Linear(idim, 1)
-		self.dropout = torch.nn.Dropout(p=dropout)
-		self.threshold = threshold
-		self.smooth_factor = smooth_factor
-		self.noise_threshold = noise_threshold
-		self.tail_threshold = tail_threshold
-	
-	def forward(self, hidden: torch.Tensor,
-	            mask: torch.Tensor,
-	            ):
-		h = hidden
-		context = h.transpose(1, 2)
-		queries = self.pad(context)
-		output = torch.relu(self.cif_conv1d(queries))
-		output = output.transpose(1, 2)
-		
-		output = self.cif_output(output)
-		alphas = torch.sigmoid(output)
-		alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-		mask = mask.transpose(-1, -2).float()
-		alphas = alphas * mask
-		
-		alphas = alphas.squeeze(-1)
-		
-		token_num = alphas.sum(-1)
-		
-		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-		
-		return acoustic_embeds, token_num, alphas, cif_peak
-	
-	def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
-		b, t, d = hidden.size()
-		tail_threshold = self.tail_threshold
-		
-		zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
-		ones_t = torch.ones_like(zeros_t)
-		mask_1 = torch.cat([mask, zeros_t], dim=1)
-		mask_2 = torch.cat([ones_t, mask], dim=1)
-		mask = mask_2 - mask_1
-		tail_threshold = mask * tail_threshold
-		alphas = torch.cat([alphas, tail_threshold], dim=1)
-		
-		zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
-		hidden = torch.cat([hidden, zeros], dim=1)
-		token_num = alphas.sum(dim=-1)
-		token_num_floor = torch.floor(token_num)
-		
-		return hidden, alphas, token_num_floor
-
-@torch.jit.script
-def cif(hidden, alphas, threshold: float):
-	batch_size, len_time, hidden_size = hidden.size()
-	threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
-	
-	# loop varss
-	integrate = torch.zeros([batch_size], device=hidden.device)
-	frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
-	# intermediate vars along time
-	list_fires = []
-	list_frames = []
-	
-	for t in range(len_time):
-		alpha = alphas[:, t]
-		distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
-		
-		integrate += alpha
-		list_fires.append(integrate)
-		
-		fire_place = integrate >= threshold
-		integrate = torch.where(fire_place,
-		                        integrate - torch.ones([batch_size], device=hidden.device),
-		                        integrate)
-		cur = torch.where(fire_place,
-		                  distribution_completion,
-		                  alpha)
-		remainds = alpha - cur
-		
-		frame += cur[:, None] * hidden[:, t, :]
-		list_frames.append(frame)
-		frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
-		                    remainds[:, None] * hidden[:, t, :],
-		                    frame)
-	
-	fires = torch.stack(list_fires, 1)
-	frames = torch.stack(list_frames, 1)
-	list_ls = []
-	len_labels = torch.round(alphas.sum(-1)).int()
-	max_label_len = len_labels.max()
-	for b in range(batch_size):
-		fire = fires[b, :]
-		l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
-		pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
-		list_ls.append(torch.cat([l, pad_l], 0))
-	return torch.stack(list_ls, 0), fires
-
-
-def CifPredictorV2_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor_scripts.save('test.pt')
-	loaded = torch.jit.load('test.pt')
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	# print(cif_output)
-	print(predictor_scripts.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	print(cif_output)
-
-
-def CifPredictorV2_export_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	# predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor = CifPredictorV2(2, 1, 1)
-	predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
-	predictor_trace.save('test_trace.pt')
-	loaded = torch.jit.load('test_trace.pt')
-	
-	x = torch.rand([3, 30, 2])
-	x_len = torch.IntTensor([6, 20, 30])
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	print(cif_output)
-	# print(predictor_trace.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	# print(cif_output)
-
-
-if __name__ == '__main__':
-	# CifPredictorV2_test()
-	CifPredictorV2_export_test()
\ No newline at end of file

--
Gitblit v1.9.1