From 02667efa7528e5350e0351358a4336aa91662fc7 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:18:13 +0800
Subject: [PATCH] update utils
---
runtime/python/libtorch/funasr_torch/utils/utils.py | 24 ++++++++++++++----------
1 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/runtime/python/libtorch/funasr_torch/utils/utils.py b/runtime/python/libtorch/funasr_torch/utils/utils.py
index f85d4e9..ee43852 100644
--- a/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,21 +1,25 @@
# -*- encoding: utf-8 -*-
-
-import functools
+import yaml
import logging
-import pickle
+import functools
+import numpy as np
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
-import numpy as np
-import yaml
-
-
-import warnings
-
root_dir = Path(__file__).resolve().parent
-
logger_initialized = {}
+def pad_list(xs, pad_value, max_len=None):
+ n_batch = len(xs)
+ if max_len is None:
+ max_len = max(x.size(0) for x in xs)
+ # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+ # numpy format
+ pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
+ for i in range(n_batch):
+ pad[i, : xs[i].shape[0]] = xs[i]
+
+ return pad
class TokenIDConverter:
def __init__(
--
Gitblit v1.9.1