bichuche0705's picture
Add/Update test_model.py
e5bcf26 verified
"""
Simple test script to run the model on new data
This script uses the correct model paths from the Hugging Face repository
"""
from ultralytics import YOLO
import os
import sys
def test_model_on_image(image_path, model_path=None):
"""
Test the model on a single image.
Args:
image_path: Path to the test image
model_path: Path to the model (auto-detect if None)
"""
# Auto-detect model path
if model_path is None:
# Try different possible locations
possible_paths = [
"models/yolov8m_stage2_improved_best.pt",
"training_runs/yolov8m_stage2_improved/weights/best.pt",
"training_runs/yolov8m_stage1_smart/weights/best.pt",
]
for path in possible_paths:
if os.path.exists(path):
model_path = path
print(f"Found model at: {model_path}")
break
if model_path is None:
print("ERROR: Model file not found!")
print("Please download the model from Hugging Face repository.")
print("The model should be at one of these locations:")
for path in possible_paths:
print(f" - {path}")
return False
if not os.path.exists(model_path):
print(f"ERROR: Model file not found at: {model_path}")
return False
if not os.path.exists(image_path):
print(f"ERROR: Image file not found at: {image_path}")
return False
print(f"\nLoading model from: {model_path}")
try:
model = YOLO(model_path)
print("Model loaded successfully!")
except Exception as e:
print(f"ERROR loading model: {e}")
return False
print(f"\nRunning inference on: {image_path}")
try:
results = model(image_path)
# Print results
for result in results:
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
print(f"\nDetected {len(boxes)} vehicle(s):")
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = box.conf[0].cpu().numpy()
cls = int(box.cls[0].cpu().numpy())
class_name = model.names[cls]
print(f" {i+1}. {class_name}: {conf:.2f} confidence at [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}]")
# Save result image
output_path = image_path.replace('.jpg', '_result.jpg').replace('.png', '_result.png')
if not output_path.endswith('_result.jpg') and not output_path.endswith('_result.png'):
output_path = image_path + '_result.jpg'
result.save(output_path)
print(f"\nResult saved to: {output_path}")
else:
print("\nNo vehicles detected in the image.")
return True
except Exception as e:
print(f"ERROR during inference: {e}")
import traceback
traceback.print_exc()
return False
def test_model_on_video(video_path, model_path=None, output_path=None):
"""
Test the model on a video file.
Args:
video_path: Path to the test video
model_path: Path to the model (auto-detect if None)
output_path: Path to save output video (optional)
"""
# Auto-detect model path
if model_path is None:
possible_paths = [
"models/yolov8m_stage2_improved_best.pt",
"training_runs/yolov8m_stage2_improved/weights/best.pt",
]
for path in possible_paths:
if os.path.exists(path):
model_path = path
print(f"Found model at: {model_path}")
break
if model_path is None:
print("ERROR: Model file not found!")
return False
if not os.path.exists(model_path):
print(f"ERROR: Model file not found at: {model_path}")
return False
if not os.path.exists(video_path):
print(f"ERROR: Video file not found at: {video_path}")
return False
print(f"\nLoading model from: {model_path}")
try:
model = YOLO(model_path)
print("Model loaded successfully!")
except Exception as e:
print(f"ERROR loading model: {e}")
return False
print(f"\nProcessing video: {video_path}")
print("This may take a while...")
try:
if output_path:
results = model(video_path, save=True, project="output", name="detection")
print(f"\nOutput saved to: output/detection/")
else:
results = model(video_path, save=True)
print(f"\nOutput saved to: runs/detect/predict/")
print("Video processing completed!")
return True
except Exception as e:
print(f"ERROR during video processing: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Main function."""
print("=" * 70)
print("Highway Vehicle Detection - Model Test Script")
print("=" * 70)
if len(sys.argv) < 2:
print("\nUsage:")
print(" For image: python test_model.py <image_path>")
print(" For video: python test_model.py <video_path> --video")
print("\nExample:")
print(" python test_model.py test_image.jpg")
print(" python test_model.py test_video.mp4 --video")
print(" python test_model.py test_video.mp4 --video --output output_video.mp4")
return
input_path = sys.argv[1]
is_video = "--video" in sys.argv
output_path = None
if "--output" in sys.argv:
output_idx = sys.argv.index("--output")
if output_idx + 1 < len(sys.argv):
output_path = sys.argv[output_idx + 1]
if is_video:
success = test_model_on_video(input_path, output_path=output_path)
else:
success = test_model_on_image(input_path)
if success:
print("\n" + "=" * 70)
print("TEST COMPLETED SUCCESSFULLY!")
print("=" * 70)
else:
print("\n" + "=" * 70)
print("TEST FAILED - Please check the errors above")
print("=" * 70)
sys.exit(1)
if __name__ == "__main__":
main()