sugarcase
2024-09-27 a8f0aad81de964941493c57351925071f3a8b733
funasr/models/fsmn_kws_mt/model.py
@@ -41,8 +41,7 @@
        encoder_conf: Optional[Dict] = None,
        ctc_conf: Optional[Dict] = None,
        input_size: int = 360,
        vocab_size: int = -1,
        vocab_size2: int = -1,
        vocab_size: list = [],
        ignore_id: int = -1,
        blank_id: int = 0,
        **kwargs,
@@ -63,14 +62,13 @@
        encoder_output_size2 = encoder.output_size2()
        ctc = CTC(
            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
            odim=vocab_size[0], encoder_output_size=encoder_output_size, **ctc_conf
        )
        ctc2 = CTC(
            odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
            odim=vocab_size[1], encoder_output_size=encoder_output_size2, **ctc_conf
        )
        self.blank_id = blank_id
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        # self.frontend = frontend
@@ -208,7 +206,6 @@
        data_lengths=None,
        key: list=None,
        tokenizer=None,
        tokenizer2=None,
        frontend=None,
        **kwargs,
    ):
@@ -217,14 +214,14 @@
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
            token_list=tokenizer[0].token_list,
            seg_dict=tokenizer[0].seg_dict,
        )
        self.kws_decoder2 = KwsCtcPrefixDecoder(
            ctc=self.ctc2,
            keywords=keywords,
            token_list=tokenizer2.token_list,
            seg_dict=tokenizer2.seg_dict,
            token_list=tokenizer[1].token_list,
            seg_dict=tokenizer[1].seg_dict,
        )
        meta_data = {}
@@ -314,12 +311,9 @@
        self,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        ctc_weight: float = 1.0,
        input_size: int = 360,
        vocab_size: int = -1,
        vocab_size2: int = -1,
        blank_id: int = 0,
        **kwargs,
    ):
@@ -328,18 +322,8 @@
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        encoder_output_size = encoder.output_size()
        if ctc_conf is None:
            ctc_conf = {}
        ctc = CTC(
            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
        )
        self.blank_id = blank_id
        self.vocab_size = vocab_size
        self.ctc_weight = ctc_weight
        self.encoder = encoder
        self.ctc = ctc
        self.error_calculator = None