Add Apple Silicon (MPS) backend support
Browse filesEnables DeepSeek-OCR to run on Apple Silicon (M1/M2/M3/M4) using the MPS backend with proper OCR output quality.
Key changes:
- Replace masked_scatter_ with row-wise boolean assignment on MPS (fixes silent embedding injection failure)
- Use fp32 precision for images and inference on MPS (bfloat16 causes numerical issues)
- Disable autocast on MPS backend
- Make tensor placement device-agnostic (.to(self.device) instead of .cuda())
- Add NaN guards for vision tower outputs on MPS
All changes are conditionally applied based on self.device.type == "mps".
CUDA code path remains completely unchanged for full backwards compatibility.
Tested on: macOS 26.0.1, Apple M4 Max, PyTorch 2.9.0, Transformers 4.46.3
- modeling_deepseekocr.py +42 -17
modeling_deepseekocr.py
CHANGED
|
@@ -3,6 +3,7 @@ from .configuration_deepseek_v2 import DeepseekV2Config
|
|
| 3 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 4 |
from typing import List, Optional, Tuple, Union
|
| 5 |
from transformers.cache_utils import Cache
|
|
|
|
| 6 |
import requests
|
| 7 |
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
| 8 |
from io import BytesIO
|
|
@@ -502,7 +503,23 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
|
|
@@ -799,7 +816,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 799 |
|
| 800 |
|
| 801 |
|
| 802 |
-
|
|
|
|
|
|
|
| 803 |
|
| 804 |
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
| 805 |
|
|
@@ -810,9 +829,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 810 |
|
| 811 |
if width_crop_num > 1 or height_crop_num > 1:
|
| 812 |
"""process the local views"""
|
| 813 |
-
|
| 814 |
for i in range(len(images_crop_raw)):
|
| 815 |
-
images_crop_list.append(image_transform(images_crop_raw[i]).to(
|
| 816 |
|
| 817 |
if image_size == 640:
|
| 818 |
valid_img_tokens += len(images_crop_list) * 100
|
|
@@ -846,7 +865,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 846 |
# else:
|
| 847 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 848 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 849 |
-
|
|
|
|
|
|
|
| 850 |
|
| 851 |
if base_size == 1024:
|
| 852 |
valid_img_tokens += int(256 * ratio)
|
|
@@ -911,12 +932,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 911 |
|
| 912 |
if not eval_mode:
|
| 913 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 914 |
-
|
|
|
|
|
|
|
| 915 |
with torch.no_grad():
|
| 916 |
output_ids = self.generate(
|
| 917 |
-
input_ids.unsqueeze(0).
|
| 918 |
-
images=[(images_crop.
|
| 919 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).
|
| 920 |
images_spatial_crop = images_spatial_crop,
|
| 921 |
# do_sample=False,
|
| 922 |
# num_beams = 1,
|
|
@@ -929,12 +952,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 929 |
)
|
| 930 |
|
| 931 |
else:
|
| 932 |
-
|
|
|
|
|
|
|
| 933 |
with torch.no_grad():
|
| 934 |
output_ids = self.generate(
|
| 935 |
-
input_ids.unsqueeze(0).
|
| 936 |
-
images=[(images_crop.
|
| 937 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).
|
| 938 |
images_spatial_crop = images_spatial_crop,
|
| 939 |
# do_sample=False,
|
| 940 |
# num_beams = 1,
|
|
@@ -944,10 +969,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 944 |
no_repeat_ngram_size = 35,
|
| 945 |
use_cache = True
|
| 946 |
)
|
| 947 |
-
|
| 948 |
|
| 949 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 950 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 951 |
stop_str = '<|end▁of▁sentence|>'
|
| 952 |
if outputs.endswith(stop_str):
|
| 953 |
outputs = outputs[:-len(stop_str)]
|
|
@@ -957,7 +982,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 957 |
return outputs
|
| 958 |
|
| 959 |
if '<image>' in conversation[0]['content'] and test_compress:
|
| 960 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 961 |
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 962 |
print('='*50)
|
| 963 |
print('image size: ', (w, h))
|
|
@@ -968,7 +993,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 968 |
|
| 969 |
|
| 970 |
if '<image>' in conversation[0]['content'] and save_results:
|
| 971 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 972 |
stop_str = '<|end▁of▁sentence|>'
|
| 973 |
|
| 974 |
print('='*15 + 'save results:' + '='*15)
|
|
|
|
| 3 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 4 |
from typing import List, Optional, Tuple, Union
|
| 5 |
from transformers.cache_utils import Cache
|
| 6 |
+
from contextlib import nullcontext
|
| 7 |
import requests
|
| 8 |
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
| 9 |
from io import BytesIO
|
|
|
|
| 503 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 504 |
# exit()
|
| 505 |
|
| 506 |
+
# MPS compatibility: use row-wise assignment; CUDA: keep original masked_scatter_
|
| 507 |
+
if self.device.type == "mps":
|
| 508 |
+
# MPS-safe: row-wise boolean assignment instead of broadcasted masked_scatter_
|
| 509 |
+
mask = images_seq_mask[idx].to(self.device)
|
| 510 |
+
feats = images_in_this_batch.to(dtype=inputs_embeds.dtype, device=self.device)
|
| 511 |
+
# Basic sanity: number of rows must match
|
| 512 |
+
if mask.sum().item() != feats.shape[0]:
|
| 513 |
+
raise RuntimeError(
|
| 514 |
+
f"image token count mismatch: mask={mask.sum().item()} vs feats={feats.shape[0]}"
|
| 515 |
+
)
|
| 516 |
+
# Guard against NaNs from upstream vision tower (seen on some MPS builds)
|
| 517 |
+
feats = torch.nan_to_num(feats)
|
| 518 |
+
# Deterministic row write
|
| 519 |
+
inputs_embeds[idx][mask] = feats
|
| 520 |
+
else:
|
| 521 |
+
# Original CUDA path (unchanged)
|
| 522 |
+
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
|
| 523 |
|
| 524 |
idx += 1
|
| 525 |
|
|
|
|
| 816 |
|
| 817 |
|
| 818 |
|
| 819 |
+
# MPS needs fp32, CUDA can use bfloat16
|
| 820 |
+
image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
|
| 821 |
+
images_list.append(image_transform(global_view).to(image_dtype))
|
| 822 |
|
| 823 |
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
| 824 |
|
|
|
|
| 829 |
|
| 830 |
if width_crop_num > 1 or height_crop_num > 1:
|
| 831 |
"""process the local views"""
|
| 832 |
+
|
| 833 |
for i in range(len(images_crop_raw)):
|
| 834 |
+
images_crop_list.append(image_transform(images_crop_raw[i]).to(image_dtype))
|
| 835 |
|
| 836 |
if image_size == 640:
|
| 837 |
valid_img_tokens += len(images_crop_list) * 100
|
|
|
|
| 865 |
# else:
|
| 866 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 867 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 868 |
+
# MPS needs fp32, CUDA can use bfloat16
|
| 869 |
+
image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
|
| 870 |
+
images_list.append(image_transform(global_view).to(image_dtype))
|
| 871 |
|
| 872 |
if base_size == 1024:
|
| 873 |
valid_img_tokens += int(256 * ratio)
|
|
|
|
| 932 |
|
| 933 |
if not eval_mode:
|
| 934 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 935 |
+
# MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
|
| 936 |
+
autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
|
| 937 |
+
with autocast_ctx:
|
| 938 |
with torch.no_grad():
|
| 939 |
output_ids = self.generate(
|
| 940 |
+
input_ids.unsqueeze(0).to(self.device),
|
| 941 |
+
images=[(images_crop.to(self.device), images_ori.to(self.device))],
|
| 942 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).to(self.device),
|
| 943 |
images_spatial_crop = images_spatial_crop,
|
| 944 |
# do_sample=False,
|
| 945 |
# num_beams = 1,
|
|
|
|
| 952 |
)
|
| 953 |
|
| 954 |
else:
|
| 955 |
+
# MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
|
| 956 |
+
autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
|
| 957 |
+
with autocast_ctx:
|
| 958 |
with torch.no_grad():
|
| 959 |
output_ids = self.generate(
|
| 960 |
+
input_ids.unsqueeze(0).to(self.device),
|
| 961 |
+
images=[(images_crop.to(self.device), images_ori.to(self.device))],
|
| 962 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).to(self.device),
|
| 963 |
images_spatial_crop = images_spatial_crop,
|
| 964 |
# do_sample=False,
|
| 965 |
# num_beams = 1,
|
|
|
|
| 969 |
no_repeat_ngram_size = 35,
|
| 970 |
use_cache = True
|
| 971 |
)
|
| 972 |
+
|
| 973 |
|
| 974 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 975 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
|
| 976 |
stop_str = '<|end▁of▁sentence|>'
|
| 977 |
if outputs.endswith(stop_str):
|
| 978 |
outputs = outputs[:-len(stop_str)]
|
|
|
|
| 982 |
return outputs
|
| 983 |
|
| 984 |
if '<image>' in conversation[0]['content'] and test_compress:
|
| 985 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
|
| 986 |
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 987 |
print('='*50)
|
| 988 |
print('image size: ', (w, h))
|
|
|
|
| 993 |
|
| 994 |
|
| 995 |
if '<image>' in conversation[0]['content'] and save_results:
|
| 996 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
|
| 997 |
stop_str = '<|end▁of▁sentence|>'
|
| 998 |
|
| 999 |
print('='*15 + 'save results:' + '='*15)
|