游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/models/sense_voice/model.py
@@ -74,8 +74,6 @@
    ):
        target_mask = kwargs.get("target_mask", None)
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
@@ -304,8 +302,6 @@
    ):
        target_mask = kwargs.get("target_mask", None)
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
@@ -649,8 +645,6 @@
    ):
        target_mask = kwargs.get("target_mask", None)
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
@@ -1054,8 +1048,6 @@
    ):
        target_mask = kwargs.get("target_mask", None)
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
@@ -1594,15 +1586,25 @@
        language = kwargs.get("language", None)
        if language is not None:
            language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
            language_query = self.embed(
                torch.LongTensor(
                    [[self.lid_dict[language] if language in self.lid_dict else 0]]
                ).to(speech.device)
            ).repeat(speech.size(0), 1, 1)
        else:
            language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
            language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(
                speech.size(0), 1, 1
            )
        textnorm = kwargs.get("text_norm", "wotextnorm")
        textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1)
        textnorm_query = self.embed(
            torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
        ).repeat(speech.size(0), 1, 1)
        speech = torch.cat((textnorm_query, speech), dim=1)
        speech_lengths += 1
        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
            speech.size(0), 1, 1
        )
        input_query = torch.cat((language_query, event_emo_query), dim=1)
        speech = torch.cat((input_query, speech), dim=1)
        speech_lengths += 3