antoniaebner commited on
Commit
35189e2
·
1 Parent(s): aae42ec

upload code

Browse files
Files changed (15) hide show
  1. .gitignore +4 -0
  2. Dockerfile +16 -0
  3. LICENSE +407 -0
  4. MODEL_CARD.md +26 -0
  5. README.md +97 -4
  6. app.py +78 -0
  7. config/config.json +36 -0
  8. data/tox_smarts.json +0 -0
  9. predict.py +101 -0
  10. preprocess.py +68 -0
  11. requirements.txt +12 -0
  12. src/__init__.py +0 -0
  13. src/model.py +126 -0
  14. src/preprocess.py +670 -0
  15. src/utils.py +525 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ hiddens/
3
+ logs/
4
+ checkpoints_/
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
MODEL_CARD.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model card - tox21_snn_classifier
2
+ ### Model details
3
+ - Model name: Self-Normalizing Neural Network Tox21 Baseline
4
+ - Developer: JKU (Linz)
5
+ - Paper URL: https://proceedings.neurips.cc/paper_files/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
6
+ - Model type / architecture:
7
+ - Self-Normalizing Neural Network implemented using PyTorch.
8
+ - Hyperparameters: https://huggingface.co/spaces/ml-jku/tox21_snn_classifier/blob/main/config/config.json
9
+ - A multitask network is trained for all Tox21 targets.
10
+ - Inference: Access via FastAPI endpoint. Upon receiving a Tox21 prediction request, the model generates and returns predictions for all Tox21 targets simultaneously.
11
+ - Model version: v0
12
+ - Model date: 14.10.2025
13
+ - Reproducibility: Code for full training is available and enables retraining from
14
+ scratch.
15
+
16
+ ### Intended use
17
+ This model serves as a baseline benchmark for evaluating and comparing toxicity prediction methods across the 12 pathway assays of the Tox21 dataset. It is not intended for clinical decision-making without experimental validation.
18
+
19
+ ### Metric
20
+ Each Tox21 task is evaluated using the area under the receiver operating characteristic curve (AUC). Overall performance is reported as the mean AUC across all individual tasks.
21
+
22
+ ### Training data
23
+ Tox21 training and validation sets.
24
+
25
+ ### Evaluation data
26
+ Tox21 test set.
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: Tox21 Snn Classifier
3
- emoji: 🏢
4
- colorFrom: yellow
5
  colorTo: pink
6
  sdk: docker
7
  pinned: false
@@ -9,4 +9,97 @@ license: cc-by-nc-4.0
9
  short_description: Self-Normalizing Neural Network Baseline for Tox21
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tox21 SNN Classifier
3
+ emoji: 🌖
4
+ colorFrom: green
5
  colorTo: pink
6
  sdk: docker
7
  pinned: false
 
9
  short_description: Self-Normalizing Neural Network Baseline for Tox21
10
  ---
11
 
12
+ # Tox21 SNN Classifier
13
+
14
+ This repository hosts a Hugging Face Space that provides an API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/ml-jku/tox21_leaderboard).
15
+
16
+ Here a [self-normalizing network (SNN)](https://arxiv.org/abs/1706.02515) is trained on the Tox21 dataset, and the trained models are provided for
17
+ inference. Model input is a SMILES string of the small molecule, and the output are 12 numeric values for
18
+ each of the toxic effects of the Tox21 dataset.
19
+
20
+
21
+ **Important:** For leaderboard submission, your Space needs to include training code. The file `train.py` should train the model using the config specified inside the `config/` folder and save the final model parameters into a file inside the `checkpoints/` folder. The model should be trained using the [Tox21_dataset](https://huggingface.co/datasets/tschouis/tox21) provided on Hugging Face. The datasets can be loaded like this:
22
+ ```python
23
+ from datasets import load_dataset
24
+ ds = load_dataset("ml-jku/tox21", token=token)
25
+ train_df = ds["train"].to_pandas()
26
+ val_df = ds["validation"].to_pandas()
27
+ ```
28
+
29
+ Additionally, the Space needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
30
+
31
+ # Repository Structure
32
+ - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
33
+ - `app.py` - FastAPI application wrapper (can be used as-is).
34
+ - `preprocess.py` - preprocesses SMILES strings to generate feature descriptors and saves results as NPZ files in `data/`.
35
+ - `train.py` - trains and saves a model using the config in the `config/` folder.
36
+ - `config/` - the config file used by `train.py`.
37
+ - `logs/` - all the logs of `train.py`, the saved model, and predictions on the validation set.
38
+ - `data/` - SNN uses numerical data. During preprocessing in `preprocess.py` two NPZ files containing molecule features are created and saved here.
39
+ - `checkpoints/` - the saved model that is used in `predict.py` is here.
40
+
41
+ - `src/` - Core model & preprocessing logic:
42
+ - `preprocess.py` - SMILES preprocessing logic
43
+ - `model.py` - SNN model class with processing, saving and loading logic
44
+ - `utils.py` - utility functions
45
+
46
+ # Quickstart with Spaces
47
+
48
+ You can easily adapt this project in your own Hugging Face account:
49
+
50
+ - Open this Space on Hugging Face.
51
+
52
+ - Click "Duplicate this Space" (top-right corner).
53
+
54
+ - Modify `src/` for your preprocessing pipeline and model class
55
+
56
+ - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
57
+
58
+ - Modify `train.py` and/or `preprocess.py` according to your model and preprocessing pipeline.
59
+
60
+ - Modify the file inside `config/` to contain all hyperparameters that are set in `train.py`.
61
+
62
+ That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
63
+
64
+ # Installation
65
+ To run (and train) the SNN, clone the repository and install dependencies:
66
+
67
+ ```bash
68
+ git clone https://huggingface.co/spaces/ml-jku/tox21_snn_classifier
69
+ cd tox21_snn_classifier
70
+
71
+ conda create -n tox21_snn_cls python=3.11
72
+ conda activate tox21_snn_cls
73
+ pip install -r requirements.txt
74
+ ```
75
+
76
+ # Inference
77
+
78
+ For inference, you only need `predict.py`.
79
+
80
+ Example usage inside Python:
81
+
82
+ ```python
83
+ from predict import predict
84
+
85
+ smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
86
+ results = predict(smiles_list)
87
+
88
+ print(results)
89
+ ```
90
+
91
+ The output will be a nested dictionary in the format:
92
+
93
+ ```python
94
+ {
95
+ "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
96
+ "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
97
+ "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
98
+ }
99
+ ```
100
+
101
+ # Notes
102
+
103
+ - Adapting `predict.py`, `train.py`, `config/`, and `checkpoints/` is required for leaderboard submission.
104
+
105
+ - Preprocessing must be done inside `predict.py` not just `train.py`.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "SNN",
48
+ "version": "1.0.0",
49
+ "max_batch_size": 256,
50
+ "tox_endpoints": [
51
+ "NR-AR",
52
+ "NR-AR-LBD",
53
+ "NR-AhR",
54
+ "NR-Aromatase",
55
+ "NR-ER",
56
+ "NR-ER-LBD",
57
+ "NR-PPAR-gamma",
58
+ "SR-ARE",
59
+ "SR-ATAD5",
60
+ "SR-HSE",
61
+ "SR-MMP",
62
+ "SR-p53",
63
+ ],
64
+ }
65
+
66
+
67
+ @app.get("/healthz")
68
+ def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ @app.post("/predict", response_model=Response)
73
+ def predict(request: Request):
74
+ predictions = predict_func(request.smiles)
75
+ return {
76
+ "predictions": predictions,
77
+ "model_info": {"name": "SNN", "version": "1.0.0"},
78
+ }
config/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 0,
3
+ "debug": "false",
4
+ "device": "cpu",
5
+
6
+ "log_folder": "logs/",
7
+
8
+ "data_folder": "data/",
9
+ "cvfold": 4,
10
+ "ecfp" : {
11
+ "radius": 3,
12
+ "fpsize": 8192
13
+ },
14
+ "merge_train_val": "false",
15
+ "descriptors": ["ecfps", "rdkit_descrs", "maccs", "tox"],
16
+ "feature_selection": {
17
+ "use": "true",
18
+ "min_var": 0.05,
19
+ "max_corr": 1,
20
+ "max_features": -1,
21
+ "min_var__feature_keys": ["ecfps", "tox"],
22
+ "max_corr__feature_keys": ["ecfps", "tox"],
23
+ "min_var__independent_keys": "true",
24
+ "max_corr__independent_keys": "true"
25
+ },
26
+ "feature_quantilization": {
27
+ "use": "true",
28
+ "feature_keys": ["rdkit_descrs"]
29
+ },
30
+ "max_samples": -1,
31
+ "scaler": "squash",
32
+ "preprocessor_path": "checkpoints/preprocessor.joblib",
33
+
34
+ "ckpt_path": "checkpoints/snn_ckpt.pth",
35
+ "model_config": "none"
36
+ }
data/tox_smarts.json ADDED
The diff for this file is too large to render. See raw diff
 
predict.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+
13
+ import json
14
+ import joblib
15
+ import torch
16
+
17
+ from src.model import Tox21SNNClassifier, SNNConfig
18
+ from src.preprocess import create_descriptors, FeaturePreprocessor
19
+ from src.utils import TASKS, normalize_config
20
+
21
+ # ---------------------------------------------------------------------------------------
22
+ CONFIG_FILE = "./config/config.json"
23
+
24
+
25
+ def predict(
26
+ smiles_list: list[str], default_prediction=0.5
27
+ ) -> dict[str, dict[str, float]]:
28
+ """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
29
+ any molecule that could not be cleaned.
30
+
31
+ Args:
32
+ smiles_list (list[str]): list of SMILES strings
33
+
34
+ Returns:
35
+ dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
36
+ """
37
+ print(f"Received {len(smiles_list)} SMILES strings")
38
+ # preprocessing pipeline
39
+ with open(CONFIG_FILE, "r") as f:
40
+ config = json.load(f)
41
+ config = normalize_config(config)
42
+
43
+ features, is_clean = create_descriptors(
44
+ smiles_list, config["descriptors"], **config["ecfp"]
45
+ )
46
+ print(f"Created descriptors for {sum(is_clean)} molecules.")
47
+ print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning")
48
+
49
+ # setup model
50
+ preprocessor = FeaturePreprocessor(
51
+ feature_selection_config=config["feature_selection"],
52
+ feature_quantilization_config=config["feature_quantilization"],
53
+ descriptors=config["descriptors"],
54
+ max_samples=config["max_samples"],
55
+ scaler=config["scaler"],
56
+ )
57
+
58
+ preprocessor_ckpt = joblib.load(config["preprocessor_path"])
59
+ preprocessor.set_state(preprocessor_ckpt["preprocessor"])
60
+ print(f"Loaded preprocessor from {config['preprocessor_path']}")
61
+
62
+ features = {descr: array[is_clean] for descr, array in features.items()}
63
+ features = preprocessor.transform(features)
64
+
65
+ dataset = torch.utils.data.TensorDataset(torch.FloatTensor(features))
66
+ loader = torch.utils.data.DataLoader(
67
+ dataset, batch_size=256, shuffle=False, num_workers=0
68
+ )
69
+
70
+ # setup model
71
+ cfg = SNNConfig(
72
+ hidden_dim=512,
73
+ n_layers=8,
74
+ dropout=0.05,
75
+ layer_form="rect",
76
+ in_features=features.shape[1],
77
+ out_features=12,
78
+ )
79
+
80
+ model = Tox21SNNClassifier(cfg)
81
+ model.load_model(config["ckpt_path"])
82
+ model.eval()
83
+ print(f"Loaded model from {config['ckpt_path']}")
84
+
85
+ predictions = defaultdict(dict)
86
+
87
+ print(f"Create predictions:")
88
+ preds = []
89
+ with torch.no_grad():
90
+ preds = np.concatenate([model.predict(batch[0]) for batch in loader], axis=0)
91
+
92
+ for i, target in enumerate(model.tasks):
93
+ target_preds = np.empty_like(is_clean, dtype=float)
94
+
95
+ target_preds[~is_clean] = default_prediction
96
+ target_preds[is_clean] = preds[:, i]
97
+
98
+ for smiles, pred in zip(smiles_list, target_preds):
99
+ predictions[smiles][target] = float(pred)
100
+
101
+ return predictions
preprocess.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a the data processing for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import argparse
10
+
11
+ import numpy as np
12
+
13
+ from src.preprocess import create_descriptors, get_tox21_split
14
+ from src.utils import TASKS, HF_TOKEN, create_dir, normalize_config
15
+
16
+ parser = argparse.ArgumentParser(
17
+ description="Data preprocessing script for the Tox21 dataset"
18
+ )
19
+
20
+ parser.add_argument(
21
+ "--config",
22
+ type=str,
23
+ default="config/config.json",
24
+ )
25
+
26
+
27
+ def main(config):
28
+ """Create molecule descriptors for HF Tox21 dataset"""
29
+ ds = get_tox21_split(HF_TOKEN, cvfold=config["cvfold"])
30
+
31
+ splits = ["train", "validation"]
32
+ for split in splits:
33
+
34
+ print(f"Preprocess {split} molecules")
35
+
36
+ ds_split = ds[split]
37
+ smiles = list(ds_split["smiles"])
38
+
39
+ features, clean_mol_mask = create_descriptors(
40
+ smiles, config["descriptors"], **config["ecfp"]
41
+ )
42
+
43
+ labels = []
44
+ for task in TASKS:
45
+ labels.append(ds_split[task].to_numpy())
46
+ labels = np.stack(labels, axis=1)
47
+
48
+ save_path = os.path.join(config["data_folder"], f"tox21_{split}_cv4.npz")
49
+ with open(save_path, "wb") as f:
50
+ np.savez(
51
+ f,
52
+ clean_mol_mask=clean_mol_mask,
53
+ labels=labels,
54
+ **features,
55
+ )
56
+ print(f"Saved preprocessed {split} split under {save_path}")
57
+ print("Preprocessing finished successfully")
58
+
59
+
60
+ if __name__ == "__main__":
61
+ args = parser.parse_args()
62
+
63
+ with open(args.config, "r") as f:
64
+ config = json.load(f)
65
+ config = normalize_config(config)
66
+
67
+ create_dir(config["data_folder"])
68
+ main(config)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ statsmodels==0.14.5
4
+ rdkit==2025.03.5
5
+ numpy==2.2.6
6
+ scikit-learn==1.6.1
7
+ joblib
8
+ tabulate
9
+ datasets==4.0.0
10
+ scipy==1.16.1
11
+ pandas==2.3.2
12
+ torch==2.8.0
src/__init__.py ADDED
File without changes
src/model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a XGBoost model for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from typing import Literal
10
+
11
+ from dataclasses import dataclass
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from .utils import TASKS
19
+
20
+
21
+ # ---------------------------------------------------------------------------------------
22
+ @dataclass
23
+ class SNNConfig:
24
+ hidden_dim: int
25
+ n_layers: int
26
+ dropout: float
27
+ layer_form: Literal["conic", "rect"]
28
+ in_features: int
29
+ out_features: int
30
+
31
+
32
+ class Tox21SNNClassifier(nn.Module):
33
+ """An SNN classifier that assigns a toxicity score to a given SMILES string."""
34
+
35
+ def __init__(self, config: SNNConfig):
36
+ """Initialize an SNN classifier for each of the 12 Tox21 tasks.
37
+
38
+ Args:
39
+ seed (int, optional): seed for SNN to ensure reproducibility. Defaults to 42.
40
+ """
41
+ super(Tox21SNNClassifier, self).__init__()
42
+
43
+ self.tasks = TASKS
44
+ self.num_tasks = len(TASKS)
45
+
46
+ activation = nn.SELU()
47
+ dropout = nn.AlphaDropout(p=config.dropout)
48
+
49
+ n_hidden = (
50
+ (
51
+ config.hidden_dim
52
+ * np.power(
53
+ np.power(
54
+ config.out_features / config.hidden_dim, 1 / (config.n_layers)
55
+ ),
56
+ range(-1, config.n_layers),
57
+ )
58
+ ).astype(int)
59
+ if config.layer_form == "conic"
60
+ else [config.hidden_dim] * (config.n_layers + 1)
61
+ )
62
+
63
+ n_hidden[0] = config.in_features
64
+ n_hidden[config.n_layers] = config.out_features
65
+
66
+ layers = []
67
+ for l in range(config.n_layers + 1):
68
+ fc = nn.Linear(
69
+ in_features=n_hidden[l],
70
+ out_features=(
71
+ n_hidden[config.n_layers]
72
+ if l == config.n_layers
73
+ else n_hidden[l + 1]
74
+ ),
75
+ )
76
+ if l < config.n_layers:
77
+ block = [
78
+ fc,
79
+ activation,
80
+ dropout,
81
+ ]
82
+ else: # last layer
83
+ block = [fc]
84
+ layers.extend(block)
85
+
86
+ self.model = nn.Sequential(*layers)
87
+ self.config = config
88
+
89
+ self.reset_parameters()
90
+
91
+ def reset_parameters(self):
92
+ for param in self.model.parameters():
93
+ # biases zero
94
+ if len(param.shape) == 1:
95
+ nn.init.constant_(param, 0)
96
+ # others using lecun-normal initialization
97
+ else:
98
+ nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear")
99
+
100
+ def forward(self, x) -> torch.Tensor:
101
+ x = self.model(x)
102
+ return x # x.view(x.size(0), self.num_tasks)
103
+
104
+ def load_model(self, path: str):
105
+ state_dict = torch.load(
106
+ path, weights_only=False, map_location=torch.device("cpu")
107
+ )["model"]
108
+ self.load_state_dict(state_dict)
109
+ self.eval()
110
+
111
+ @torch.no_grad()
112
+ def predict(self, features: torch.tensor) -> np.ndarray:
113
+ """Predicts labels for a given Tox21 target using molecule features
114
+
115
+ Args:
116
+ task (str): the Tox21 target to predict for
117
+ features (torch.tensor): molecule features used for prediction
118
+
119
+ Returns:
120
+ np.ndarray: predicted probability for positive class
121
+ """
122
+ assert (
123
+ len(features.shape) == 2
124
+ ), f"Function expects 2D torch.tensor. Current shape: {features.shape}"
125
+
126
+ return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()
src/preprocess.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from datasets import load_dataset
9
+ from sklearn.base import BaseEstimator, TransformerMixin
10
+ from sklearn.feature_selection import VarianceThreshold
11
+ from sklearn.preprocessing import StandardScaler, FunctionTransformer
12
+ from statsmodels.distributions.empirical_distribution import ECDF
13
+
14
+ from rdkit import Chem, DataStructs
15
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
16
+ from rdkit.Chem.rdchem import Mol
17
+
18
+ from .utils import USED_200_DESCR, TOX_SMARTS_PATH, Standardizer, FeatureDictMixin
19
+
20
+
21
+ class SquashScaler(TransformerMixin, BaseEstimator):
22
+ """
23
+ Scaler that performs sequential standardization, nonlinearity (tanh), and
24
+ re-standardization. Inspired by DeepTox (Mayr et al., 2016)
25
+ """
26
+
27
+ def __init__(self):
28
+ self.scaler1 = StandardScaler()
29
+ self.scaler2 = StandardScaler()
30
+
31
+ def fit(self, X):
32
+ _X = X.copy()
33
+ _X = self.scaler1.fit_transform(_X)
34
+ _X = np.tanh(_X)
35
+ _X = self.scaler2.fit(_X)
36
+ self.is_fitted_ = True
37
+ return self
38
+
39
+ def transform(self, X):
40
+ _X = X.copy()
41
+ _X = self.scaler1.transform(_X)
42
+ _X = np.tanh(_X)
43
+ return self.scaler2.transform(_X)
44
+
45
+
46
+ SCALER_REGISTRY = {
47
+ None: FunctionTransformer,
48
+ "standard": StandardScaler,
49
+ "squash": SquashScaler,
50
+ }
51
+
52
+
53
+ class SubSampler(TransformerMixin, BaseEstimator):
54
+ """
55
+ Preprocessor that randomly samples `max_samples` from data.
56
+
57
+ Args:
58
+ max_samples (int): Maximum allowed samples. If -1, all samples are retained.
59
+
60
+ Input:
61
+ np.ndarray: A 2D NumPy array of shape (n_samples, n_features).
62
+
63
+ Output:
64
+ np.ndarray: Subsampled array of shape (min(n_samples, max_samples), n_features).
65
+ """
66
+
67
+ def __init__(self, *, max_samples=-1):
68
+ self.max_samples = max_samples
69
+ self.is_fitted_ = True
70
+
71
+ def fit(self, X: np.ndarray, y: np.ndarray | None = None):
72
+ return self
73
+
74
+ def transform(
75
+ self, X: np.ndarray, y: np.ndarray | None = None
76
+ ) -> np.ndarray | tuple[np.ndarray]:
77
+
78
+ _X = X.copy()
79
+ _y = y.copy() if y is not None else None
80
+
81
+ if self.max_samples > 0 and _X.shape[0] > self.max_samples:
82
+ resample_idxs = np.random.choice(
83
+ np.arange(_X.shape[0]), size=(self.max_samples,), replace=True
84
+ )
85
+ _X = _X[resample_idxs]
86
+ _y = _y[resample_idxs] if _y is not None else None
87
+
88
+ if _y is None:
89
+ return _X
90
+ return _X, _y
91
+
92
+
93
+ class FeatureSelector(FeatureDictMixin, TransformerMixin, BaseEstimator):
94
+ """
95
+ Preprocessor that performs feature selection based on variance and correlation.
96
+
97
+ This transformer selects features that:
98
+ 1. Have variance above a specified threshold.
99
+ 2. Are below a given pairwise correlation threshold.
100
+ 3. Among the remaining features, keeps only the top `max_features` with the highest variance.
101
+
102
+ The input and output are both dictionaries mapping feature types to their corresponding
103
+ feature matrices.
104
+
105
+ Args:
106
+ min_var (float): Minimum variance required for a feature to be retained.
107
+ max_corr (float): Maximum allowed correlation between features.
108
+ Features exceeding this threshold with others are removed.
109
+ max_features (int): Maximum number of features to keep after filtering.
110
+ If -1, all remaining features are retained.
111
+ feature_keys (list[str]): Features to apply feature selection to.
112
+ independent_keys (bool): Apply filtering only within features types.
113
+
114
+ Input:
115
+ dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
116
+ and each value is a 2D NumPy array of shape (n_samples, n_features).
117
+
118
+ Output:
119
+ dict[str, np.ndarray]: A dictionary with the same keys as the input,
120
+ containing only the selected features for each feature type.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *,
126
+ min_var=0.0,
127
+ max_corr=1.0,
128
+ max_features=-1,
129
+ feature_keys=None,
130
+ min_var__feature_keys=None,
131
+ max_corr__feature_keys=None,
132
+ max_features__feature_keys=None,
133
+ min_var__independent_keys=False,
134
+ max_corr__independent_keys=False,
135
+ max_features__independent_keys=False,
136
+ ):
137
+ self.min_var = min_var
138
+ self.max_corr = max_corr
139
+ self.max_features = max_features
140
+
141
+ self.min_var__feature_keys = min_var__feature_keys
142
+ self.max_corr__feature_keys = max_corr__feature_keys
143
+ self.max_features__feature_keys = max_features__feature_keys
144
+
145
+ self.min_var__independent_keys = min_var__independent_keys
146
+ self.max_corr__independent_keys = max_corr__independent_keys
147
+ self.max_features__independent_keys = max_features__independent_keys
148
+
149
+ super().__init__(feature_keys=feature_keys)
150
+
151
+ def _get_min_var_mask(self, X: np.ndarray, *args) -> np.ndarray:
152
+ var_thresh = VarianceThreshold(threshold=self.min_var)
153
+ return var_thresh.fit(X).get_support() # mask
154
+
155
+ def _get_max_corr_mask(
156
+ self, X: np.ndarray, prev_feature_mask: np.ndarray
157
+ ) -> np.ndarray:
158
+ _prev_feature_mask = prev_feature_mask.copy()
159
+ corr_matrix = np.corrcoef(X[:, _prev_feature_mask], rowvar=False)
160
+ upper_tri = np.triu(corr_matrix, k=1)
161
+ to_keep = np.ones((sum(_prev_feature_mask),), dtype=bool)
162
+ for i in range(upper_tri.shape[0]):
163
+ for j in range(upper_tri.shape[1]):
164
+ if upper_tri[i, j] > self.max_corr:
165
+ to_keep[j] = False
166
+
167
+ _prev_feature_mask[_prev_feature_mask] = to_keep
168
+ return _prev_feature_mask
169
+
170
+ def _get_max_features_mask(
171
+ self, X: np.ndarray, prev_feature_mask: np.ndarray
172
+ ) -> np.ndarray:
173
+ _prev_feature_mask = prev_feature_mask.copy()
174
+ # select features with at least max_var variation
175
+ feature_vars = np.nanvar(X[:, _prev_feature_mask], axis=0)
176
+ order = np.argsort(feature_vars)[: -(self.max_features + 1) : -1]
177
+ keep_feat_idx = np.arange(len(_prev_feature_mask))[order]
178
+ _prev_feature_mask = np.isin(
179
+ np.arange(len(_prev_feature_mask)), keep_feat_idx, assume_unique=True
180
+ )
181
+ return _prev_feature_mask
182
+
183
+ def apply_filter(self, filter, X, prev_feature_mask):
184
+ mask = prev_feature_mask.copy()
185
+ func = self.__getattribute__(f"_get_{filter}_mask")
186
+ feature_keys = self.__getattribute__(f"{filter}__feature_keys")
187
+
188
+ if self.__getattribute__(f"{filter}__independent_keys"):
189
+ for key in feature_keys:
190
+ key_mask = self._curr_keys == key
191
+ mask[key_mask] = func(X[:, key_mask], mask[key_mask])
192
+
193
+ else:
194
+ feature_key_mask = np.isin(self._curr_keys, feature_keys)
195
+ mask[feature_key_mask] = func(
196
+ X[:, feature_key_mask], mask[feature_key_mask]
197
+ )
198
+ return mask
199
+
200
+ def fit(self, X: dict[str, np.ndarray]):
201
+ _X = self.dict_to_array(X)
202
+ feature_mask = np.ones((_X.shape[1]), dtype=bool)
203
+
204
+ # select features with at least min_var variation
205
+ if self.min_var > 0.0:
206
+ if self.min_var__independent_keys:
207
+ for key in self.min_var__feature_keys:
208
+ key_mask = self._curr_keys == key
209
+ feature_mask[key_mask] = self._get_min_var_mask(_X[:, key_mask])
210
+
211
+ else:
212
+ feature_key_mask = np.isin(self._curr_keys, self.min_var__feature_keys)
213
+ feature_mask[feature_key_mask] = self._get_min_var_mask(
214
+ _X[:, feature_key_mask]
215
+ )
216
+
217
+ # select features with at least max_var variation
218
+ if self.max_corr < 1.0:
219
+ if self.max_corr__independent_keys:
220
+ for key in self.max_corr__feature_keys:
221
+ key_mask = self._curr_keys == key
222
+ subset = _X[:, key_mask]
223
+ feature_mask[key_mask] = self._get_max_corr_mask(
224
+ subset, feature_mask[key_mask]
225
+ )
226
+ else:
227
+ feature_key_mask = np.isin(self._curr_keys, self.max_corr__feature_keys)
228
+ feature_mask[feature_key_mask] = self._get_max_corr_mask(
229
+ _X[:, feature_key_mask], feature_mask[feature_key_mask]
230
+ )
231
+
232
+ if self.max_features == 0:
233
+ raise ValueError(
234
+ f"max_features (={self.max_features}) must be -1 or larger 0."
235
+ )
236
+ elif self.max_features > 0:
237
+ if self.max_features__independent_keys:
238
+ for key in self.max_features__feature_keys:
239
+ key_mask = self._curr_keys == key
240
+ feature_mask[key_mask] = self._get_max_features_mask(
241
+ _X[:, key_mask], feature_mask[key_mask]
242
+ )
243
+ else:
244
+ feature_key_mask = np.isin(
245
+ self._curr_keys, self.max_features__feature_keys
246
+ )
247
+ feature_mask[feature_key_mask] = self._get_max_features_mask(
248
+ _X[:, feature_key_mask], feature_mask[feature_key_mask]
249
+ )
250
+
251
+ self._feature_mask = feature_mask
252
+ self.is_fitted_ = True
253
+ return self
254
+
255
+ def transform(self, X: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
256
+ _X = self.dict_to_array(X)
257
+ _X = _X[:, self._feature_mask]
258
+ self._curr_keys = self._curr_keys[self._feature_mask]
259
+ return self.array_to_dict(_X)
260
+
261
+
262
+ class QuantileCreator(FeatureDictMixin, TransformerMixin, BaseEstimator):
263
+ """
264
+ Preprocessor that transforms features into empirical quantiles using ECDFs.
265
+
266
+ This transformer applies an Empirical Cumulative Distribution Function (ECDF)
267
+ to each feature and replaces feature values with their corresponding quantile
268
+ ranks. The transformation is applied independently to each feature type.
269
+
270
+ Both input and output are dictionaries mapping feature types to their
271
+ corresponding feature matrices.
272
+
273
+ Args:
274
+ feature_keys (list[str]): Features to apply quantile creation to.
275
+
276
+ Input:
277
+ dict[str, np.ndarray]: A dictionary where each key corresponds to a feature type
278
+ and each value is a 2D NumPy array of shape (n_samples, n_features).
279
+
280
+ Output:
281
+ dict[str, np.ndarray]: A dictionary with the same keys as the input,
282
+ where each feature value is replaced by its corresponding ECDF quantile rank.
283
+ """
284
+
285
+ def __init__(self, *, feature_keys=None):
286
+ self._ecdfs = None
287
+ super().__init__(feature_keys=feature_keys)
288
+
289
+ def fit(self, X: dict[str, np.ndarray]):
290
+ _X = self.dict_to_array(X)
291
+ ecdfs = []
292
+ for column in range(_X.shape[1]):
293
+ raw_values = _X[:, column].reshape(-1)
294
+ ecdfs.append(ECDF(raw_values))
295
+ self._ecdfs = ecdfs
296
+ self.is_fitted_ = True
297
+ return self
298
+
299
+ def transform(self, X: dict[str, np.ndarray]) -> np.ndarray:
300
+ _X = self.dict_to_array(X)
301
+
302
+ quantiles = np.zeros_like(_X)
303
+ for column in range(_X.shape[1]):
304
+ raw_values = _X[:, column].reshape(-1)
305
+ ecdf = self._ecdfs[column]
306
+ q = ecdf(raw_values)
307
+ quantiles[:, column] = q
308
+
309
+ return self.array_to_dict(quantiles)
310
+
311
+
312
+ class FeaturePreprocessor(TransformerMixin, BaseEstimator):
313
+ """This class implements the feature preprocessing from a dictionary of molecule features."""
314
+
315
+ def __init__(
316
+ self,
317
+ feature_selection_config: dict[str, Any],
318
+ feature_quantilization_config: dict[str, Any],
319
+ descriptors: list[str],
320
+ max_samples: int = -1,
321
+ scaler: str = "standard",
322
+ ):
323
+ self.descriptors = descriptors
324
+
325
+ self.feature_quantilization_config = copy.deepcopy(
326
+ feature_quantilization_config
327
+ )
328
+ self.use_feat_quant = self.feature_quantilization_config.pop("use")
329
+ self.quantile_creator = QuantileCreator(**self.feature_quantilization_config)
330
+
331
+ self.feature_selection_config = copy.deepcopy(feature_selection_config)
332
+ self.use_feat_selec = self.feature_selection_config.pop("use")
333
+ self.feature_selection_config["feature_keys"] = descriptors
334
+ self.feature_selector = FeatureSelector(**self.feature_selection_config)
335
+
336
+ self.max_samples = max_samples
337
+ self.sub_sampler = SubSampler(max_samples=max_samples)
338
+
339
+ self.scaler = SCALER_REGISTRY[scaler]()
340
+
341
+ def __getstate__(self):
342
+ state = super().__getstate__()
343
+ state["quantile_creator"] = self.quantile_creator.__getstate__()
344
+ state["feature_selector"] = self.feature_selector.__getstate__()
345
+ state["sub_sampler"] = self.sub_sampler.__getstate__()
346
+ state["scaler"] = self.scaler.__getstate__()
347
+ return state
348
+
349
+ def __setstate__(self, state):
350
+ _state = copy.deepcopy(state)
351
+ self.quantile_creator.__setstate__(_state.pop("quantile_creator"))
352
+ self.feature_selector.__setstate__(_state.pop("feature_selector"))
353
+ self.sub_sampler.__setstate__(_state.pop("sub_sampler"))
354
+ self.scaler.__setstate__(_state.pop("scaler"))
355
+ super().__setstate__(_state)
356
+
357
+ def get_state(self):
358
+ return self.__getstate__()
359
+
360
+ def set_state(self, state):
361
+ return self.__setstate__(state)
362
+
363
+ def fit(self, X: dict[str, np.ndarray]):
364
+ """Fit the processor transformers"""
365
+ _X = copy.deepcopy(X)
366
+
367
+ if self.use_feat_quant:
368
+ _X = self.quantile_creator.fit_transform(_X)
369
+
370
+ if self.use_feat_selec:
371
+ _X = self.feature_selector.fit_transform(_X)
372
+
373
+ _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
374
+ self.scaler.fit(_X)
375
+ return self
376
+
377
+ def transform(
378
+ self, X: np.ndarray, y: np.ndarray | None = None
379
+ ) -> np.ndarray | tuple[np.ndarray]:
380
+
381
+ _X = X.copy()
382
+ _y = y.copy() if y is not None else None
383
+
384
+ if self.use_feat_quant:
385
+ _X = self.quantile_creator.transform(_X)
386
+ if self.use_feat_selec:
387
+ _X = self.feature_selector.transform(_X)
388
+ _X = np.concatenate([_X[descr] for descr in self.descriptors], axis=1)
389
+ _X = self.scaler.transform(_X)
390
+
391
+ if _y is None:
392
+ _X = self.sub_sampler.transform(_X)
393
+ return _X
394
+
395
+ _X, _y = self.sub_sampler.transform(_X, _y)
396
+ return _X, _y
397
+
398
+
399
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
400
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
401
+ Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
402
+ Modification by Antonia Ebner:
403
+ - skip uncleanable molecules
404
+ - return clean molecule mask
405
+
406
+ Args:
407
+ smiles (list[str]): list of SMILES
408
+
409
+ Returns:
410
+ list[Mol]: list of cleaned molecules
411
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
412
+ index `i` could not be cleaned and was removed.
413
+ """
414
+ sm = Standardizer(canon_taut=True)
415
+
416
+ clean_mol_mask = list()
417
+ mols = list()
418
+ for i, smile in enumerate(smiles):
419
+ mol = Chem.MolFromSmiles(smile)
420
+ standardized_mol, _ = sm.standardize_mol(mol)
421
+ is_cleaned = standardized_mol is not None
422
+ clean_mol_mask.append(is_cleaned)
423
+ if not is_cleaned:
424
+ continue
425
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
426
+ mols.append(can_mol)
427
+
428
+ return mols, np.array(clean_mol_mask)
429
+
430
+
431
+ def create_ecfp_fps(mols: list[Mol], radius=3, fpsize=2048, **kwargs) -> np.ndarray:
432
+ """This function ECFP fingerprints for a list of molecules.
433
+ Inspired by from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
434
+
435
+ Args:
436
+ mols (list[Mol]): list of molecules
437
+
438
+ Returns:
439
+ np.ndarray: ECFP fingerprints of molecules
440
+ """
441
+ ecfps = list()
442
+
443
+ for mol in mols:
444
+ gen = rdFingerprintGenerator.GetMorganGenerator(
445
+ countSimulation=True, fpSize=fpsize, radius=radius
446
+ )
447
+ fp_sparse_vec = gen.GetCountFingerprint(mol)
448
+
449
+ fp = np.zeros((0,), np.int8)
450
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
451
+
452
+ ecfps.append(fp)
453
+
454
+ return np.array(ecfps)
455
+
456
+
457
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
458
+ """This function creates MACCS keys for a list of molecules.
459
+
460
+ Args:
461
+ mols (list[Mol]): list of molecules
462
+
463
+ Returns:
464
+ np.ndarray: MACCS keys of molecules
465
+ """
466
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
467
+ return np.array(maccs)
468
+
469
+
470
+ def get_tox_patterns(filepath: str):
471
+ """This retrieves the tox features defined in filepath.
472
+ Args:
473
+ filepath (str): A list of tox features
474
+ """
475
+ # load patterns
476
+ with open(filepath) as f:
477
+ smarts_list = [s[1] for s in json.load(f)]
478
+
479
+ # Code does not work for this case
480
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
481
+
482
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
483
+ # and then use them for all molecules. This gives a huge speedup over existing code.
484
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
485
+ all_patterns = []
486
+ for smarts in smarts_list:
487
+ patterns = [] # list of smarts-patterns
488
+ # value for each of the patterns above. Negates the values of the above later.
489
+ negations = []
490
+
491
+ if " AND " in smarts:
492
+ smarts = smarts.split(" AND ")
493
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
494
+ else:
495
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
496
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
497
+ smarts = smarts.split(" OR ")
498
+ merge_any = True
499
+
500
+ # for all subsmarts check if they are preceded by 'NOT '
501
+ for s in smarts:
502
+ neg = s.startswith("NOT ")
503
+ if neg:
504
+ s = s[4:]
505
+ patterns.append(Chem.MolFromSmarts(s))
506
+ negations.append(neg)
507
+
508
+ all_patterns.append((patterns, negations, merge_any))
509
+ return all_patterns
510
+
511
+
512
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
513
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
514
+ tox_data = []
515
+ for mol in mols:
516
+ mol_features = []
517
+ for patts, negations, merge_any in patterns:
518
+ matches = [mol.HasSubstructMatch(p) for p in patts]
519
+ matches = [m != n for m, n in zip(matches, negations)]
520
+ if merge_any:
521
+ pres = any(matches)
522
+ else:
523
+ pres = all(matches)
524
+ mol_features.append(pres)
525
+
526
+ tox_data.append(np.array(mol_features))
527
+
528
+ return np.array(tox_data)
529
+
530
+
531
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
532
+ """This function creates RDKit descriptors for a list of molecules.
533
+ Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
534
+
535
+ Args:
536
+ mols (list[Mol]): list of molecules
537
+
538
+ Returns:
539
+ np.ndarray: RDKit descriptors of molecules
540
+ """
541
+ rdkit_descriptors = list()
542
+
543
+ for mol in mols:
544
+ descrs = []
545
+ for _, descr_calc_fn in Descriptors._descList:
546
+ descrs.append(descr_calc_fn(mol))
547
+
548
+ descrs = np.array(descrs)
549
+ descrs = descrs[USED_200_DESCR]
550
+ rdkit_descriptors.append(descrs)
551
+
552
+ return np.array(rdkit_descriptors)
553
+
554
+
555
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
556
+ """Create quantile values for given features using the columns
557
+ Taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
558
+
559
+ Args:
560
+ raw_features (np.ndarray): values to put into quantiles
561
+ ecdfs (list): ECDFs to use
562
+
563
+ Returns:
564
+ np.ndarray: computed quantiles
565
+ """
566
+ quantiles = np.zeros_like(raw_features)
567
+
568
+ for column in range(raw_features.shape[1]):
569
+ raw_values = raw_features[:, column].reshape(-1)
570
+ ecdf = ecdfs[column]
571
+ q = ecdf(raw_values)
572
+ quantiles[:, column] = q
573
+
574
+ return quantiles
575
+
576
+
577
+ def fill(features, mask, value=np.nan):
578
+ n_mols = len(mask)
579
+ n_features = features.shape[1]
580
+
581
+ data = np.zeros(shape=(n_mols, n_features))
582
+ data.fill(value)
583
+ data[~mask] = features
584
+ return data
585
+
586
+
587
+ def create_descriptors(
588
+ smiles,
589
+ descriptors,
590
+ **ecfp_kwargs,
591
+ ):
592
+ """Generate molecular descriptors for multiple SMILES strings.
593
+ Inspired by https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
594
+
595
+ Each SMILES is processed and sanitized using RDKit.
596
+ SMILES that cannot be sanitized are encoded with NaNs, and a corresponding boolean mask
597
+ is returned to indicate which inputs were successfully processed.
598
+
599
+ Args:
600
+ smiles (list[str]): List of SMILES strings for which to generate descriptors.
601
+ descriptors (list[str]): List of descriptor types to compute.
602
+ Supported values include:
603
+ ['ecfps', 'tox', 'maccs', 'rdkit_descrs'].
604
+
605
+ Returns:
606
+ tuple[dict[str, np.ndarray], np.ndarray]:
607
+ - A dictionary mapping descriptor names to their computed arrays.
608
+ - A boolean mask of shape (len(smiles),) indicating which SMILES
609
+ were successfully sanitized and processed.
610
+ """
611
+ # Create cleanded rdkit mol objects
612
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
613
+ print(f"Cleaned molecules, {(~clean_mol_mask).sum()} could not be sanitized")
614
+
615
+ # Create fingerprints and descriptors
616
+ if "ecfps" in descriptors:
617
+ ecfps = create_ecfp_fps(mols, **ecfp_kwargs)
618
+ ecfps = fill(ecfps, ~clean_mol_mask)
619
+ print("Created ECFP fingerprints")
620
+
621
+ if "tox" in descriptors:
622
+ tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
623
+ tox = create_tox_features(mols, tox_patterns)
624
+ tox = fill(tox, ~clean_mol_mask)
625
+ print("Created Tox features")
626
+
627
+ if "maccs" in descriptors:
628
+ maccs = create_maccs_keys(mols)
629
+ maccs = fill(maccs, ~clean_mol_mask)
630
+ print("Created MACCS keys")
631
+
632
+ if "rdkit_descrs" in descriptors:
633
+ rdkit_descrs = create_rdkit_descriptors(mols)
634
+ rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
635
+ print("Created RDKit descriptors")
636
+
637
+ # concatenate features
638
+ features = {}
639
+ for descr in descriptors:
640
+ features[descr] = vars()[descr]
641
+
642
+ return features, clean_mol_mask
643
+
644
+
645
+ def get_tox21_split(token, cvfold=None):
646
+ """Retrieve Tox21 splits from HuggingFace with respect to given cvfold."""
647
+ ds = load_dataset("ml-jku/tox21", token=token)
648
+
649
+ train_df = ds["train"].to_pandas()
650
+ val_df = ds["validation"].to_pandas()
651
+
652
+ if cvfold is None:
653
+ return {"train": train_df, "validation": val_df}
654
+
655
+ combined_df = pd.concat([train_df, val_df], ignore_index=True)
656
+ cvfold = float(cvfold)
657
+
658
+ # create new splits
659
+ cvfold = float(cvfold)
660
+ train_df = combined_df[combined_df.CVfold != cvfold]
661
+ val_df = combined_df[combined_df.CVfold == cvfold]
662
+
663
+ # exclude train mols that occur in the validation split
664
+ val_inchikeys = set(val_df["inchikey"])
665
+ train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
666
+
667
+ return {
668
+ "train": train_df.reset_index(drop=True),
669
+ "validation": val_df.reset_index(drop=True),
670
+ }
src/utils.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These MolStandardizer classes are due to Paolo Tosco
2
+ ## It was taken from the FS-Mol github
3
+ ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
+ ## standardizer.py)
5
+ ## They ensure that a sequence of standardization operations are applied
6
+ ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
+
8
+ import os
9
+ import pickle
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+ from rdkit import Chem
15
+ from rdkit.Chem.MolStandardize import rdMolStandardize
16
+
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
18
+ PAD_VALUE = -100
19
+ TOX_SMARTS_PATH = "data/tox_smarts.json"
20
+
21
+ TASKS = [
22
+ "NR-AR",
23
+ "NR-AR-LBD",
24
+ "NR-AhR",
25
+ "NR-Aromatase",
26
+ "NR-ER",
27
+ "NR-ER-LBD",
28
+ "NR-PPAR-gamma",
29
+ "SR-ARE",
30
+ "SR-ATAD5",
31
+ "SR-HSE",
32
+ "SR-MMP",
33
+ "SR-p53",
34
+ ]
35
+
36
+ USED_200_DESCR = [
37
+ 0,
38
+ 1,
39
+ 2,
40
+ 3,
41
+ 4,
42
+ 5,
43
+ 6,
44
+ 7,
45
+ 8,
46
+ 9,
47
+ 10,
48
+ 11,
49
+ 12,
50
+ 13,
51
+ 14,
52
+ 15,
53
+ 16,
54
+ 25,
55
+ 26,
56
+ 27,
57
+ 28,
58
+ 29,
59
+ 30,
60
+ 31,
61
+ 32,
62
+ 33,
63
+ 34,
64
+ 35,
65
+ 36,
66
+ 37,
67
+ 38,
68
+ 39,
69
+ 40,
70
+ 41,
71
+ 42,
72
+ 43,
73
+ 44,
74
+ 45,
75
+ 46,
76
+ 47,
77
+ 48,
78
+ 49,
79
+ 50,
80
+ 51,
81
+ 52,
82
+ 53,
83
+ 54,
84
+ 55,
85
+ 56,
86
+ 57,
87
+ 58,
88
+ 59,
89
+ 60,
90
+ 61,
91
+ 62,
92
+ 63,
93
+ 64,
94
+ 65,
95
+ 66,
96
+ 67,
97
+ 68,
98
+ 69,
99
+ 70,
100
+ 71,
101
+ 72,
102
+ 73,
103
+ 74,
104
+ 75,
105
+ 76,
106
+ 77,
107
+ 78,
108
+ 79,
109
+ 80,
110
+ 81,
111
+ 82,
112
+ 83,
113
+ 84,
114
+ 85,
115
+ 86,
116
+ 87,
117
+ 88,
118
+ 89,
119
+ 90,
120
+ 91,
121
+ 92,
122
+ 93,
123
+ 94,
124
+ 95,
125
+ 96,
126
+ 97,
127
+ 98,
128
+ 99,
129
+ 100,
130
+ 101,
131
+ 102,
132
+ 103,
133
+ 104,
134
+ 105,
135
+ 106,
136
+ 107,
137
+ 108,
138
+ 109,
139
+ 110,
140
+ 111,
141
+ 112,
142
+ 113,
143
+ 114,
144
+ 115,
145
+ 116,
146
+ 117,
147
+ 118,
148
+ 119,
149
+ 120,
150
+ 121,
151
+ 122,
152
+ 123,
153
+ 124,
154
+ 125,
155
+ 126,
156
+ 127,
157
+ 128,
158
+ 129,
159
+ 130,
160
+ 131,
161
+ 132,
162
+ 133,
163
+ 134,
164
+ 135,
165
+ 136,
166
+ 137,
167
+ 138,
168
+ 139,
169
+ 140,
170
+ 141,
171
+ 142,
172
+ 143,
173
+ 144,
174
+ 145,
175
+ 146,
176
+ 147,
177
+ 148,
178
+ 149,
179
+ 150,
180
+ 151,
181
+ 152,
182
+ 153,
183
+ 154,
184
+ 155,
185
+ 156,
186
+ 157,
187
+ 158,
188
+ 159,
189
+ 160,
190
+ 161,
191
+ 162,
192
+ 163,
193
+ 164,
194
+ 165,
195
+ 166,
196
+ 167,
197
+ 168,
198
+ 169,
199
+ 170,
200
+ 171,
201
+ 172,
202
+ 173,
203
+ 174,
204
+ 175,
205
+ 176,
206
+ 177,
207
+ 178,
208
+ 179,
209
+ 180,
210
+ 181,
211
+ 182,
212
+ 183,
213
+ 184,
214
+ 185,
215
+ 186,
216
+ 187,
217
+ 188,
218
+ 189,
219
+ 190,
220
+ 191,
221
+ 192,
222
+ 193,
223
+ 194,
224
+ 195,
225
+ 196,
226
+ 197,
227
+ 198,
228
+ 199,
229
+ 200,
230
+ 201,
231
+ 202,
232
+ 203,
233
+ 204,
234
+ 205,
235
+ 206,
236
+ 207,
237
+ ]
238
+
239
+
240
+ class Standardizer:
241
+ """
242
+ Simple wrapper class around rdkit Standardizer.
243
+ """
244
+
245
+ DEFAULT_CANON_TAUT = False
246
+ DEFAULT_METAL_DISCONNECT = False
247
+ MAX_TAUTOMERS = 100
248
+ MAX_TRANSFORMS = 100
249
+ MAX_RESTARTS = 200
250
+ PREFER_ORGANIC = True
251
+
252
+ def __init__(
253
+ self,
254
+ metal_disconnect=None,
255
+ canon_taut=None,
256
+ ):
257
+ """
258
+ Constructor.
259
+ All parameters are optional.
260
+ :param metal_disconnect: if True, metallorganic complexes are
261
+ disconnected
262
+ :param canon_taut: if True, molecules are converted to their
263
+ canonical tautomer
264
+ """
265
+ super().__init__()
266
+ if metal_disconnect is None:
267
+ metal_disconnect = self.DEFAULT_METAL_DISCONNECT
268
+ if canon_taut is None:
269
+ canon_taut = self.DEFAULT_CANON_TAUT
270
+ self._canon_taut = canon_taut
271
+ self._metal_disconnect = metal_disconnect
272
+ self._taut_enumerator = None
273
+ self._uncharger = None
274
+ self._lfrag_chooser = None
275
+ self._metal_disconnector = None
276
+ self._normalizer = None
277
+ self._reionizer = None
278
+ self._params = None
279
+
280
+ @property
281
+ def params(self):
282
+ """Return the MolStandardize CleanupParameters."""
283
+ if self._params is None:
284
+ self._params = rdMolStandardize.CleanupParameters()
285
+ self._params.maxTautomers = self.MAX_TAUTOMERS
286
+ self._params.maxTransforms = self.MAX_TRANSFORMS
287
+ self._params.maxRestarts = self.MAX_RESTARTS
288
+ self._params.preferOrganic = self.PREFER_ORGANIC
289
+ self._params.tautomerRemoveSp3Stereo = False
290
+ return self._params
291
+
292
+ @property
293
+ def canon_taut(self):
294
+ """Return whether tautomer canonicalization will be done."""
295
+ return self._canon_taut
296
+
297
+ @property
298
+ def metal_disconnect(self):
299
+ """Return whether metallorganic complexes will be disconnected."""
300
+ return self._metal_disconnect
301
+
302
+ @property
303
+ def taut_enumerator(self):
304
+ """Return the TautomerEnumerator object."""
305
+ if self._taut_enumerator is None:
306
+ self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
307
+ return self._taut_enumerator
308
+
309
+ @property
310
+ def uncharger(self):
311
+ """Return the Uncharger object."""
312
+ if self._uncharger is None:
313
+ self._uncharger = rdMolStandardize.Uncharger()
314
+ return self._uncharger
315
+
316
+ @property
317
+ def lfrag_chooser(self):
318
+ """Return the LargestFragmentChooser object."""
319
+ if self._lfrag_chooser is None:
320
+ self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
321
+ self.params.preferOrganic
322
+ )
323
+ return self._lfrag_chooser
324
+
325
+ @property
326
+ def metal_disconnector(self):
327
+ """Return the MetalDisconnector object."""
328
+ if self._metal_disconnector is None:
329
+ self._metal_disconnector = rdMolStandardize.MetalDisconnector()
330
+ return self._metal_disconnector
331
+
332
+ @property
333
+ def normalizer(self):
334
+ """Return the Normalizer object."""
335
+ if self._normalizer is None:
336
+ self._normalizer = rdMolStandardize.Normalizer(
337
+ self.params.normalizationsFile, self.params.maxRestarts
338
+ )
339
+ return self._normalizer
340
+
341
+ @property
342
+ def reionizer(self):
343
+ """Return the Reionizer object."""
344
+ if self._reionizer is None:
345
+ self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
346
+ return self._reionizer
347
+
348
+ def charge_parent(self, mol_in):
349
+ """Sequentially apply a series of MolStandardize operations:
350
+ * MetalDisconnector
351
+ * Normalizer
352
+ * Reionizer
353
+ * LargestFragmentChooser
354
+ * Uncharger
355
+ The net result is that a desalted, normalized, neutral
356
+ molecule with implicit Hs is returned.
357
+ """
358
+ params = Chem.RemoveHsParameters()
359
+ params.removeAndTrackIsotopes = True
360
+ mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
361
+ if self._metal_disconnect:
362
+ mol_in = self.metal_disconnector.Disconnect(mol_in)
363
+ normalized = self.normalizer.normalize(mol_in)
364
+ Chem.SanitizeMol(normalized)
365
+ normalized = self.reionizer.reionize(normalized)
366
+ Chem.AssignStereochemistry(normalized)
367
+ normalized = self.lfrag_chooser.choose(normalized)
368
+ normalized = self.uncharger.uncharge(normalized)
369
+ # need this to reassess aromaticity on things like
370
+ # cyclopentadienyl, tropylium, azolium, etc.
371
+ Chem.SanitizeMol(normalized)
372
+ return Chem.RemoveHs(Chem.AddHs(normalized))
373
+
374
+ def standardize_mol(self, mol_in):
375
+ """
376
+ Standardize a single molecule.
377
+ :param mol_in: a Chem.Mol
378
+ :return: * (standardized Chem.Mol, n_taut) tuple
379
+ if success. n_taut will be negative if
380
+ tautomer enumeration was aborted due
381
+ to reaching a limit
382
+ * (None, error_msg) if failure
383
+ This calls self.charge_parent() and, if self._canon_taut
384
+ is True, runs tautomer canonicalization.
385
+ """
386
+ n_tautomers = 0
387
+ if isinstance(mol_in, Chem.Mol):
388
+ name = None
389
+ try:
390
+ name = mol_in.GetProp("_Name")
391
+ except KeyError:
392
+ pass
393
+ if not name:
394
+ name = "NONAME"
395
+ else:
396
+ error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
397
+ return None, error
398
+ try:
399
+ mol_out = self.charge_parent(mol_in)
400
+ except Exception as e:
401
+ error = f"charge_parent FAILED: {str(e).strip()}"
402
+ return None, error
403
+ if self._canon_taut:
404
+ try:
405
+ res = self.taut_enumerator.Enumerate(mol_out, False)
406
+ except TypeError:
407
+ # we are still on the pre-2021 RDKit API
408
+ res = self.taut_enumerator.Enumerate(mol_out)
409
+ except Exception as e:
410
+ # something else went wrong
411
+ error = f"canon_taut FAILED: {str(e).strip()}"
412
+ return None, error
413
+ n_tautomers = len(res)
414
+ if hasattr(res, "status"):
415
+ completed = (
416
+ res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
417
+ )
418
+ else:
419
+ # we are still on the pre-2021 RDKit API
420
+ completed = len(res) < 1000
421
+ if not completed:
422
+ n_tautomers = -n_tautomers
423
+ try:
424
+ mol_out = self.taut_enumerator.PickCanonical(res)
425
+ except AttributeError:
426
+ # we are still on the pre-2021 RDKit API
427
+ mol_out = max(
428
+ [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
429
+ )[1]
430
+ except Exception as e:
431
+ # something else went wrong
432
+ error = f"canon_taut FAILED: {str(e).strip()}"
433
+ return None, error
434
+ mol_out.SetProp("_Name", name)
435
+ return mol_out, n_tautomers
436
+
437
+
438
+ class FeatureDictMixin:
439
+ """
440
+ Mixin that enables bidirectional handling of dict-based multi-feature inputs.
441
+ Allows selective removal of columns directly from the combined array.
442
+
443
+ Example input:
444
+ {
445
+ "ecfps": np.ndarray,
446
+ "tox": np.ndarray,
447
+ }
448
+ """
449
+
450
+ def __init__(self, feature_keys=None):
451
+ self.feature_keys = feature_keys
452
+ self._curr_keys = None
453
+ self._unused_data = None
454
+
455
+ def dict_to_array(self, input: dict[Any, np.ndarray]) -> np.ndarray:
456
+ """Parse dict input and concatenate into a single array."""
457
+ if not isinstance(input, dict):
458
+ raise TypeError("Input must be a dict {feature_type: np.ndarray, ...}")
459
+
460
+ self._unused_data = {}
461
+ remaining_input = {}
462
+ for key in list(input.keys()):
463
+ if key not in self.feature_keys:
464
+ self._unused_data[key] = input[key]
465
+ else:
466
+ remaining_input[key] = input[key]
467
+
468
+ curr_keys = []
469
+ output = []
470
+ for key in self.feature_keys:
471
+ array = remaining_input.pop(key)
472
+ if array.ndim != 2:
473
+ raise ValueError(f"Feature '{key}' must be 2D, got shape {array.shape}")
474
+
475
+ curr_keys.extend([key] * array.shape[1])
476
+ output.append(array)
477
+
478
+ self._curr_keys = np.array(curr_keys)
479
+
480
+ return np.concatenate(output, axis=1)
481
+
482
+ def array_to_dict(self, input: np.ndarray) -> dict[Any, np.ndarray]:
483
+ """Reconstruct dict from a concatenated array."""
484
+ if self._curr_keys is None:
485
+ raise ValueError("No feature mapping stored. Did you call parse_input()?")
486
+
487
+ output = {key: input[:, self._curr_keys == key] for key in self.feature_keys}
488
+ output.update(self._unused_data)
489
+
490
+ self._curr_keys = None
491
+ self._unused_data = None
492
+ return output
493
+
494
+
495
+ def load_pickle(path: str):
496
+ with open(path, "rb") as file:
497
+ content = pickle.load(file)
498
+ return content
499
+
500
+
501
+ def write_pickle(path: str, obj: object):
502
+ with open(path, "wb") as file:
503
+ pickle.dump(obj, file)
504
+
505
+
506
+ def create_dir(path, is_file=False):
507
+ """Creates the parent directories if a path to a file is given, else create the given directory"""
508
+
509
+ to_create = os.path.dirname(path) if is_file else path
510
+ if not os.path.exists(to_create):
511
+ os.makedirs(to_create)
512
+
513
+
514
+ def normalize_config(config: dict):
515
+ """Normalizes a json config recursively by applying a mapping"""
516
+ mapping = {"none": None, "true": True, "false": False}
517
+ new_config = {}
518
+ for key, val in config.items():
519
+ if isinstance(val, dict):
520
+ new_config[key] = normalize_config(val)
521
+ elif isinstance(val, (int, float, str)) and val in mapping:
522
+ new_config[key] = mapping[val]
523
+ else:
524
+ new_config[key] = val
525
+ return new_config