Update modeling_deepseekocr.py
Browse files- modeling_deepseekocr.py +6 -7
modeling_deepseekocr.py
CHANGED
|
@@ -353,6 +353,7 @@ class DeepseekOCRConfig(DeepseekV2Config):
|
|
| 353 |
|
| 354 |
class DeepseekOCRModel(DeepseekV2Model):
|
| 355 |
config_class = DeepseekOCRConfig
|
|
|
|
| 356 |
|
| 357 |
def __init__(self, config: DeepseekV2Config):
|
| 358 |
super(DeepseekOCRModel, self).__init__(config)
|
|
@@ -383,7 +384,6 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 383 |
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 384 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 385 |
return_dict: Optional[bool] = None,
|
| 386 |
-
verbose: Optional[bool] = None,
|
| 387 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 388 |
|
| 389 |
|
|
@@ -433,7 +433,7 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 433 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 434 |
global_features = self.projector(global_features)
|
| 435 |
|
| 436 |
-
if verbose:
|
| 437 |
print('=====================')
|
| 438 |
print('BASE: ', global_features.shape)
|
| 439 |
print('PATCHES: ', local_features.shape)
|
|
@@ -478,7 +478,7 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 478 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 479 |
global_features = self.projector(global_features)
|
| 480 |
|
| 481 |
-
if verbose:
|
| 482 |
print('=====================')
|
| 483 |
print('BASE: ', global_features.shape)
|
| 484 |
print('NO PATCHES')
|
|
@@ -706,6 +706,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 706 |
|
| 707 |
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
|
| 708 |
self.disable_torch_init()
|
|
|
|
| 709 |
|
| 710 |
if len(output_path) > 0 :
|
| 711 |
os.makedirs(output_path, exist_ok=True)
|
|
@@ -930,8 +931,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 930 |
streamer=streamer,
|
| 931 |
max_new_tokens=8192,
|
| 932 |
no_repeat_ngram_size = 20,
|
| 933 |
-
use_cache = True
|
| 934 |
-
verbose = verbose
|
| 935 |
)
|
| 936 |
|
| 937 |
else:
|
|
@@ -948,8 +948,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 948 |
eos_token_id=tokenizer.eos_token_id,
|
| 949 |
max_new_tokens=8192,
|
| 950 |
no_repeat_ngram_size = 35,
|
| 951 |
-
use_cache = True
|
| 952 |
-
verbose = verbose
|
| 953 |
)
|
| 954 |
|
| 955 |
|
|
|
|
| 353 |
|
| 354 |
class DeepseekOCRModel(DeepseekV2Model):
|
| 355 |
config_class = DeepseekOCRConfig
|
| 356 |
+
verbose = True
|
| 357 |
|
| 358 |
def __init__(self, config: DeepseekV2Config):
|
| 359 |
super(DeepseekOCRModel, self).__init__(config)
|
|
|
|
| 384 |
images_seq_mask: Optional[torch.FloatTensor] = None,
|
| 385 |
images_spatial_crop: Optional[torch.FloatTensor] = None,
|
| 386 |
return_dict: Optional[bool] = None,
|
|
|
|
| 387 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 388 |
|
| 389 |
|
|
|
|
| 433 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 434 |
global_features = self.projector(global_features)
|
| 435 |
|
| 436 |
+
if self.verbose:
|
| 437 |
print('=====================')
|
| 438 |
print('BASE: ', global_features.shape)
|
| 439 |
print('PATCHES: ', local_features.shape)
|
|
|
|
| 478 |
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
|
| 479 |
global_features = self.projector(global_features)
|
| 480 |
|
| 481 |
+
if self.verbose:
|
| 482 |
print('=====================')
|
| 483 |
print('BASE: ', global_features.shape)
|
| 484 |
print('NO PATCHES')
|
|
|
|
| 706 |
|
| 707 |
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
|
| 708 |
self.disable_torch_init()
|
| 709 |
+
self.model.verbose = verbose
|
| 710 |
|
| 711 |
if len(output_path) > 0 :
|
| 712 |
os.makedirs(output_path, exist_ok=True)
|
|
|
|
| 931 |
streamer=streamer,
|
| 932 |
max_new_tokens=8192,
|
| 933 |
no_repeat_ngram_size = 20,
|
| 934 |
+
use_cache = True
|
|
|
|
| 935 |
)
|
| 936 |
|
| 937 |
else:
|
|
|
|
| 948 |
eos_token_id=tokenizer.eos_token_id,
|
| 949 |
max_new_tokens=8192,
|
| 950 |
no_repeat_ngram_size = 35,
|
| 951 |
+
use_cache = True
|
|
|
|
| 952 |
)
|
| 953 |
|
| 954 |
|