Spaces:
Sleeping
Sleeping
| import gc | |
| import os | |
| import torch | |
| from PIL import Image | |
| from torch import nn, optim | |
| from torch.utils.data import DataLoader, Dataset, random_split | |
| from torchvision import models, transforms | |
| # Define data transformations for training and validation | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), # Ensure all images are 224x224 | |
| transforms.ToTensor(), # Convert to tensor | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), # Standard for ResNet | |
| ] | |
| ) | |
| # Custom dataset class for loading chess piece images | |
| class ChessPieceDataset(Dataset): | |
| def __init__(self, root_dir, transform=None): | |
| """ | |
| Args: | |
| root_dir (str): Directory with all the images and subdirectories (class labels). | |
| transform (callable, optional): Optional transform to be applied on an image. | |
| """ | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.classes = sorted( | |
| [ | |
| d | |
| for d in os.listdir(root_dir) | |
| if os.path.isdir(os.path.join(root_dir, d)) | |
| ] | |
| ) | |
| self.image_paths = [] | |
| self.labels = [] | |
| for label, class_name in enumerate(self.classes): | |
| class_folder = os.path.join(root_dir, class_name) | |
| for image_name in os.listdir(class_folder): | |
| img_path = os.path.join(class_folder, image_name) | |
| # Only include valid image files | |
| if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")): | |
| try: | |
| # Verify the image can be opened | |
| with Image.open(img_path) as img: | |
| img.verify() # Verify image integrity | |
| self.image_paths.append(img_path) | |
| self.labels.append(label) | |
| except Exception as e: | |
| print(f"Skipping corrupted image {img_path}: {e}") | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| img_path = self.image_paths[idx] | |
| try: | |
| image = Image.open(img_path).convert("RGB") | |
| except Exception as e: | |
| print(f"Error loading image {img_path}: {e}") | |
| # Return a dummy image and label to avoid crashing | |
| image = Image.new("RGB", (224, 224), (0, 0, 0)) | |
| label = self.labels[idx] | |
| else: | |
| label = self.labels[idx] | |
| if self.transform: | |
| try: | |
| image = self.transform(image) | |
| # Verify the image size after transformation | |
| if image.shape != (3, 224, 224): | |
| print( | |
| f"Unexpected image size after transform for {img_path}: {image.shape}" | |
| ) | |
| except Exception as e: | |
| print(f"Error applying transform to {img_path}: {e}") | |
| image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0))) | |
| return image, label | |
| # Define training function (unchanged) | |
| def train_model( | |
| model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu" | |
| ): | |
| best_accuracy = 0.0 | |
| for epoch in range(num_epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for inputs, labels in train_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| correct += (predicted == labels).sum().item() | |
| total += labels.size(0) | |
| model.eval() | |
| val_correct = 0 | |
| val_total = 0 | |
| with torch.no_grad(): | |
| for inputs, labels in val_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| _, predicted = torch.max(outputs, 1) | |
| val_correct += (predicted == labels).sum().item() | |
| val_total += labels.size(0) | |
| epoch_loss = running_loss / len(train_loader) | |
| epoch_train_accuracy = 100 * correct / total | |
| epoch_val_accuracy = 100 * val_correct / val_total | |
| print( | |
| f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, " | |
| f"Train Accuracy: {epoch_train_accuracy:.2f}%, " | |
| f"Validation Accuracy: {epoch_val_accuracy:.2f}%" | |
| ) | |
| if epoch_val_accuracy > best_accuracy: | |
| best_accuracy = epoch_val_accuracy | |
| torch.save(model.state_dict(), "best_chess_piece_model.pth") | |
| print("Training completed.") | |
| # Path to dataset folder | |
| dataset_path = "train" # Ensure this path is correct | |
| # Create dataset | |
| full_dataset = ChessPieceDataset(dataset_path, transform=transform) | |
| # Check if dataset is empty | |
| if len(full_dataset) == 0: | |
| raise ValueError( | |
| "Dataset is empty. Check dataset_path and ensure it contains valid images." | |
| ) | |
| # Split the dataset into training and validation sets | |
| train_size = int(0.8 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) | |
| # Create DataLoaders | |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the pre-trained ResNet18 model and modify the final layer | |
| model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) | |
| model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes)) | |
| model = model.to(device) | |
| # Define loss function and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=0.0001) | |
| # Train the model | |
| train_model( | |
| model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device | |
| ) | |
| # After training, load the best model for inference | |
| model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device)) | |
| model.eval() | |
| gc.collect() | |
| del model | |
| torch.cuda.empty_cache() | |
| gc.collect() | |