Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from model import load_model | |
| # Preprocessing | |
| _transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.4815, 0.4578, 0.4082], | |
| std=[0.2686, 0.2613, 0.2758] | |
| ) | |
| ]) | |
| def load_for_inference(repo_id, filename="model.pt"): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = load_model(repo_id=repo_id, filename=filename, device=device) | |
| tokenizer = model.tokenizer | |
| return model, tokenizer, device | |
| def predict(model, tokenizer, device, image: Image.Image, question: str): | |
| image_tensor = _transform(image).unsqueeze(0).to(device) | |
| q = tokenizer( | |
| question, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True, | |
| max_length=64 | |
| ).to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| images=image_tensor, | |
| input_ids=q.input_ids, | |
| attention_mask=q.attention_mask, | |
| max_length=64, | |
| num_beams=4 | |
| ) | |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |