Sharris's picture
Upload folder using huggingface_hub
de3c81a verified
"""
Load a trained age regression model and run a prediction on a single image.
Usage: python predict.py --model_path saved_model_age_regressor --image_path some_image.jpg
"""
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import tensorflow as tf
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='saved_model_age_regressor')
parser.add_argument('--image_path', type=str, required=True)
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--output_key', type=str, default=None,
help='If the model returns a dict, select this key for the numeric prediction. If omitted the first numeric output will be used.')
return parser.parse_args()
def load_image(path, img_size):
img = Image.open(path).convert('RGB')
img = img.resize((img_size, img_size))
arr = np.array(img, dtype=np.float32) / 255.0
return arr
def main():
args = parse_args()
model_path = Path(args.model_path)
# Load Keras .h5/.keras files directly, and attempt Keras load for directories first.
if model_path.is_file() and model_path.suffix.lower() in ('.h5', '.keras'):
model = tf.keras.models.load_model(str(model_path), compile=False)
print(f"Loaded Keras model file: {model_path}")
elif model_path.is_dir():
# Some SavedModel directories are not loadable with tf.keras.load_model in Keras 3;
# try load_model first (covers .keras saved dirs), otherwise wrap with TFSMLayer.
try:
model = tf.keras.models.load_model(str(model_path), compile=False)
print(f"Loaded Keras-compatible model from directory: {model_path}")
except Exception:
# Wrap the SavedModel with a TFSMLayer for inference compatibility in Keras.
try:
tf_layer = tf.keras.layers.TFSMLayer(str(model_path), call_endpoint='serving_default')
model = tf.keras.Sequential([
tf.keras.Input(shape=(args.img_size, args.img_size, 3)),
tf_layer,
])
print(f"Wrapped TensorFlow SavedModel at {model_path} with TFSMLayer (serving_default).")
except Exception as e:
raise RuntimeError(f"Failed to load or wrap SavedModel directory '{model_path}': {e}")
else:
# Unknown path type: try load_model and allow it to raise a helpful exception.
model = tf.keras.models.load_model(str(model_path), compile=False)
print(f"Loaded model from path: {model_path}")
image_path = Path(args.image_path)
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
x = load_image(image_path, args.img_size)
x = np.expand_dims(x, axis=0)
pred = model.predict(x)
# If the model returns a dict (typical for a wrapped SavedModel serving signature),
# select the requested output key or fall back to the first available numeric output.
if isinstance(pred, dict):
if args.output_key:
if args.output_key not in pred:
raise KeyError(f"Requested output key '{args.output_key}' not found. Available keys: {list(pred.keys())}")
chosen = pred[args.output_key]
else:
first_key = next(iter(pred.keys()))
print(f"No --output_key provided; using first output key: '{first_key}'")
chosen = pred[first_key]
arr = np.asarray(chosen)
else:
arr = np.asarray(pred)
if arr.size == 0:
raise ValueError("Model returned an empty prediction.")
age_pred = float(arr.flatten()[0])
print(f"Predicted age: {age_pred:.2f} years")
if __name__ == '__main__':
main()