From 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 18 六月 2024 10:01:56 +0800
Subject: [PATCH] Merge pull request #1825 from modelscope/dev_libt
---
runtime/python/libtorch/funasr_torch/utils/utils.py | 24 ++++++++++++++----------
1 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/runtime/python/libtorch/funasr_torch/utils/utils.py b/runtime/python/libtorch/funasr_torch/utils/utils.py
index f85d4e9..ee43852 100644
--- a/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,21 +1,25 @@
# -*- encoding: utf-8 -*-
-
-import functools
+import yaml
import logging
-import pickle
+import functools
+import numpy as np
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
-import numpy as np
-import yaml
-
-
-import warnings
-
root_dir = Path(__file__).resolve().parent
-
logger_initialized = {}
+def pad_list(xs, pad_value, max_len=None):
+ n_batch = len(xs)
+ if max_len is None:
+ max_len = max(x.size(0) for x in xs)
+ # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+ # numpy format
+ pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
+ for i in range(n_batch):
+ pad[i, : xs[i].shape[0]] = xs[i]
+
+ return pad
class TokenIDConverter:
def __init__(
--
Gitblit v1.9.1