Add pytorch model
Browse files- config.json +2 -1
- flax_to_pytorch.py +26 -0
- pytorch_model.bin +3 -0
config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "
|
| 3 |
"architectures": [
|
| 4 |
"T5ForConditionalGeneration"
|
| 5 |
],
|
|
@@ -21,6 +21,7 @@
|
|
| 21 |
"pad_token_id": 0,
|
| 22 |
"relative_attention_num_buckets": 32,
|
| 23 |
"tie_word_embeddings": false,
|
|
|
|
| 24 |
"transformers_version": "4.13.0",
|
| 25 |
"use_cache": true,
|
| 26 |
"vocab_size": 32103
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": ".",
|
| 3 |
"architectures": [
|
| 4 |
"T5ForConditionalGeneration"
|
| 5 |
],
|
|
|
|
| 21 |
"pad_token_id": 0,
|
| 22 |
"relative_attention_num_buckets": 32,
|
| 23 |
"tie_word_embeddings": false,
|
| 24 |
+
"torch_dtype": "float32",
|
| 25 |
"transformers_version": "4.13.0",
|
| 26 |
"use_cache": true,
|
| 27 |
"vocab_size": 32103
|
flax_to_pytorch.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from transformers import FlaxT5ForConditionalGeneration
|
| 6 |
+
from transformers import T5ForConditionalGeneration
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained(".")
|
| 8 |
+
model_fx = FlaxT5ForConditionalGeneration.from_pretrained(".")
|
| 9 |
+
model_pt = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
|
| 10 |
+
model_pt.save_pretrained("./")
|
| 11 |
+
text = "Hoe gaat het?"
|
| 12 |
+
e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
|
| 13 |
+
d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
|
| 14 |
+
e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
|
| 15 |
+
d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
|
| 16 |
+
print(e_input_ids_fx)
|
| 17 |
+
print(d_input_ids_fx)
|
| 18 |
+
print()
|
| 19 |
+
encoder_pt = model_fx.encode(**e_input_ids_pt)
|
| 20 |
+
decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
|
| 21 |
+
logits_pt = decoder_pt.logits
|
| 22 |
+
print(logits_pt)
|
| 23 |
+
encoder_fx = model_fx.encode(**e_input_ids_fx)
|
| 24 |
+
decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
|
| 25 |
+
logits_fx = decoder_fx.logits
|
| 26 |
+
print(logits_fx)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa8b87f8bb924ddaf9823ed6c9ed8f57adbee415b398049da58ddbe36997cf9a
|
| 3 |
+
size 990280781
|