Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -721,6 +721,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 721 |
english_token_ids=english_token_ids,
|
| 722 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 723 |
)
|
|
|
|
| 724 |
|
| 725 |
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
|
| 726 |
|
|
@@ -927,9 +928,6 @@ class TorchGptGroupedQueryAttention(nn.Module):
|
|
| 927 |
attention_weights = nn.functional.softmax(attention_logits, dim=-1)
|
| 928 |
attention_weights = attention_weights.to(values.dtype)
|
| 929 |
|
| 930 |
-
print(f"Attention weights type : ", attention_weights.dtype)
|
| 931 |
-
print(f"Values type : ", values.dtype)
|
| 932 |
-
|
| 933 |
values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
|
| 934 |
values = values.contiguous().view(batch_size, seq_len, -1)
|
| 935 |
|
|
|
|
| 721 |
english_token_ids=english_token_ids,
|
| 722 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 723 |
)
|
| 724 |
+
logits = logits.to(torch.float32)
|
| 725 |
|
| 726 |
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
|
| 727 |
|
|
|
|
| 928 |
attention_weights = nn.functional.softmax(attention_logits, dim=-1)
|
| 929 |
attention_weights = attention_weights.to(values.dtype)
|
| 930 |
|
|
|
|
|
|
|
|
|
|
| 931 |
values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
|
| 932 |
values = values.contiguous().view(batch_size, seq_len, -1)
|
| 933 |
|