Inception V3 model variables for Fréchet Inception Distance in Jax/Flax
Model parameters as well as batch normalization statistics for the Inception V3 model.
The parameters are ported from the original checkpoint http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz.
Example
To use the Inception V3 in inference mode, use the following code.
network = inception.InceptionV3(
num_classes=1008,
last_block_max_pool=True,
with_aux_logits=False,
dtype=jnp.float32, # NOTE: recommended
param_dtype=jnp.float32, # NOTE: recommended
precision=jax.lax.Precision.HIGHEST, # NOTE: this is required for reproducibility
outputs = network.apply(
variables={"params": params, "batch_stats": batch_stats},
inputs=image,
deterministic=True,
with_head=False,
with_aux_logits=False,
rngs=rng,
)
For training, please refer to the Flax documentation here.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support