Spaces:
Build error
Build error
| import re | |
| from transformers import AutoModel | |
| # map to check the supported cv archs and also how to extract the model - in some arch, we want to | |
| # go through a specific prefix to get to the model as in `model.vision_model` for clip | |
| vision_model_name2model = { | |
| r"clip": lambda model: model.vision_model, | |
| r"vit": lambda model: model, | |
| } | |
| def vision_model_name_to_model(model_name_or_path, model): | |
| """returns the model if supported, asserts otherwise""" | |
| model_name_lowcase = model_name_or_path.lower() | |
| for rx, lookup in vision_model_name2model.items(): | |
| if re.search(rx, model_name_lowcase): | |
| return lookup(model) | |
| else: | |
| raise ValueError( | |
| f"Unknown type of backbone vision model. Got {model_name_or_path}, supported regexes:" | |
| f" {list(vision_model_name2model.keys())}." | |
| ) | |
| def get_vision_model(config): | |
| vision_model_name = config.vision_model_name | |
| vision_model_params = eval(config.vision_model_params) | |
| model = AutoModel.from_pretrained(vision_model_name, **vision_model_params) | |
| return vision_model_name_to_model(vision_model_name, model) | |