ianpan commited on
Commit
1a29f83
·
verified ·
1 Parent(s): 7261dc0

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +26 -0
  3. configuration.py +33 -0
  4. model.safetensors +3 -0
  5. modeling.py +533 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TotalClassifierModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.TotalClassifierConfig",
7
+ "AutoModel": "modeling.TotalClassifierModel"
8
+ },
9
+ "backbone": "tf_efficientnetv2_b0",
10
+ "cnn_dropout": 0.1,
11
+ "feature_dim": 192,
12
+ "image_size": [
13
+ 256,
14
+ 256
15
+ ],
16
+ "in_chans": 1,
17
+ "linear_dropout": 0.1,
18
+ "model_type": "total_classifier",
19
+ "num_classes": 117,
20
+ "rnn_dropout": 0.0,
21
+ "rnn_num_layers": 1,
22
+ "rnn_type": "GRU",
23
+ "seq_len": 512,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.47.0"
26
+ }
configuration.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class TotalClassifierConfig(PretrainedConfig):
5
+ model_type = "total_classifier"
6
+
7
+ def __init__(
8
+ self,
9
+ backbone: str = "tf_efficientnetv2_b0",
10
+ feature_dim: int = 192,
11
+ cnn_dropout: float = 0.1,
12
+ in_chans: int = 1,
13
+ rnn_type: str = "GRU",
14
+ rnn_num_layers: int = 1,
15
+ rnn_dropout: float = 0.0,
16
+ num_classes: int = 117,
17
+ seq_len: int = 512,
18
+ linear_dropout: float = 0.1,
19
+ image_size: tuple[int, int] = (256, 256),
20
+ **kwargs,
21
+ ):
22
+ self.backbone = backbone
23
+ self.feature_dim = feature_dim
24
+ self.cnn_dropout = cnn_dropout
25
+ self.in_chans = in_chans
26
+ self.rnn_type = rnn_type
27
+ self.rnn_num_layers = rnn_num_layers
28
+ self.rnn_dropout = rnn_dropout
29
+ self.num_classes = num_classes
30
+ self.seq_len = seq_len
31
+ self.linear_dropout = linear_dropout
32
+ self.image_size = image_size
33
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:235e8d4cb902c53b68c91e0e61c7837dcda376508eae7b9896a5631ca75b3a0b
3
+ size 23472996
modeling.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import glob
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+ from pathlib import Path
12
+ from transformers import PreTrainedModel
13
+ from timm import create_model
14
+
15
+ from .configuration import TotalClassifierConfig
16
+
17
+ _PYDICOM_AVAILABLE = False
18
+ try:
19
+ from pydicom import dcmread
20
+
21
+ _PYDICOM_AVAILABLE = True
22
+ except ModuleNotFoundError:
23
+ pass
24
+
25
+ _PANDAS_AVAILABLE = False
26
+ try:
27
+ import pandas as pd
28
+
29
+ _PANDAS_AVAILABLE = True
30
+ except ModuleNotFoundError:
31
+ pass
32
+
33
+
34
+ class RNNHead(nn.Module):
35
+ def __init__(
36
+ self,
37
+ rnn_type: str,
38
+ rnn_num_layers: int,
39
+ rnn_dropout: float,
40
+ feature_dim: int,
41
+ linear_dropout: float,
42
+ num_classes: int,
43
+ ):
44
+ super().__init__()
45
+ self.rnn = getattr(nn, rnn_type)(
46
+ input_size=feature_dim,
47
+ hidden_size=feature_dim // 2,
48
+ num_layers=rnn_num_layers,
49
+ dropout=rnn_dropout,
50
+ batch_first=True,
51
+ bidirectional=True,
52
+ )
53
+ self.dropout = nn.Dropout(linear_dropout)
54
+ self.linear = nn.Linear(feature_dim, num_classes)
55
+
56
+ @staticmethod
57
+ def convert_seq_and_mask_to_packed_sequence(
58
+ seq: torch.Tensor, mask: torch.Tensor
59
+ ) -> tuple[torch.Tensor, torch.Tensor]:
60
+ assert seq.shape[0] == mask.shape[0]
61
+ lengths = mask.sum(1)
62
+ seq = nn.utils.rnn.pack_padded_sequence(
63
+ seq, lengths.cpu().int(), batch_first=True, enforce_sorted=False
64
+ )
65
+ return seq
66
+
67
+ def forward(
68
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
69
+ ) -> torch.Tensor:
70
+ skip = x
71
+ if mask is not None:
72
+ # convert to PackedSequence
73
+ L = x.shape[1]
74
+ x = self.convert_seq_and_mask_to_packed_sequence(x, mask)
75
+
76
+ x, _ = self.rnn(x)
77
+
78
+ if mask is not None:
79
+ # convert back to tensor
80
+ x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=L)[0]
81
+
82
+ x = x + skip
83
+ return self.linear(self.dropout(x))
84
+
85
+
86
+ class TotalClassifierModel(PreTrainedModel):
87
+ config_class = TotalClassifierConfig
88
+
89
+ def __init__(self, config):
90
+ super().__init__(config)
91
+ self.image_size = config.image_size
92
+ self.backbone = create_model(
93
+ model_name=config.backbone,
94
+ pretrained=False,
95
+ num_classes=0,
96
+ global_pool="",
97
+ features_only=True,
98
+ in_chans=config.in_chans,
99
+ )
100
+ self.cnn_dropout = nn.Dropout(p=config.cnn_dropout)
101
+ self.head = RNNHead(
102
+ rnn_type=config.rnn_type,
103
+ rnn_num_layers=config.rnn_num_layers,
104
+ rnn_dropout=config.rnn_dropout,
105
+ feature_dim=config.feature_dim,
106
+ linear_dropout=config.linear_dropout,
107
+ num_classes=config.num_classes,
108
+ )
109
+ with open(
110
+ os.path.join(Path(__file__).parent.absolute(), "label2index.json"), "r"
111
+ ) as f:
112
+ self.label2index = json.load(f)
113
+
114
+ self.index2label = {v: k for k, v in self.label2index.items()}
115
+
116
+ def forward(
117
+ self,
118
+ x: torch.Tensor,
119
+ mask: torch.Tensor | None = None,
120
+ return_logits: bool = False,
121
+ return_as_dict: bool = False,
122
+ return_as_df: bool = False,
123
+ ) -> torch.Tensor:
124
+ if return_as_df:
125
+ assert (
126
+ _PANDAS_AVAILABLE
127
+ ), "`return_as_df=True` requires pandas to be installed"
128
+ # x.shape = (b, n, c, h, w)
129
+ b, n, c, h, w = x.shape
130
+ # x = rearrange(x, "b n c h w -> (b n) c h w")
131
+ x = x.reshape(b * n, c, h, w)
132
+ x = self.normalize(x)
133
+ # avg pooling
134
+ features = self.backbone(x)
135
+ # take last feature map
136
+ features = F.adaptive_avg_pool2d(features[-1], 1).flatten(1)
137
+ features = self.cnn_dropout(features)
138
+ # features = rearrange(features, "(b n) d -> b n d", b=b, n=n)
139
+ features = features.reshape(b, n, -1)
140
+ logits = self.head(features, mask=mask)
141
+ if return_logits:
142
+ # return raw logits
143
+ return logits
144
+ probas = logits.sigmoid()
145
+ if return_as_dict or return_as_df:
146
+ # list_of_dictionaries
147
+ batch_list = []
148
+ for i in range(probas.shape[0]):
149
+ dict_for_batch = {}
150
+ probas_i = probas[i]
151
+ for each_class in range(probas_i.shape[1]):
152
+ dict_for_batch[self.index2label[each_class]] = probas_i[
153
+ :, each_class
154
+ ]
155
+ if return_as_df:
156
+ batch_list.append(
157
+ pd.DataFrame(
158
+ {k: v.cpu().numpy() for k, v in dict_for_batch.items()}
159
+ )
160
+ )
161
+ else:
162
+ batch_list.append(dict_for_batch)
163
+ return batch_list
164
+ return probas
165
+
166
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
167
+ # [0, 255] -> [-1, 1]
168
+ mini, maxi = 0.0, 255.0
169
+ x = (x - mini) / (maxi - mini)
170
+ x = (x - 0.5) * 2.0
171
+ return x
172
+
173
+ @staticmethod
174
+ def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]:
175
+ # applying windowing to CT
176
+ lower, upper = WL - WW // 2, WL + WW // 2
177
+ x = np.clip(x, lower, upper)
178
+ x = (x - lower) / (upper - lower)
179
+ return (x * 255.0).astype("uint8")
180
+
181
+ @staticmethod
182
+ def validate_windows_type(windows):
183
+ assert isinstance(windows, tuple) or isinstance(windows, list)
184
+ if isinstance(windows, tuple):
185
+ assert len(windows) == 2
186
+ assert [isinstance(_, int) for _ in windows]
187
+ elif isinstance(windows, list):
188
+ assert all([isinstance(_, tuple) for _ in windows])
189
+ assert all([len(_) == 2 for _ in windows])
190
+ assert all([isinstance(__, int) for _ in windows for __ in _])
191
+
192
+ @staticmethod
193
+ def determine_dicom_orientation(ds) -> int:
194
+ iop = ds.ImageOrientationPatient
195
+
196
+ # Calculate the direction cosine for the normal vector of the plane
197
+ normal_vector = np.cross(iop[:3], iop[3:])
198
+
199
+ # Determine the plane based on the largest component of the normal vector
200
+ abs_normal = np.abs(normal_vector)
201
+ if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
202
+ return 0 # sagittal
203
+ elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
204
+ return 1 # coronal
205
+ else:
206
+ return 2 # axial
207
+
208
+ def load_image_from_dicom(
209
+ self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None
210
+ ) -> np.ndarray:
211
+ # windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH)
212
+ # or list of tuples if wishing to generate multi-channel image using
213
+ # > 1 window
214
+ if not _PYDICOM_AVAILABLE:
215
+ raise Exception("`pydicom` is not installed")
216
+ dicom = dcmread(path)
217
+ array = dicom.pixel_array.astype("float32")
218
+ m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept)
219
+ array = array * m + b
220
+ if windows is None:
221
+ return array
222
+
223
+ self.validate_windows_type(windows)
224
+ if isinstance(windows, tuple):
225
+ windows = [windows]
226
+
227
+ arr_list = []
228
+ for WL, WW in windows:
229
+ arr_list.append(self.window(array.copy(), WL, WW))
230
+
231
+ array = np.stack(arr_list, axis=-1)
232
+ if array.shape[-1] == 1:
233
+ array = np.squeeze(array, axis=-1)
234
+
235
+ return array
236
+
237
+ @staticmethod
238
+ def is_valid_dicom(
239
+ ds,
240
+ fname: str = "",
241
+ sort_by_instance_number: bool = False,
242
+ exclude_invalid_dicoms: bool = False,
243
+ ) -> bool:
244
+ attributes = [
245
+ "pixel_array",
246
+ "RescaleSlope",
247
+ "RescaleIntercept",
248
+ ]
249
+ if sort_by_instance_number:
250
+ attributes.append("InstanceNumber")
251
+ else:
252
+ attributes.append("ImagePositionPatient")
253
+ attributes.append("ImageOrientationPatient")
254
+ attributes_present = [hasattr(ds, attr) for attr in attributes]
255
+ valid = all(attributes_present)
256
+ if not valid and not exclude_invalid_dicoms:
257
+ raise Exception(
258
+ f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}"
259
+ )
260
+ return valid
261
+
262
+ @staticmethod
263
+ def most_common_element(lst):
264
+ return max(set(lst), key=lst.count)
265
+
266
+ @staticmethod
267
+ def center_crop_or_pad_borders(image, size):
268
+ height, width = image.shape[:2]
269
+ new_height, new_width = size
270
+ if new_height < height:
271
+ # crop top and bottom
272
+ crop_top = (height - new_height) // 2
273
+ crop_bottom = height - new_height - crop_top
274
+ image = image[crop_top:-crop_bottom]
275
+ elif new_height > height:
276
+ # pad top and bottom
277
+ pad_top = (new_height - height) // 2
278
+ pad_bottom = new_height - height - pad_top
279
+ image = np.pad(
280
+ image,
281
+ ((pad_top, pad_bottom), (0, 0)),
282
+ mode="constant",
283
+ constant_values=0,
284
+ )
285
+
286
+ if new_width < width:
287
+ # crop left and right
288
+ crop_left = (width - new_width) // 2
289
+ crop_right = width - new_width - crop_left
290
+ image = image[:, crop_left:-crop_right]
291
+ elif new_width > width:
292
+ # pad left and right
293
+ pad_left = (new_width - width) // 2
294
+ pad_right = new_width - width - pad_left
295
+ image = np.pad(
296
+ image,
297
+ ((0, 0), (pad_left, pad_right)),
298
+ mode="constant",
299
+ constant_values=0,
300
+ )
301
+
302
+ return image
303
+
304
+ def load_stack_from_dicom_folder(
305
+ self,
306
+ path: str,
307
+ windows: tuple[int, int] | list[tuple[int, int]] | None = None,
308
+ dicom_extension: str = ".dcm",
309
+ sort_by_instance_number: bool = False,
310
+ exclude_invalid_dicoms: bool = False,
311
+ fix_unequal_shapes: str = "crop_pad",
312
+ return_sorted_dicom_files: bool = False,
313
+ ) -> np.ndarray | tuple[np.ndarray, list[str]]:
314
+ if not _PYDICOM_AVAILABLE:
315
+ raise Exception("`pydicom` is not installed")
316
+ dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}"))
317
+ if len(dicom_files) == 0:
318
+ raise Exception(
319
+ f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`"
320
+ )
321
+ dicoms = [dcmread(f) for f in dicom_files]
322
+ dicoms = [
323
+ (d, dicom_files[idx])
324
+ for idx, d in enumerate(dicoms)
325
+ if self.is_valid_dicom(
326
+ d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms
327
+ )
328
+ ]
329
+ # handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True
330
+ # by only including valid DICOM filenames
331
+ dicom_files = [_[1] for _ in dicoms]
332
+ dicoms = [_[0] for _ in dicoms]
333
+
334
+ slices = [dcm.pixel_array.astype("float32") for dcm in dicoms]
335
+ shapes = np.stack([s.shape for s in slices], axis=0)
336
+ if not np.all(shapes == shapes[0]):
337
+ unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True)
338
+ standard_shape = tuple(unique_shapes[np.argmax(counts)])
339
+ print(
340
+ f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}"
341
+ )
342
+ if fix_unequal_shapes == "crop_pad":
343
+ slices = [
344
+ self.center_crop_or_pad_borders(s, standard_shape)
345
+ if s.shape != standard_shape
346
+ else s
347
+ for s in slices
348
+ ]
349
+ elif fix_unequal_shapes == "resize":
350
+ slices = [
351
+ cv2.resize(s, standard_shape) if s.shape != standard_shape else s
352
+ for s in slices
353
+ ]
354
+ slices = np.stack(slices, axis=0)
355
+ # find orientation
356
+ orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms]
357
+ # use most common
358
+ orientation = self.most_common_element(orientation)
359
+
360
+ # sort using ImagePositionPatient
361
+ # orientation is index to use for sorting
362
+ if sort_by_instance_number:
363
+ positions = [float(d.InstanceNumber) for d in dicoms]
364
+ else:
365
+ positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms]
366
+ indices = np.argsort(positions)
367
+ slices = slices[indices]
368
+
369
+ # rescale
370
+ m, b = (
371
+ [float(d.RescaleSlope) for d in dicoms],
372
+ [float(d.RescaleIntercept) for d in dicoms],
373
+ )
374
+ m, b = self.most_common_element(m), self.most_common_element(b)
375
+ slices = slices * m + b
376
+ if windows is not None:
377
+ self.validate_windows_type(windows)
378
+ if isinstance(windows, tuple):
379
+ windows = [windows]
380
+
381
+ arr_list = []
382
+ for WL, WW in windows:
383
+ arr_list.append(self.window(slices.copy(), WL, WW))
384
+
385
+ slices = np.stack(arr_list, axis=-1)
386
+ if slices.shape[-1] == 1:
387
+ slices = np.squeeze(slices, axis=-1)
388
+
389
+ if return_sorted_dicom_files:
390
+ return slices, [dicom_files[idx] for idx in indices]
391
+ return slices
392
+
393
+ def preprocess(self, x: np.ndarray, mode="2d") -> np.ndarray:
394
+ mode = mode.lower()
395
+ if mode == "2d":
396
+ x = cv2.resize(x, self.image_size)
397
+ if x.ndim == 2:
398
+ x = x[:, :, np.newaxis]
399
+ elif mode == "3d":
400
+ x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0)
401
+ if x.ndim == 3:
402
+ x = x[:, :, :, np.newaxis]
403
+ return x
404
+
405
+ def crop_single_plane(
406
+ self,
407
+ x: np.ndarray,
408
+ device: str | torch.device,
409
+ organ: str | list[str],
410
+ threshold: float = 0.5,
411
+ buffer: float | int = 0,
412
+ speed_up: str | None = None,
413
+ ) -> np.ndarray:
414
+ num_slices = x.shape[0]
415
+ if speed_up is not None:
416
+ assert speed_up in ["fast", "faster", "fastest"]
417
+ if speed_up == "fast":
418
+ # 75% of slices
419
+ reduce_num_slices = 3 * num_slices // 4
420
+ elif speed_up == "faster":
421
+ # 50% of slices
422
+ reduce_num_slices = num_slices // 2
423
+ elif speed_up == "fastest":
424
+ # 33% of slices
425
+ reduce_num_slices = num_slices // 3
426
+ indices = np.linspace(0, num_slices - 1, reduce_num_slices).astype(int)
427
+ x = x[indices]
428
+ x = self.preprocess(x, mode="3d")
429
+ x = torch.from_numpy(x)
430
+ x = rearrange(x, "n h w c -> n c h w").float().to(device)
431
+ x = rearrange(x, "n c h w -> 1 n c h w")
432
+ if x.size(2) > 1:
433
+ # if multi-channel, take mean
434
+ x = x.mean(2, keepdim=True)
435
+ organ_cls = self.forward(x)[0]
436
+ if speed_up is not None:
437
+ # organ_cls.shape = (num_slices, num_classes)
438
+ organ_cls = (
439
+ F.interpolate(
440
+ organ_cls.transpose(1, 0).unsqueeze(0),
441
+ size=(num_slices,),
442
+ mode="linear",
443
+ )
444
+ .squeeze(0)
445
+ .transpose(1, 0)
446
+ )
447
+ assert organ_cls.shape[0] == num_slices
448
+ slices = []
449
+ for each_organ in organ:
450
+ slices.append(
451
+ torch.where(organ_cls[:, self.label2index[each_organ]] >= threshold)[0]
452
+ )
453
+ slices = torch.cat(slices)
454
+ slice_min, slice_max = slices.min().item(), slices.max().item()
455
+ if buffer > 0:
456
+ if isinstance(buffer, float):
457
+ # % buffer
458
+ diff = slice_max - slice_min
459
+ buf = int(buffer * diff)
460
+ else:
461
+ # absolute slice buffer
462
+ buf = buffer
463
+ slice_min = max(0, slice_min - buf)
464
+ slice_max = min(num_slices - 1, slice_max + buf)
465
+ return slice_min, slice_max
466
+
467
+ @torch.no_grad()
468
+ def crop(
469
+ self,
470
+ x: np.ndarray,
471
+ organ: str | list[str],
472
+ crop_dims: int | list[int] = 0,
473
+ device: str | torch.device | None = None,
474
+ raw_hu: bool = False,
475
+ threshold: float = 0.5,
476
+ buffer: float | int = 0,
477
+ speed_up: str | None = None,
478
+ ) -> (
479
+ np.ndarray
480
+ | tuple[np.ndarray, list[int]]
481
+ | tuple[np.ndarray, list[int], list[int]]
482
+ ):
483
+ if device is None:
484
+ device = "cuda" if torch.cuda.is_available() else "cpu"
485
+ assert isinstance(x, np.ndarray)
486
+ assert x.ndim in {
487
+ 3,
488
+ 4,
489
+ }, f"x should be a 3D or 4D array, but got {x.ndim} dimensions"
490
+
491
+ if raw_hu:
492
+ # if input is in Hounsfield units, apply soft tissue window
493
+ x = self.window(x, WL=50, WW=400)
494
+
495
+ x0 = x
496
+ if not isinstance(organ, list):
497
+ organ = [organ]
498
+ if not isinstance(crop_dims, list):
499
+ crop_dims = [crop_dims]
500
+
501
+ assert max(crop_dims) <= 2
502
+ assert min(crop_dims) >= 0
503
+
504
+ if isinstance(buffer, float):
505
+ # percentage of cropped axis dimension
506
+ assert buffer < 1
507
+
508
+ if 0 in crop_dims:
509
+ smin0, smax0 = self.crop_single_plane(
510
+ x0, device, organ, threshold, buffer, speed_up
511
+ )
512
+ else:
513
+ smin0, smax0 = 0, x0.shape[0]
514
+
515
+ if 1 in crop_dims:
516
+ # swap plane
517
+ x = x0.transpose(1, 0, 2)
518
+ smin1, smax1 = self.crop_single_plane(
519
+ x, device, organ, threshold, buffer, speed_up
520
+ )
521
+ else:
522
+ smin1, smax1 = 0, x0.shape[1]
523
+
524
+ if 2 in crop_dims:
525
+ # swap plane
526
+ x = x0.transpose(2, 1, 0)
527
+ smin2, smax2 = self.crop_single_plane(
528
+ x, device, organ, threshold, buffer, speed_up
529
+ )
530
+ else:
531
+ smin2, smax2 = 0, x0.shape[2]
532
+
533
+ return x0[smin0 : smax0 + 1, smin1 : smax1 + 1, smin2 : smax2 + 1]