arcaputo3 commited on
Commit
1e3401a
·
1 Parent(s): 5951289

Add Apple Silicon (MPS) backend support

Browse files

Enables DeepSeek-OCR to run on Apple Silicon (M1/M2/M3/M4) using the MPS backend with proper OCR output quality.

Key changes:
- Replace masked_scatter_ with row-wise boolean assignment on MPS (fixes silent embedding injection failure)
- Use fp32 precision for images and inference on MPS (bfloat16 causes numerical issues)
- Disable autocast on MPS backend
- Make tensor placement device-agnostic (.to(self.device) instead of .cuda())
- Add NaN guards for vision tower outputs on MPS

All changes are conditionally applied based on self.device.type == "mps".
CUDA code path remains completely unchanged for full backwards compatibility.

Tested on: macOS 26.0.1, Apple M4 Max, PyTorch 2.9.0, Transformers 4.46.3

Files changed (1) hide show
  1. modeling_deepseekocr.py +42 -17
modeling_deepseekocr.py CHANGED
@@ -3,6 +3,7 @@ from .configuration_deepseek_v2 import DeepseekV2Config
3
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4
  from typing import List, Optional, Tuple, Union
5
  from transformers.cache_utils import Cache
 
6
  import requests
7
  from PIL import Image, ImageOps, ImageDraw, ImageFont
8
  from io import BytesIO
@@ -502,7 +503,23 @@ class DeepseekOCRModel(DeepseekV2Model):
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
 
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  idx += 1
508
 
@@ -799,7 +816,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
799
 
800
 
801
 
802
- images_list.append(image_transform(global_view).to(torch.bfloat16))
 
 
803
 
804
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
805
 
@@ -810,9 +829,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
810
 
811
  if width_crop_num > 1 or height_crop_num > 1:
812
  """process the local views"""
813
-
814
  for i in range(len(images_crop_raw)):
815
- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
816
 
817
  if image_size == 640:
818
  valid_img_tokens += len(images_crop_list) * 100
@@ -846,7 +865,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
846
  # else:
847
  global_view = ImageOps.pad(image, (image_size, image_size),
848
  color=tuple(int(x * 255) for x in image_transform.mean))
849
- images_list.append(image_transform(global_view).to(torch.bfloat16))
 
 
850
 
851
  if base_size == 1024:
852
  valid_img_tokens += int(256 * ratio)
@@ -911,12 +932,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
911
 
912
  if not eval_mode:
913
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
- with torch.autocast("cuda", dtype=torch.bfloat16):
 
 
915
  with torch.no_grad():
916
  output_ids = self.generate(
917
- input_ids.unsqueeze(0).cuda(),
918
- images=[(images_crop.cuda(), images_ori.cuda())],
919
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
920
  images_spatial_crop = images_spatial_crop,
921
  # do_sample=False,
922
  # num_beams = 1,
@@ -929,12 +952,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
929
  )
930
 
931
  else:
932
- with torch.autocast("cuda", dtype=torch.bfloat16):
 
 
933
  with torch.no_grad():
934
  output_ids = self.generate(
935
- input_ids.unsqueeze(0).cuda(),
936
- images=[(images_crop.cuda(), images_ori.cuda())],
937
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
938
  images_spatial_crop = images_spatial_crop,
939
  # do_sample=False,
940
  # num_beams = 1,
@@ -944,10 +969,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
944
  no_repeat_ngram_size = 35,
945
  use_cache = True
946
  )
947
-
948
 
949
  if '<image>' in conversation[0]['content'] and eval_mode:
950
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
951
  stop_str = '<|end▁of▁sentence|>'
952
  if outputs.endswith(stop_str):
953
  outputs = outputs[:-len(stop_str)]
@@ -957,7 +982,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
957
  return outputs
958
 
959
  if '<image>' in conversation[0]['content'] and test_compress:
960
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
961
  pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
962
  print('='*50)
963
  print('image size: ', (w, h))
@@ -968,7 +993,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
968
 
969
 
970
  if '<image>' in conversation[0]['content'] and save_results:
971
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
972
  stop_str = '<|end▁of▁sentence|>'
973
 
974
  print('='*15 + 'save results:' + '='*15)
 
3
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4
  from typing import List, Optional, Tuple, Union
5
  from transformers.cache_utils import Cache
6
+ from contextlib import nullcontext
7
  import requests
8
  from PIL import Image, ImageOps, ImageDraw, ImageFont
9
  from io import BytesIO
 
503
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
504
  # exit()
505
 
506
+ # MPS compatibility: use row-wise assignment; CUDA: keep original masked_scatter_
507
+ if self.device.type == "mps":
508
+ # MPS-safe: row-wise boolean assignment instead of broadcasted masked_scatter_
509
+ mask = images_seq_mask[idx].to(self.device)
510
+ feats = images_in_this_batch.to(dtype=inputs_embeds.dtype, device=self.device)
511
+ # Basic sanity: number of rows must match
512
+ if mask.sum().item() != feats.shape[0]:
513
+ raise RuntimeError(
514
+ f"image token count mismatch: mask={mask.sum().item()} vs feats={feats.shape[0]}"
515
+ )
516
+ # Guard against NaNs from upstream vision tower (seen on some MPS builds)
517
+ feats = torch.nan_to_num(feats)
518
+ # Deterministic row write
519
+ inputs_embeds[idx][mask] = feats
520
+ else:
521
+ # Original CUDA path (unchanged)
522
+ inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
523
 
524
  idx += 1
525
 
 
816
 
817
 
818
 
819
+ # MPS needs fp32, CUDA can use bfloat16
820
+ image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
821
+ images_list.append(image_transform(global_view).to(image_dtype))
822
 
823
  # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
824
 
 
829
 
830
  if width_crop_num > 1 or height_crop_num > 1:
831
  """process the local views"""
832
+
833
  for i in range(len(images_crop_raw)):
834
+ images_crop_list.append(image_transform(images_crop_raw[i]).to(image_dtype))
835
 
836
  if image_size == 640:
837
  valid_img_tokens += len(images_crop_list) * 100
 
865
  # else:
866
  global_view = ImageOps.pad(image, (image_size, image_size),
867
  color=tuple(int(x * 255) for x in image_transform.mean))
868
+ # MPS needs fp32, CUDA can use bfloat16
869
+ image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
870
+ images_list.append(image_transform(global_view).to(image_dtype))
871
 
872
  if base_size == 1024:
873
  valid_img_tokens += int(256 * ratio)
 
932
 
933
  if not eval_mode:
934
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
935
+ # MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
936
+ autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
937
+ with autocast_ctx:
938
  with torch.no_grad():
939
  output_ids = self.generate(
940
+ input_ids.unsqueeze(0).to(self.device),
941
+ images=[(images_crop.to(self.device), images_ori.to(self.device))],
942
+ images_seq_mask = images_seq_mask.unsqueeze(0).to(self.device),
943
  images_spatial_crop = images_spatial_crop,
944
  # do_sample=False,
945
  # num_beams = 1,
 
952
  )
953
 
954
  else:
955
+ # MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
956
+ autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
957
+ with autocast_ctx:
958
  with torch.no_grad():
959
  output_ids = self.generate(
960
+ input_ids.unsqueeze(0).to(self.device),
961
+ images=[(images_crop.to(self.device), images_ori.to(self.device))],
962
+ images_seq_mask = images_seq_mask.unsqueeze(0).to(self.device),
963
  images_spatial_crop = images_spatial_crop,
964
  # do_sample=False,
965
  # num_beams = 1,
 
969
  no_repeat_ngram_size = 35,
970
  use_cache = True
971
  )
972
+
973
 
974
  if '<image>' in conversation[0]['content'] and eval_mode:
975
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
976
  stop_str = '<|end▁of▁sentence|>'
977
  if outputs.endswith(stop_str):
978
  outputs = outputs[:-len(stop_str)]
 
982
  return outputs
983
 
984
  if '<image>' in conversation[0]['content'] and test_compress:
985
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
986
  pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
987
  print('='*50)
988
  print('image size: ', (w, h))
 
993
 
994
 
995
  if '<image>' in conversation[0]['content'] and save_results:
996
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
997
  stop_str = '<|end▁of▁sentence|>'
998
 
999
  print('='*15 + 'save results:' + '='*15)