|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- object-detection |
|
|
- rfdetr |
|
|
- computer-vision |
|
|
library_name: rfdetr |
|
|
pipeline_tag: object-detection |
|
|
--- |
|
|
|
|
|
# RF-DETR MEDIUM Model |
|
|
|
|
|
This is an RF-DETR (Real-time Transformer-based Object Detector) model fine-tuned on a custom dataset. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Architecture**: RFDETRMedium |
|
|
- **Framework**: RF-DETR (PyTorch) |
|
|
- **Image Resolution**: 640x640 |
|
|
- **Training Epochs**: 50 |
|
|
- **Batch Size**: 4 (with gradient accumulation steps: 4) |
|
|
|
|
|
## Classes |
|
|
|
|
|
The model is trained to detect the following classes: |
|
|
`drone` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Using RF-DETR Library (Recommended) |
|
|
|
|
|
```python |
|
|
from rfdetr import RFDETRMedium |
|
|
from PIL import Image |
|
|
|
|
|
# Load the model |
|
|
model = RFDETRMedium( |
|
|
pretrain_weights="hf://rujutashashikanjoshi/rfdetr-drone-detection-0205/checkpoint_best_total.pth" |
|
|
) |
|
|
model.optimize_for_inference() |
|
|
|
|
|
# Run inference |
|
|
image = Image.open("your_image.jpg") |
|
|
detections = model.predict(image, threshold=0.5) |
|
|
|
|
|
print(f"Found {len(detections)} detections") |
|
|
``` |
|
|
|
|
|
### Using PyTorch Directly |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download the model weights |
|
|
model_path = hf_hub_download( |
|
|
repo_id="rujutashashikanjoshi/rfdetr-drone-detection-0205", |
|
|
filename="pytorch_model.bin" |
|
|
) |
|
|
|
|
|
# Load the state dict |
|
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Optimizer**: AdamW |
|
|
- **Learning Rate Schedule**: Cosine with warmup |
|
|
- **Data Format**: COCO JSON |
|
|
- **Framework Version**: rfdetr>=1.2.1 |
|
|
|
|
|
## Files in this Repository |
|
|
|
|
|
- `checkpoint_best_total.pth`: Original RF-DETR checkpoint (best model) |
|
|
- `pytorch_model.bin`: Standard PyTorch weights for compatibility |
|
|
- `config.json`: Model configuration |
|
|
- `class_names.txt`: List of detection classes |
|
|
- `results.json`: Training results and metrics |
|
|
- `metrics_plot.png`: Visualization of training metrics |
|
|
- `log.json`: Full training log |
|
|
|
|
|
## License |
|
|
|
|
|
This model is released under the Apache 2.0 License. |
|
|
|