From 23e7ddebccd3b05cf7ef89809bcfe565ad6dfa1f Mon Sep 17 00:00:00 2001
From: majic31 <majic31@163.com>
Date: 星期二, 24 十二月 2024 10:00:14 +0800
Subject: [PATCH] Fix the variable name (#2328)
---
runtime/python/libtorch/funasr_torch/utils/utils.py | 24 ++++++++++++++----------
1 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/runtime/python/libtorch/funasr_torch/utils/utils.py b/runtime/python/libtorch/funasr_torch/utils/utils.py
index f85d4e9..ee43852 100644
--- a/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,21 +1,25 @@
# -*- encoding: utf-8 -*-
-
-import functools
+import yaml
import logging
-import pickle
+import functools
+import numpy as np
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
-import numpy as np
-import yaml
-
-
-import warnings
-
root_dir = Path(__file__).resolve().parent
-
logger_initialized = {}
+def pad_list(xs, pad_value, max_len=None):
+ n_batch = len(xs)
+ if max_len is None:
+ max_len = max(x.size(0) for x in xs)
+ # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+ # numpy format
+ pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
+ for i in range(n_batch):
+ pad[i, : xs[i].shape[0]] = xs[i]
+
+ return pad
class TokenIDConverter:
def __init__(
--
Gitblit v1.9.1