Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
from .dense_optical_tracking import DenseOpticalTracker
from .optical_flow import OpticalFlow
from .point_tracking import PointTracker
def create_model(args):
if args.model == "dot":
model = DenseOpticalTracker(
height=args.height,
width=args.width,
tracker_config=args.tracker_config,
tracker_path=args.tracker_path,
estimator_config=args.estimator_config,
estimator_path=args.estimator_path,
refiner_config=args.refiner_config,
refiner_path=args.refiner_path,
)
elif args.model == "pt":
model = PointTracker(
height=args.height,
width=args.width,
tracker_config=args.tracker_config,
tracker_path=args.tracker_path,
estimator_config=args.estimator_config,
estimator_path=args.estimator_path,
)
elif args.model == "ofe":
model = OpticalFlow(
height=args.height,
width=args.width,
config=args.estimator_config,
load_path=args.estimator_path,
)
elif args.model == "ofr":
model = OpticalFlow(
height=args.height,
width=args.width,
config=args.refiner_config,
load_path=args.refiner_path,
)
else:
raise ValueError(f"Unknown model name {args.model}")
return model