| | |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class BranchformerEncoderLayer(torch.nn.Module): |
| | | """Branchformer encoder layer module. |
| | | |
| | |
| | | stochastic_depth_rate: float = 0.0, |
| | | ): |
| | | super().__init__() |
| | | assert (attn is not None) or ( |
| | | cgmlp is not None |
| | | ), "At least one branch should be valid" |
| | | assert (attn is not None) or (cgmlp is not None), "At least one branch should be valid" |
| | | |
| | | self.size = size |
| | | self.attn = attn |
| | |
| | | self.merge_proj = torch.nn.Linear(size, size) |
| | | |
| | | elif merge_method == "fixed_ave": |
| | | assert ( |
| | | 0.0 <= cgmlp_weight <= 1.0 |
| | | ), "cgmlp weight should be between 0.0 and 1.0" |
| | | assert 0.0 <= cgmlp_weight <= 1.0, "cgmlp weight should be between 0.0 and 1.0" |
| | | |
| | | # remove the other branch if only one branch is used |
| | | if cgmlp_weight == 0.0: |
| | |
| | | ) # (batch, 1, time) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=score1.dtype).numpy().dtype |
| | | ).min |
| | | numpy.finfo(torch.tensor(0, dtype=score1.dtype).numpy().dtype).min |
| | | ) |
| | | score1 = score1.masked_fill(mask.eq(0), min_value) |
| | | score1 = torch.softmax(score1, dim=-1).masked_fill( |
| | | mask.eq(0), 0.0 |
| | | ) |
| | | score1 = torch.softmax(score1, dim=-1).masked_fill(mask.eq(0), 0.0) |
| | | else: |
| | | score1 = torch.softmax(score1, dim=-1) |
| | | pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size) |
| | |
| | | ) # (batch, 1, time) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=score2.dtype).numpy().dtype |
| | | ).min |
| | | numpy.finfo(torch.tensor(0, dtype=score2.dtype).numpy().dtype).min |
| | | ) |
| | | score2 = score2.masked_fill(mask.eq(0), min_value) |
| | | score2 = torch.softmax(score2, dim=-1).masked_fill( |
| | | mask.eq(0), 0.0 |
| | | ) |
| | | score2 = torch.softmax(score2, dim=-1).masked_fill(mask.eq(0), 0.0) |
| | | else: |
| | | score2 = torch.softmax(score2, dim=-1) |
| | | pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size) |
| | |
| | | merge_weights = torch.softmax( |
| | | torch.cat([weight1, weight2], dim=-1), dim=-1 |
| | | ) # (batch, 2) |
| | | merge_weights = merge_weights.unsqueeze(-1).unsqueeze( |
| | | -1 |
| | | ) # (batch, 2, 1, 1) |
| | | merge_weights = merge_weights.unsqueeze(-1).unsqueeze(-1) # (batch, 2, 1, 1) |
| | | w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1) |
| | | |
| | | x = x + stoch_layer_coeff * self.dropout( |
| | | self.merge_proj(w1 * x1 + w2 * x2) |
| | | ) |
| | | x = x + stoch_layer_coeff * self.dropout(self.merge_proj(w1 * x1 + w2 * x2)) |
| | | elif self.merge_method == "fixed_ave": |
| | | x = x + stoch_layer_coeff * self.dropout( |
| | | self.merge_proj( |
| | | (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2 |
| | | ) |
| | | self.merge_proj((1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2) |
| | | ) |
| | | else: |
| | | raise RuntimeError(f"unknown merge method: {self.merge_method}") |
| | |
| | | return (x, pos_emb), mask |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | @tables.register("encoder_classes", "BranchformerEncoder") |
| | | class BranchformerEncoder(nn.Module): |
| | |
| | | elif pos_enc_layer_type == "legacy_rel_pos": |
| | | assert attention_layer_type == "legacy_rel_selfattn" |
| | | pos_enc_class = LegacyRelPositionalEncoding |
| | | logging.warning( |
| | | "Using legacy_rel_pos and it will be deprecated in the future." |
| | | ) |
| | | logging.warning("Using legacy_rel_pos and it will be deprecated in the future.") |
| | | else: |
| | | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) |
| | | |
| | |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | logging.warning( |
| | | "Using legacy_rel_selfattn and it will be deprecated in the future." |
| | | ) |
| | | logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.") |
| | | elif attention_layer_type == "rel_selfattn": |
| | | assert pos_enc_layer_type == "rel_pos" |
| | | encoder_selfattn_layer = RelPositionMultiHeadedAttention |
| | |
| | | num_blocks, |
| | | lambda lnum: BranchformerEncoderLayer( |
| | | output_size, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args) |
| | | if use_attn |
| | | else None, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args) if use_attn else None, |
| | | cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, |
| | | dropout_rate, |
| | | merge_method, |