From e0fa63765bfb4a36bde7047c2a6066ca5a80e90f Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 21 八月 2023 10:37:42 +0800
Subject: [PATCH] Dev hw (#878)
---
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 47 insertions(+), 0 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index f1fc9a0..cf74200 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -7,6 +7,7 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
+import torch
import numpy as np
import yaml
try:
@@ -22,6 +23,52 @@
logger_initialized = {}
+def pad_list(xs, pad_value, max_len=None):
+ n_batch = len(xs)
+ if max_len is None:
+ max_len = max(x.size(0) for x in xs)
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+ for i in range(n_batch):
+ pad[i, : xs[i].size(0)] = xs[i]
+
+ return pad
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if not isinstance(lengths, list):
+ lengths = lengths.tolist()
+ bs = int(len(lengths))
+ if maxlen is None:
+ if xs is None:
+ maxlen = int(max(lengths))
+ else:
+ maxlen = xs.size(length_dim)
+ else:
+ assert xs is None
+ assert maxlen >= int(max(lengths))
+
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ if xs is not None:
+ assert xs.size(0) == bs, (xs.size(0), bs)
+
+ if length_dim < 0:
+ length_dim = xs.dim() + length_dim
+ # ind = (:, None, ..., None, :, , None, ..., None)
+ ind = tuple(
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+ )
+ mask = mask[ind].expand_as(xs).to(xs.device)
+ return mask
+
+
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
--
Gitblit v1.9.1