liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/whisper_lid/lid_predictor.py
@@ -1,25 +1,31 @@
from funasr.register import tables
from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
from funasr.models.whisper_lid.eres2net.ResNet import (
    ERes2Net,
    BasicBlockERes2Net,
    BasicBlockERes2Net_diff_AFF,
)
@tables.register("lid_predictor_classes", "LidPredictor")
class LidPredictor(ERes2Net):
    def __init__(self,
                 block=BasicBlockERes2Net,
                 block_fuse=BasicBlockERes2Net_diff_AFF,
                 num_blocks=[3, 4, 6, 3],
                 m_channels=32,
                 feat_dim=80,
                 embedding_size=192,
                 pooling_func='TSTP',
                 two_emb_layer=False):
    def __init__(
        self,
        block=BasicBlockERes2Net,
        block_fuse=BasicBlockERes2Net_diff_AFF,
        num_blocks=[3, 4, 6, 3],
        m_channels=32,
        feat_dim=80,
        embedding_size=192,
        pooling_func="TSTP",
        two_emb_layer=False,
    ):
        super(LidPredictor, self).__init__(
                block=block,
                block_fuse=block_fuse,
                num_blocks=num_blocks,
                m_channels=m_channels,
                feat_dim=feat_dim,
                embedding_size=embedding_size,
                pooling_func=pooling_func,
                two_emb_layer=two_emb_layer
        )
            block=block,
            block_fuse=block_fuse,
            num_blocks=num_blocks,
            m_channels=m_channels,
            feat_dim=feat_dim,
            embedding_size=embedding_size,
            pooling_func=pooling_func,
            two_emb_layer=two_emb_layer,
        )