From f14f9f8d15037c7b81cbdc880d61d05e23382a8f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 09 一月 2024 00:13:51 +0800
Subject: [PATCH] funasr1.0 infer url modelscope

---
 /dev/null                  |  359 -----------------------------
 .gitignore                 |    1 
 funasr/download/file.py    |  328 +++++++++++++++++++++++++++
 funasr/utils/load_utils.py |   29 ++
 4 files changed, 355 insertions(+), 362 deletions(-)

diff --git a/.gitignore b/.gitignore
index dea4634..4023869 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,3 +22,4 @@
 samples
 .ipynb_checkpoints
 outputs*
+emotion2vec*
diff --git a/funasr/download/file.py b/funasr/download/file.py
new file mode 100644
index 0000000..d93f24c
--- /dev/null
+++ b/funasr/download/file.py
@@ -0,0 +1,328 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import contextlib
+import os
+import tempfile
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+from typing import Generator, Union
+
+import requests
+
+
+class Storage(metaclass=ABCMeta):
+    """Abstract class of storage.
+
+    All backends need to implement two apis: ``read()`` and ``read_text()``.
+    ``read()`` reads the file as a byte stream and ``read_text()`` reads
+    the file as texts.
+    """
+
+    @abstractmethod
+    def read(self, filepath: str):
+        pass
+
+    @abstractmethod
+    def read_text(self, filepath: str):
+        pass
+
+    @abstractmethod
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        pass
+
+    @abstractmethod
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        pass
+
+
+class LocalStorage(Storage):
+    """Local hard disk storage"""
+
+    def read(self, filepath: Union[str, Path]) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        with open(filepath, 'rb') as f:
+            content = f.read()
+        return content
+
+    def read_text(self,
+                  filepath: Union[str, Path],
+                  encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        with open(filepath, 'r', encoding=encoding) as f:
+            value_buf = f.read()
+        return value_buf
+
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``write`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        dirname = os.path.dirname(filepath)
+        if dirname and not os.path.exists(dirname):
+            os.makedirs(dirname, exist_ok=True)
+
+        with open(filepath, 'wb') as f:
+            f.write(obj)
+
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``write_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        dirname = os.path.dirname(filepath)
+        if dirname and not os.path.exists(dirname):
+            os.makedirs(dirname, exist_ok=True)
+
+        with open(filepath, 'w', encoding=encoding) as f:
+            f.write(obj)
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self,
+            filepath: Union[str,
+                            Path]) -> Generator[Union[str, Path], None, None]:
+        """Only for unified API and do nothing."""
+        yield filepath
+
+
+class HTTPStorage(Storage):
+    """HTTP and HTTPS storage."""
+
+    def read(self, url):
+        # TODO @wenmeng.zwm add progress bar if file is too large
+        r = requests.get(url)
+        r.raise_for_status()
+        return r.content
+
+    def read_text(self, url):
+        r = requests.get(url)
+        r.raise_for_status()
+        return r.text
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self, filepath: str) -> Generator[Union[str, Path], None, None]:
+        """Download a file from ``filepath``.
+
+        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> storage = HTTPStorage()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with storage.get_local_path('http://path/to/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.read(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def write(self, obj: bytes, url: Union[str, Path]) -> None:
+        raise NotImplementedError('write is not supported by HTTP Storage')
+
+    def write_text(self,
+                   obj: str,
+                   url: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        raise NotImplementedError(
+            'write_text is not supported by HTTP Storage')
+
+
+class OSSStorage(Storage):
+    """OSS storage."""
+
+    def __init__(self, oss_config_file=None):
+        # read from config file or env var
+        raise NotImplementedError(
+            'OSSStorage.__init__ to be implemented in the future')
+
+    def read(self, filepath):
+        raise NotImplementedError(
+            'OSSStorage.read to be implemented in the future')
+
+    def read_text(self, filepath, encoding='utf-8'):
+        raise NotImplementedError(
+            'OSSStorage.read_text to be implemented in the future')
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self, filepath: str) -> Generator[Union[str, Path], None, None]:
+        """Download a file from ``filepath``.
+
+        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> storage = OSSStorage()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with storage.get_local_path('http://path/to/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.read(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        raise NotImplementedError(
+            'OSSStorage.write to be implemented in the future')
+
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        raise NotImplementedError(
+            'OSSStorage.write_text to be implemented in the future')
+
+
+G_STORAGES = {}
+
+
+class File(object):
+    _prefix_to_storage: dict = {
+        'oss': OSSStorage,
+        'http': HTTPStorage,
+        'https': HTTPStorage,
+        'local': LocalStorage,
+    }
+
+    @staticmethod
+    def _get_storage(uri):
+        assert isinstance(uri,
+                          str), f'uri should be str type, but got {type(uri)}'
+
+        if '://' not in uri:
+            # local path
+            storage_type = 'local'
+        else:
+            prefix, _ = uri.split('://')
+            storage_type = prefix
+
+        assert storage_type in File._prefix_to_storage, \
+            f'Unsupported uri {uri}, valid prefixs: '\
+            f'{list(File._prefix_to_storage.keys())}'
+
+        if storage_type not in G_STORAGES:
+            G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
+
+        return G_STORAGES[storage_type]
+
+    @staticmethod
+    def read(uri: str) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        storage = File._get_storage(uri)
+        return storage.read(uri)
+
+    @staticmethod
+    def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        storage = File._get_storage(uri)
+        return storage.read_text(uri)
+
+    @staticmethod
+    def write(obj: bytes, uri: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``write`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        storage = File._get_storage(uri)
+        return storage.write(obj, uri)
+
+    @staticmethod
+    def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``write_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        storage = File._get_storage(uri)
+        return storage.write_text(obj, uri)
+
+    @contextlib.contextmanager
+    def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
+        """Only for unified API and do nothing."""
+        storage = File._get_storage(uri)
+        with storage.as_local_path(uri) as local_path:
+            yield local_path
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
deleted file mode 100644
index 364746a..0000000
--- a/funasr/utils/asr_utils.py
+++ /dev/null
@@ -1,359 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import os
-import struct
-from typing import Any, Dict, List, Union
-
-import torchaudio
-import librosa
-import numpy as np
-import pkg_resources
-from modelscope.utils.logger import get_logger
-
-logger = get_logger()
-
-green_color = '\033[1;32m'
-red_color = '\033[0;31;40m'
-yellow_color = '\033[0;33;40m'
-end_color = '\033[0m'
-
-global_asr_language = 'zh-cn'
-
-SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
-
-def get_version():
-    return float(pkg_resources.get_distribution('easyasr').version)
-
-
-def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
-    r_audio_fs = None
-
-    if audio_format == 'wav' or audio_format == 'scp':
-        r_audio_fs = get_sr_from_wav(audio_in)
-    elif audio_format == 'pcm' and isinstance(audio_in, bytes):
-        r_audio_fs = get_sr_from_bytes(audio_in)
-
-    return r_audio_fs
-
-
-def type_checking(audio_in: Union[str, bytes],
-                  audio_fs: int = None,
-                  recog_type: str = None,
-                  audio_format: str = None):
-    r_recog_type = recog_type
-    r_audio_format = audio_format
-    r_wav_path = audio_in
-
-    if isinstance(audio_in, str):
-        assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
-    elif isinstance(audio_in, bytes):
-        assert len(audio_in) > 0, 'audio in is empty'
-        r_audio_format = 'pcm'
-        r_recog_type = 'wav'
-
-    if audio_in is None:
-        # for raw_inputs
-        r_recog_type = 'wav'
-        r_audio_format = 'pcm'
-
-    if r_recog_type is None and audio_in is not None:
-        # audio_in is wav, recog_type is wav_file
-        if os.path.isfile(audio_in):
-            audio_type = os.path.basename(audio_in).lower()
-            for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
-                    r_recog_type = 'wav'
-                    r_audio_format = 'wav'
-            if audio_type.rfind(".scp") >= 0:
-                r_recog_type = 'wav'
-                r_audio_format = 'scp'
-            if r_recog_type is None:
-                raise NotImplementedError(
-                    f'Not supported audio type: {audio_type}')
-
-        # recog_type is datasets_file
-        elif os.path.isdir(audio_in):
-            dir_name = os.path.basename(audio_in)
-            if 'test' in dir_name:
-                r_recog_type = 'test'
-            elif 'dev' in dir_name:
-                r_recog_type = 'dev'
-            elif 'train' in dir_name:
-                r_recog_type = 'train'
-
-    if r_audio_format is None:
-        if find_file_by_ends(audio_in, '.ark'):
-            r_audio_format = 'kaldi_ark'
-        elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
-                audio_in, '.WAV'):
-            r_audio_format = 'wav'
-        elif find_file_by_ends(audio_in, '.records'):
-            r_audio_format = 'tfrecord'
-
-    if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
-        # datasets with kaldi_ark file
-        r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
-    elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
-        # datasets with tensorflow records file
-        r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
-    elif r_audio_format == 'wav' and r_recog_type != 'wav':
-        # datasets with waveform files
-        r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
-
-    return r_recog_type, r_audio_format, r_wav_path
-
-
-def get_sr_from_bytes(wav: bytes):
-    sr = None
-    data = wav
-    if len(data) > 44:
-        try:
-            header_fields = {}
-            header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
-            header_fields['Format'] = str(data[8:12], 'UTF-8')
-            header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
-            if header_fields['ChunkID'] == 'RIFF' and header_fields[
-                    'Format'] == 'WAVE' and header_fields[
-                        'Subchunk1ID'] == 'fmt ':
-                header_fields['SampleRate'] = struct.unpack('<I',
-                                                            data[24:28])[0]
-                sr = header_fields['SampleRate']
-        except Exception:
-            # no treatment
-            pass
-    else:
-        logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
-
-    return sr
-
-
-def get_sr_from_wav(fname: str):
-    fs = None
-    if os.path.isfile(fname):
-        audio_type = os.path.basename(fname).lower()
-        for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
-            if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
-                if support_audio_type == "pcm":
-                    fs = None
-                else:
-                    try:
-                        audio, fs = torchaudio.load(fname)
-                    except:
-                        audio, fs = librosa.load(fname)
-                break
-        if audio_type.rfind(".scp") >= 0:
-            with open(fname, encoding="utf-8") as f:
-                for line in f:
-                    wav_path = line.split()[1]
-                    fs = get_sr_from_wav(wav_path)
-                    if fs is not None:
-                        break
-        return fs
-    elif os.path.isdir(fname):
-        dir_files = os.listdir(fname)
-        for file in dir_files:
-            file_path = os.path.join(fname, file)
-            if os.path.isfile(file_path):
-                fs = get_sr_from_wav(file_path)
-            elif os.path.isdir(file_path):
-                fs = get_sr_from_wav(file_path)
-
-            if fs is not None:
-                break
-
-    return fs
-
-
-def find_file_by_ends(dir_path: str, ends: str):
-    dir_files = os.listdir(dir_path)
-    for file in dir_files:
-        file_path = os.path.join(dir_path, file)
-        if os.path.isfile(file_path):
-            if ends == ".wav" or ends == ".WAV":
-                audio_type = os.path.basename(file_path).lower()
-                for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                    if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
-                        return True
-                raise NotImplementedError(
-                    f'Not supported audio type: {audio_type}')
-            elif file_path.endswith(ends):
-                return True
-        elif os.path.isdir(file_path):
-            if find_file_by_ends(file_path, ends):
-                return True
-
-    return False
-
-
-def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
-    dir_files = os.listdir(dir_path)
-    for file in dir_files:
-        file_path = os.path.join(dir_path, file)
-        if os.path.isfile(file_path):
-            audio_type = os.path.basename(file_path).lower()
-            for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
-                    wav_list.append(file_path)
-        elif os.path.isdir(file_path):
-            recursion_dir_all_wav(wav_list, file_path)
-
-    return wav_list
-
-def compute_wer(hyp_list: List[Any],
-                ref_list: List[Any],
-                lang: str = None) -> Dict[str, Any]:
-    assert len(hyp_list) > 0, 'hyp list is empty'
-    assert len(ref_list) > 0, 'ref list is empty'
-
-    rst = {
-        'Wrd': 0,
-        'Corr': 0,
-        'Ins': 0,
-        'Del': 0,
-        'Sub': 0,
-        'Snt': 0,
-        'Err': 0.0,
-        'S.Err': 0.0,
-        'wrong_words': 0,
-        'wrong_sentences': 0
-    }
-
-    if lang is None:
-        lang = global_asr_language
-
-    for h_item in hyp_list:
-        for r_item in ref_list:
-            if h_item['key'] == r_item['key']:
-                out_item = compute_wer_by_line(h_item['value'],
-                                               r_item['value'],
-                                               lang)
-                rst['Wrd'] += out_item['nwords']
-                rst['Corr'] += out_item['cor']
-                rst['wrong_words'] += out_item['wrong']
-                rst['Ins'] += out_item['ins']
-                rst['Del'] += out_item['del']
-                rst['Sub'] += out_item['sub']
-                rst['Snt'] += 1
-                if out_item['wrong'] > 0:
-                    rst['wrong_sentences'] += 1
-                    print_wrong_sentence(key=h_item['key'],
-                                         hyp=h_item['value'],
-                                         ref=r_item['value'])
-                else:
-                    print_correct_sentence(key=h_item['key'],
-                                           hyp=h_item['value'],
-                                           ref=r_item['value'])
-
-                break
-
-    if rst['Wrd'] > 0:
-        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
-    if rst['Snt'] > 0:
-        rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
-
-    return rst
-
-
-def compute_wer_by_line(hyp: List[str],
-                        ref: List[str],
-                        lang: str = 'zh-cn') -> Dict[str, Any]:
-    if lang != 'zh-cn':
-        hyp = hyp.split()
-        ref = ref.split()
-
-    hyp = list(map(lambda x: x.lower(), hyp))
-    ref = list(map(lambda x: x.lower(), ref))
-
-    len_hyp = len(hyp)
-    len_ref = len(ref)
-
-    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
-
-    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
-
-    for i in range(len_hyp + 1):
-        cost_matrix[i][0] = i
-    for j in range(len_ref + 1):
-        cost_matrix[0][j] = j
-
-    for i in range(1, len_hyp + 1):
-        for j in range(1, len_ref + 1):
-            if hyp[i - 1] == ref[j - 1]:
-                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
-            else:
-                substitution = cost_matrix[i - 1][j - 1] + 1
-                insertion = cost_matrix[i - 1][j] + 1
-                deletion = cost_matrix[i][j - 1] + 1
-
-                compare_val = [substitution, insertion, deletion]
-
-                min_val = min(compare_val)
-                operation_idx = compare_val.index(min_val) + 1
-                cost_matrix[i][j] = min_val
-                ops_matrix[i][j] = operation_idx
-
-    match_idx = []
-    i = len_hyp
-    j = len_ref
-    rst = {
-        'nwords': len_ref,
-        'cor': 0,
-        'wrong': 0,
-        'ins': 0,
-        'del': 0,
-        'sub': 0
-    }
-    while i >= 0 or j >= 0:
-        i_idx = max(0, i)
-        j_idx = max(0, j)
-
-        if ops_matrix[i_idx][j_idx] == 0:  # correct
-            if i - 1 >= 0 and j - 1 >= 0:
-                match_idx.append((j - 1, i - 1))
-                rst['cor'] += 1
-
-            i -= 1
-            j -= 1
-
-        elif ops_matrix[i_idx][j_idx] == 2:  # insert
-            i -= 1
-            rst['ins'] += 1
-
-        elif ops_matrix[i_idx][j_idx] == 3:  # delete
-            j -= 1
-            rst['del'] += 1
-
-        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
-            i -= 1
-            j -= 1
-            rst['sub'] += 1
-
-        if i < 0 and j >= 0:
-            rst['del'] += 1
-        elif j < 0 and i >= 0:
-            rst['ins'] += 1
-
-    match_idx.reverse()
-    wrong_cnt = cost_matrix[len_hyp][len_ref]
-    rst['wrong'] = wrong_cnt
-
-    return rst
-
-
-def print_wrong_sentence(key: str, hyp: str, ref: str):
-    space = len(key)
-    print(key + yellow_color + ' ref: ' + ref)
-    print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
-
-
-def print_correct_sentence(key: str, hyp: str, ref: str):
-    space = len(key)
-    print(key + yellow_color + ' ref: ' + ref)
-    print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
-
-
-def print_progress(percent):
-    if percent > 1:
-        percent = 1
-    res = int(50 * percent) * '#'
-    print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 4fb27c0..c5c3ffc 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -9,7 +9,12 @@
 import time
 import logging
 from torch.nn.utils.rnn import pad_sequence
-
+try:
+	from urllib.parse import urlparse
+	from funasr.download.file import HTTPStorage
+	import tempfile
+except:
+	print("urllib is not installed, if you infer from url, please install it first.")
 # def load_audio(data_or_path_or_list, fs: int=16000, audio_fs: int=16000):
 #
 # 	if isinstance(data_or_path_or_list, (list, tuple)):
@@ -43,7 +48,8 @@
 			return data_or_path_or_list_ret
 		else:
 			return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs) for audio in data_or_path_or_list]
-	
+	if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'):
+		data_or_path_or_list = download_from_url(data_or_path_or_list)
 	if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list):
 		data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
 		data_or_path_or_list = data_or_path_or_list[0, :]
@@ -99,4 +105,21 @@
 	
 	if isinstance(data_len, (list, tuple)):
 		data_len = torch.tensor([data_len])
-	return data.to(torch.float32), data_len.to(torch.int32)
\ No newline at end of file
+	return data.to(torch.float32), data_len.to(torch.int32)
+
+def download_from_url(url):
+	
+	result = urlparse(url)
+	file_path = None
+	if result.scheme is not None and len(result.scheme) > 0:
+		storage = HTTPStorage()
+		# bytes
+		data = storage.read(url)
+		work_dir = tempfile.TemporaryDirectory().name
+		if not os.path.exists(work_dir):
+			os.makedirs(work_dir)
+		file_path = os.path.join(work_dir, os.path.basename(url))
+		with open(file_path, 'wb') as fb:
+			fb.write(data)
+	assert file_path is not None, f"failed to download: {url}"
+	return file_path
\ No newline at end of file

--
Gitblit v1.9.1