RoadToNowhere commited on
Commit
e4689a1
·
verified ·
1 Parent(s): ba670c5

fix enforced bf16 data type on SM75 and lower devices

Browse files
Files changed (1) hide show
  1. modeling_dots_vision.py +7 -3
modeling_dots_vision.py CHANGED
@@ -489,9 +489,13 @@ class DotsVisionTransformer(PreTrainedModel):
489
  rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
490
  return rotary_pos_emb
491
 
492
- def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True) -> torch.Tensor:
493
- if bf16:
494
- hidden_states = hidden_states.bfloat16()
 
 
 
 
495
  hidden_states = self.patch_embed(hidden_states, grid_thw)
496
 
497
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
 
489
  rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
490
  return rotary_pos_emb
491
 
492
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=None) -> torch.Tensor:
493
+ # 尝试修复SM75及之前不支持BF16设备的报错
494
+ # 若未显式指定 bf16,则根据权重 dtype 推断
495
+ if bf16 is None:
496
+ bf16 = (self.dtype == torch.bfloat16)
497
+ # 始终将输入显式对齐到本模块的计算精度,避免 input/bias dtype 不一致
498
+ hidden_states = hidden_states.to(torch.bfloat16 if bf16 else self.dtype)
499
  hidden_states = self.patch_embed(hidden_states, grid_thw)
500
 
501
  rotary_pos_emb = self.rot_pos_emb(grid_thw)