| from transformers import FlaxRobertaModel, RobertaTokenizerFast | |
| from datasets import load_dataset | |
| import jax | |
| dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True) | |
| dummy_input = next(iter(dataset))["text"] | |
| tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") | |
| input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10] | |
| model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown") | |
| # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling` | |
| z = model(input_ids) | |