add streamer and verbose params and image_file support io.ByteIO,PIL.Image.Image for infer method
#23
by
						
weege007
	
							
						- opened
							
					
- modeling_deepseekocr.py +22 -14
    	
        modeling_deepseekocr.py
    CHANGED
    
    | @@ -27,7 +27,9 @@ import time | |
| 27 | 
             
            def load_image(image_path):
         | 
| 28 |  | 
| 29 | 
             
                try:
         | 
| 30 | 
            -
                    image =  | 
|  | |
|  | |
| 31 |  | 
| 32 | 
             
                    corrected_image = ImageOps.exif_transpose(image)
         | 
| 33 |  | 
| @@ -353,6 +355,7 @@ class DeepseekOCRConfig(DeepseekV2Config): | |
| 353 |  | 
| 354 | 
             
            class DeepseekOCRModel(DeepseekV2Model):
         | 
| 355 | 
             
                config_class = DeepseekOCRConfig
         | 
|  | |
| 356 |  | 
| 357 | 
             
                def __init__(self, config: DeepseekV2Config):
         | 
| 358 | 
             
                    super(DeepseekOCRModel, self).__init__(config)
         | 
| @@ -432,10 +435,11 @@ class DeepseekOCRModel(DeepseekV2Model): | |
| 432 | 
             
                                    global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) 
         | 
| 433 | 
             
                                    global_features = self.projector(global_features)
         | 
| 434 |  | 
| 435 | 
            -
                                     | 
| 436 | 
            -
             | 
| 437 | 
            -
             | 
| 438 | 
            -
             | 
|  | |
| 439 |  | 
| 440 | 
             
                                    _, hw, n_dim = global_features.shape
         | 
| 441 | 
             
                                    h = w = int(hw ** 0.5)
         | 
| @@ -475,10 +479,12 @@ class DeepseekOCRModel(DeepseekV2Model): | |
| 475 | 
             
                                    global_features_2 = vision_model(image_ori, global_features_1) 
         | 
| 476 | 
             
                                    global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) 
         | 
| 477 | 
             
                                    global_features = self.projector(global_features)
         | 
| 478 | 
            -
             | 
| 479 | 
            -
                                     | 
| 480 | 
            -
             | 
| 481 | 
            -
             | 
|  | |
|  | |
| 482 | 
             
                                    _, hw, n_dim = global_features.shape
         | 
| 483 | 
             
                                    h = w = int(hw ** 0.5)
         | 
| 484 |  | 
| @@ -700,11 +706,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 700 |  | 
| 701 |  | 
| 702 |  | 
| 703 | 
            -
                def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
         | 
| 704 | 
             
                    self.disable_torch_init()
         | 
|  | |
| 705 |  | 
| 706 | 
            -
                     | 
| 707 | 
            -
             | 
|  | |
| 708 |  | 
| 709 | 
             
                    if prompt and image_file:
         | 
| 710 | 
             
                        conversation = [
         | 
| @@ -716,7 +724,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 716 | 
             
                                # "content": "<image>\nFree OCR. ",
         | 
| 717 | 
             
                                # "content": "<image>\nParse the figure. ",
         | 
| 718 | 
             
                                # "content": "<image>\nExtract the text in the image. ",
         | 
| 719 | 
            -
                                "images": [f'{image_file}'],
         | 
| 720 | 
             
                            },
         | 
| 721 | 
             
                            {"role": "<|Assistant|>", "content": ""},
         | 
| 722 | 
             
                        ]
         | 
| @@ -910,7 +918,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 910 |  | 
| 911 |  | 
| 912 | 
             
                    if not eval_mode:
         | 
| 913 | 
            -
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 914 | 
             
                        with torch.autocast("cuda", dtype=torch.bfloat16):
         | 
| 915 | 
             
                            with torch.no_grad():
         | 
| 916 | 
             
                                output_ids = self.generate(
         | 
|  | |
| 27 | 
             
            def load_image(image_path):
         | 
| 28 |  | 
| 29 | 
             
                try:
         | 
| 30 | 
            +
                    image = image_path
         | 
| 31 | 
            +
                    if not isinstance(image_path, Image.Image):
         | 
| 32 | 
            +
                        image = Image.open(image_path)
         | 
| 33 |  | 
| 34 | 
             
                    corrected_image = ImageOps.exif_transpose(image)
         | 
| 35 |  | 
|  | |
| 355 |  | 
| 356 | 
             
            class DeepseekOCRModel(DeepseekV2Model):
         | 
| 357 | 
             
                config_class = DeepseekOCRConfig
         | 
| 358 | 
            +
                verbose = True
         | 
| 359 |  | 
| 360 | 
             
                def __init__(self, config: DeepseekV2Config):
         | 
| 361 | 
             
                    super(DeepseekOCRModel, self).__init__(config)
         | 
|  | |
| 435 | 
             
                                    global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) 
         | 
| 436 | 
             
                                    global_features = self.projector(global_features)
         | 
| 437 |  | 
| 438 | 
            +
                                    if self.verbose:
         | 
| 439 | 
            +
                                        print('=====================')
         | 
| 440 | 
            +
                                        print('BASE: ', global_features.shape)
         | 
| 441 | 
            +
                                        print('PATCHES: ', local_features.shape)
         | 
| 442 | 
            +
                                        print('=====================')
         | 
| 443 |  | 
| 444 | 
             
                                    _, hw, n_dim = global_features.shape
         | 
| 445 | 
             
                                    h = w = int(hw ** 0.5)
         | 
|  | |
| 479 | 
             
                                    global_features_2 = vision_model(image_ori, global_features_1) 
         | 
| 480 | 
             
                                    global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) 
         | 
| 481 | 
             
                                    global_features = self.projector(global_features)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                                    if self.verbose:
         | 
| 484 | 
            +
                                        print('=====================')
         | 
| 485 | 
            +
                                        print('BASE: ', global_features.shape)
         | 
| 486 | 
            +
                                        print('NO PATCHES')
         | 
| 487 | 
            +
                                        print('=====================')
         | 
| 488 | 
             
                                    _, hw, n_dim = global_features.shape
         | 
| 489 | 
             
                                    h = w = int(hw ** 0.5)
         | 
| 490 |  | 
|  | |
| 706 |  | 
| 707 |  | 
| 708 |  | 
| 709 | 
            +
                def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
         | 
| 710 | 
             
                    self.disable_torch_init()
         | 
| 711 | 
            +
                    self.model.verbose = verbose
         | 
| 712 |  | 
| 713 | 
            +
                    if len(output_path) > 0 :
         | 
| 714 | 
            +
                        os.makedirs(output_path, exist_ok=True)
         | 
| 715 | 
            +
                        os.makedirs(f'{output_path}/images', exist_ok=True)
         | 
| 716 |  | 
| 717 | 
             
                    if prompt and image_file:
         | 
| 718 | 
             
                        conversation = [
         | 
|  | |
| 724 | 
             
                                # "content": "<image>\nFree OCR. ",
         | 
| 725 | 
             
                                # "content": "<image>\nParse the figure. ",
         | 
| 726 | 
             
                                # "content": "<image>\nExtract the text in the image. ",
         | 
| 727 | 
            +
                                "images": [image_file] if isinstance(image_file, (BytesIO, Image.Image)) else [f'{image_file}'],
         | 
| 728 | 
             
                            },
         | 
| 729 | 
             
                            {"role": "<|Assistant|>", "content": ""},
         | 
| 730 | 
             
                        ]
         | 
|  | |
| 918 |  | 
| 919 |  | 
| 920 | 
             
                    if not eval_mode:
         | 
| 921 | 
            +
                        streamer = streamer or NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 922 | 
             
                        with torch.autocast("cuda", dtype=torch.bfloat16):
         | 
| 923 | 
             
                            with torch.no_grad():
         | 
| 924 | 
             
                                output_ids = self.generate(
         | 
