From e0fa63765bfb4a36bde7047c2a6066ca5a80e90f Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 21 八月 2023 10:37:42 +0800
Subject: [PATCH] Dev hw (#878)

---
 funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py |   47 +++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 47 insertions(+), 0 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index f1fc9a0..cf74200 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -7,6 +7,7 @@
 from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
 
 import re
+import torch
 import numpy as np
 import yaml
 try:
@@ -22,6 +23,52 @@
 logger_initialized = {}
 
 
+def pad_list(xs, pad_value, max_len=None):
+    n_batch = len(xs)
+    if max_len is None:
+        max_len = max(x.size(0) for x in xs)
+    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+    for i in range(n_batch):
+        pad[i, : xs[i].size(0)] = xs[i]
+
+    return pad
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+    if length_dim == 0:
+        raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+    if not isinstance(lengths, list):
+        lengths = lengths.tolist()
+    bs = int(len(lengths))
+    if maxlen is None:
+        if xs is None:
+            maxlen = int(max(lengths))
+        else:
+            maxlen = xs.size(length_dim)
+    else:
+        assert xs is None
+        assert maxlen >= int(max(lengths))
+
+    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    if xs is not None:
+        assert xs.size(0) == bs, (xs.size(0), bs)
+
+        if length_dim < 0:
+            length_dim = xs.dim() + length_dim
+        # ind = (:, None, ..., None, :, , None, ..., None)
+        ind = tuple(
+            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+        )
+        mask = mask[ind].expand_as(xs).to(xs.device)
+    return mask
+
+
 class TokenIDConverter():
     def __init__(self, token_list: Union[List, str],
                  ):

--
Gitblit v1.9.1