From 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 18 六月 2024 10:01:56 +0800
Subject: [PATCH] Merge pull request #1825 from modelscope/dev_libt

---
 runtime/python/libtorch/funasr_torch/utils/utils.py |   24 ++++++++++++++----------
 1 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/runtime/python/libtorch/funasr_torch/utils/utils.py b/runtime/python/libtorch/funasr_torch/utils/utils.py
index f85d4e9..ee43852 100644
--- a/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,21 +1,25 @@
 # -*- encoding: utf-8 -*-
-
-import functools
+import yaml
 import logging
-import pickle
+import functools
+import numpy as np
 from pathlib import Path
 from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
 
-import numpy as np
-import yaml
-
-
-import warnings
-
 root_dir = Path(__file__).resolve().parent
-
 logger_initialized = {}
 
+def pad_list(xs, pad_value, max_len=None):
+    n_batch = len(xs)
+    if max_len is None:
+        max_len = max(x.size(0) for x in xs)
+    # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+    # numpy format
+    pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
+    for i in range(n_batch):
+        pad[i, : xs[i].shape[0]] = xs[i]
+
+    return pad
 
 class TokenIDConverter:
     def __init__(

--
Gitblit v1.9.1