Niharmahesh commited on
Commit
b87a05f
·
verified ·
1 Parent(s): faa3867

Update stage_1/model_setup.py

Browse files
Files changed (1) hide show
  1. stage_1/model_setup.py +14 -17
stage_1/model_setup.py CHANGED
@@ -31,50 +31,47 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  MIN_PIXELS = 256 * 28 * 28
32
  MAX_PIXELS = 256 * 28 * 28
33
 
 
34
  def setup_model():
35
  """
36
- Stage 2 configuration: Unfreeze LLM + connector while keeping vision encoder frozen
 
37
  """
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  print(f"Using device: {device}")
40
 
41
  # Initialize model
42
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
- MODEL_ID,
44
  torch_dtype=torch.bfloat16,
45
  attn_implementation="flash_attention_2",
46
  device_map="auto" if torch.cuda.is_available() else None
47
  )
48
 
49
- # Freeze entire model first
50
  # Freeze entire model first
51
  for param in model.parameters():
52
  param.requires_grad = False
53
 
54
- # 1. Unfreeze vision merger (connector)
55
  for name, param in model.visual.named_parameters():
56
  if "merger" in name:
57
- param.requires_grad = True
58
-
59
- # 2. Unfreeze LLM (model + lm_head) WITHOUT affecting visual.merger
60
- for name, param in model.named_parameters():
61
- if any(k in name for k in ("model", "lm_head")):
62
- param.requires_grad = True # Only modifies LLM params
63
-
64
- # 3. Training modes (rotary_emb auto-included)
65
  model.visual.merger.train()
66
- model.model.train()
67
- model.lm_head.train()
68
-
69
- # Print trainable parameters
70
- print("\n✅ Stage 2 Trainable Parameters:")
71
  for name, param in model.named_parameters():
72
  if param.requires_grad:
73
  print(f"- {name}")
 
 
74
  print("\nModule training states:")
75
  for name, module in model.named_modules():
76
  state = "train" if module.training else "eval"
77
  print(f"{name}: {state}")
 
78
  return model
79
 
80
 
 
31
  MIN_PIXELS = 256 * 28 * 28
32
  MAX_PIXELS = 256 * 28 * 28
33
 
34
+ # Define model setup function (unchanged)
35
  def setup_model():
36
  """
37
+ Initialize and configure the Qwen2.5 VL model with selective parameter freezing.
38
+ Only the vision merger layers will be trainable, while the rest of the model will be frozen.
39
  """
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
  print(f"Using device: {device}")
42
 
43
  # Initialize model
44
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ "Qwen/Qwen2.5-VL-3B-Instruct",
46
  torch_dtype=torch.bfloat16,
47
  attn_implementation="flash_attention_2",
48
  device_map="auto" if torch.cuda.is_available() else None
49
  )
50
 
 
51
  # Freeze entire model first
52
  for param in model.parameters():
53
  param.requires_grad = False
54
 
55
+ # Unfreeze only vision merger layers
56
  for name, param in model.visual.named_parameters():
57
  if "merger" in name:
58
+ param.requires_grad = True # Enable training for these parameters
59
+
60
+ # Force the merger to train mode
 
 
 
 
 
61
  model.visual.merger.train()
62
+
63
+ # Print trainable parameter names
64
+ print("\n✅ Verified trainable parameters:")
 
 
65
  for name, param in model.named_parameters():
66
  if param.requires_grad:
67
  print(f"- {name}")
68
+
69
+ # Print out the training state of all modules:
70
  print("\nModule training states:")
71
  for name, module in model.named_modules():
72
  state = "train" if module.training else "eval"
73
  print(f"{name}: {state}")
74
+
75
  return model
76
 
77