|
|
""" |
|
|
Load a trained age regression model and run a prediction on a single image. |
|
|
|
|
|
Usage: python predict.py --model_path saved_model_age_regressor --image_path some_image.jpg |
|
|
""" |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model_path', type=str, default='saved_model_age_regressor') |
|
|
parser.add_argument('--image_path', type=str, required=True) |
|
|
parser.add_argument('--img_size', type=int, default=224) |
|
|
parser.add_argument('--output_key', type=str, default=None, |
|
|
help='If the model returns a dict, select this key for the numeric prediction. If omitted the first numeric output will be used.') |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_image(path, img_size): |
|
|
img = Image.open(path).convert('RGB') |
|
|
img = img.resize((img_size, img_size)) |
|
|
arr = np.array(img, dtype=np.float32) / 255.0 |
|
|
return arr |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
model_path = Path(args.model_path) |
|
|
|
|
|
if model_path.is_file() and model_path.suffix.lower() in ('.h5', '.keras'): |
|
|
model = tf.keras.models.load_model(str(model_path), compile=False) |
|
|
print(f"Loaded Keras model file: {model_path}") |
|
|
elif model_path.is_dir(): |
|
|
|
|
|
|
|
|
try: |
|
|
model = tf.keras.models.load_model(str(model_path), compile=False) |
|
|
print(f"Loaded Keras-compatible model from directory: {model_path}") |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
tf_layer = tf.keras.layers.TFSMLayer(str(model_path), call_endpoint='serving_default') |
|
|
model = tf.keras.Sequential([ |
|
|
tf.keras.Input(shape=(args.img_size, args.img_size, 3)), |
|
|
tf_layer, |
|
|
]) |
|
|
print(f"Wrapped TensorFlow SavedModel at {model_path} with TFSMLayer (serving_default).") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load or wrap SavedModel directory '{model_path}': {e}") |
|
|
else: |
|
|
|
|
|
model = tf.keras.models.load_model(str(model_path), compile=False) |
|
|
print(f"Loaded model from path: {model_path}") |
|
|
image_path = Path(args.image_path) |
|
|
if not image_path.exists(): |
|
|
raise FileNotFoundError(f"Image not found: {image_path}") |
|
|
x = load_image(image_path, args.img_size) |
|
|
x = np.expand_dims(x, axis=0) |
|
|
pred = model.predict(x) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(pred, dict): |
|
|
if args.output_key: |
|
|
if args.output_key not in pred: |
|
|
raise KeyError(f"Requested output key '{args.output_key}' not found. Available keys: {list(pred.keys())}") |
|
|
chosen = pred[args.output_key] |
|
|
else: |
|
|
first_key = next(iter(pred.keys())) |
|
|
print(f"No --output_key provided; using first output key: '{first_key}'") |
|
|
chosen = pred[first_key] |
|
|
arr = np.asarray(chosen) |
|
|
else: |
|
|
arr = np.asarray(pred) |
|
|
|
|
|
if arr.size == 0: |
|
|
raise ValueError("Model returned an empty prediction.") |
|
|
age_pred = float(arr.flatten()[0]) |
|
|
print(f"Predicted age: {age_pred:.2f} years") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|