fix enforced bf16 data type on SM75 and lower devices
Browse files- modeling_dots_vision.py +7 -3
    	
        modeling_dots_vision.py
    CHANGED
    
    | @@ -489,9 +489,13 @@ class DotsVisionTransformer(PreTrainedModel): | |
| 489 | 
             
                    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
         | 
| 490 | 
             
                    return rotary_pos_emb
         | 
| 491 |  | 
| 492 | 
            -
                def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16= | 
| 493 | 
            -
             | 
| 494 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 495 | 
             
                    hidden_states = self.patch_embed(hidden_states, grid_thw)
         | 
| 496 |  | 
| 497 | 
             
                    rotary_pos_emb = self.rot_pos_emb(grid_thw)
         | 
|  | |
| 489 | 
             
                    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
         | 
| 490 | 
             
                    return rotary_pos_emb
         | 
| 491 |  | 
| 492 | 
            +
                def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=None) -> torch.Tensor:
         | 
| 493 | 
            +
                	# 尝试修复SM75及之前不支持BF16设备的报错
         | 
| 494 | 
            +
                	# 若未显式指定 bf16,则根据权重 dtype 推断
         | 
| 495 | 
            +
                    if bf16 is None:
         | 
| 496 | 
            +
                        bf16 = (self.dtype == torch.bfloat16)
         | 
| 497 | 
            +
                    # 始终将输入显式对齐到本模块的计算精度,避免 input/bias dtype 不一致
         | 
| 498 | 
            +
                    hidden_states = hidden_states.to(torch.bfloat16 if bf16 else self.dtype)
         | 
| 499 | 
             
                    hidden_states = self.patch_embed(hidden_states, grid_thw)
         | 
| 500 |  | 
| 501 | 
             
                    rotary_pos_emb = self.rot_pos_emb(grid_thw)
         | 
