fix
九耳
2023-03-30 0b15e6ea5cccbea3c590958d60e623800bbe3dfb
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -32,8 +32,7 @@
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = 1
        self.encoder_conf = config["encoder_conf"]
        self.punc_list = config.punc_list
        self.punc_list = config['punc_list']
        self.period = 0
        for i in range(len(self.punc_list)):
            if self.punc_list[i] == ",":
@@ -44,13 +43,13 @@
                self.period = i
        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
            train=False,
            token_type=config.token_type,
            token_list=config.token_list,
            bpemodel=config.bpemodel,
            text_cleaner=config.cleaner,
            g2p_type=config.g2p,
            token_type=config['token_type'],
            token_list=config['token_list'],
            bpemodel=config['bpemodel'],
            text_cleaner=config['cleaner'],
            g2p_type=config['g2p'],
            text_name="text",
            non_linguistic_symbols=config.non_linguistic_symbols,
            non_linguistic_symbols=config['non_linguistic_symbols'],
        )
    def __call__(self, text: Union[list, str], split_size=20):
@@ -71,15 +70,14 @@
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
            data = {
                "text": mini_sentence_id,
                "text_lengths": len(mini_sentence_id),
                "text": mini_sentence_id[None,:].astype(np.int64),
                "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
            }
            try:
                outputs = self.infer(data['text'], data['text_lengths'])
                y = outputs[0]
                _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
                punctuations = indices
                assert punctuations.size()[0] == len(mini_sentence)
                punctuations = np.argmax(y,axis=-1)[0]
                assert punctuations.size == len(mini_sentence)
            except ONNXRuntimeError:
                logging.warning("error")
@@ -103,8 +101,7 @@
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            punctuations_np = punctuations.cpu().numpy()
            new_mini_sentence_punc += [int(x) for x in punctuations_np]
            new_mini_sentence_punc += [int(x) for x in punctuations]
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if i > 0:
@@ -126,8 +123,8 @@
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
        return new_mini_sentence_out, new_mini_sentence_punc_out
    def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer(feats)
    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer([feats, feats_len])
        return outputs