project-monai commited on
Commit
fd4ffa6
·
verified ·
1 Parent(s): e18e422

Upload vista2d version 0.3.1

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ download_preprocessor/cellpose_agreement.png filter=lfs diff=lfs merge=lfs -text
37
+ download_preprocessor/cellpose_links.png filter=lfs diff=lfs merge=lfs -text
38
+ download_preprocessor/kaggle_download.png filter=lfs diff=lfs merge=lfs -text
39
+ download_preprocessor/omnipose_download.png filter=lfs diff=lfs merge=lfs -text
40
+ download_preprocessor/tissuenet_download.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Code License
2
+
3
+ This license applies to all files except the model weights in the directory.
4
+
5
+ Apache License
6
+ Version 2.0, January 2004
7
+ http://www.apache.org/licenses/
8
+
9
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10
+
11
+ 1. Definitions.
12
+
13
+ "License" shall mean the terms and conditions for use, reproduction,
14
+ and distribution as defined by Sections 1 through 9 of this document.
15
+
16
+ "Licensor" shall mean the copyright owner or entity authorized by
17
+ the copyright owner that is granting the License.
18
+
19
+ "Legal Entity" shall mean the union of the acting entity and all
20
+ other entities that control, are controlled by, or are under common
21
+ control with that entity. For the purposes of this definition,
22
+ "control" means (i) the power, direct or indirect, to cause the
23
+ direction or management of such entity, whether by contract or
24
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
25
+ outstanding shares, or (iii) beneficial ownership of such entity.
26
+
27
+ "You" (or "Your") shall mean an individual or Legal Entity
28
+ exercising permissions granted by this License.
29
+
30
+ "Source" form shall mean the preferred form for making modifications,
31
+ including but not limited to software source code, documentation
32
+ source, and configuration files.
33
+
34
+ "Object" form shall mean any form resulting from mechanical
35
+ transformation or translation of a Source form, including but
36
+ not limited to compiled object code, generated documentation,
37
+ and conversions to other media types.
38
+
39
+ "Work" shall mean the work of authorship, whether in Source or
40
+ Object form, made available under the License, as indicated by a
41
+ copyright notice that is included in or attached to the work
42
+ (an example is provided in the Appendix below).
43
+
44
+ "Derivative Works" shall mean any work, whether in Source or Object
45
+ form, that is based on (or derived from) the Work and for which the
46
+ editorial revisions, annotations, elaborations, or other modifications
47
+ represent, as a whole, an original work of authorship. For the purposes
48
+ of this License, Derivative Works shall not include works that remain
49
+ separable from, or merely link (or bind by name) to the interfaces of,
50
+ the Work and Derivative Works thereof.
51
+
52
+ "Contribution" shall mean any work of authorship, including
53
+ the original version of the Work and any modifications or additions
54
+ to that Work or Derivative Works thereof, that is intentionally
55
+ submitted to Licensor for inclusion in the Work by the copyright owner
56
+ or by an individual or Legal Entity authorized to submit on behalf of
57
+ the copyright owner. For the purposes of this definition, "submitted"
58
+ means any form of electronic, verbal, or written communication sent
59
+ to the Licensor or its representatives, including but not limited to
60
+ communication on electronic mailing lists, source code control systems,
61
+ and issue tracking systems that are managed by, or on behalf of, the
62
+ Licensor for the purpose of discussing and improving the Work, but
63
+ excluding communication that is conspicuously marked or otherwise
64
+ designated in writing by the copyright owner as "Not a Contribution."
65
+
66
+ "Contributor" shall mean Licensor and any individual or Legal Entity
67
+ on behalf of whom a Contribution has been received by Licensor and
68
+ subsequently incorporated within the Work.
69
+
70
+ 2. Grant of Copyright License. Subject to the terms and conditions of
71
+ this License, each Contributor hereby grants to You a perpetual,
72
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73
+ copyright license to reproduce, prepare Derivative Works of,
74
+ publicly display, publicly perform, sublicense, and distribute the
75
+ Work and such Derivative Works in Source or Object form.
76
+
77
+ 3. Grant of Patent License. Subject to the terms and conditions of
78
+ this License, each Contributor hereby grants to You a perpetual,
79
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80
+ (except as stated in this section) patent license to make, have made,
81
+ use, offer to sell, sell, import, and otherwise transfer the Work,
82
+ where such license applies only to those patent claims licensable
83
+ by such Contributor that are necessarily infringed by their
84
+ Contribution(s) alone or by combination of their Contribution(s)
85
+ with the Work to which such Contribution(s) was submitted. If You
86
+ institute patent litigation against any entity (including a
87
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
88
+ or a Contribution incorporated within the Work constitutes direct
89
+ or contributory patent infringement, then any patent licenses
90
+ granted to You under this License for that Work shall terminate
91
+ as of the date such litigation is filed.
92
+
93
+ 4. Redistribution. You may reproduce and distribute copies of the
94
+ Work or Derivative Works thereof in any medium, with or without
95
+ modifications, and in Source or Object form, provided that You
96
+ meet the following conditions:
97
+
98
+ (a) You must give any other recipients of the Work or
99
+ Derivative Works a copy of this License; and
100
+
101
+ (b) You must cause any modified files to carry prominent notices
102
+ stating that You changed the files; and
103
+
104
+ (c) You must retain, in the Source form of any Derivative Works
105
+ that You distribute, all copyright, patent, trademark, and
106
+ attribution notices from the Source form of the Work,
107
+ excluding those notices that do not pertain to any part of
108
+ the Derivative Works; and
109
+
110
+ (d) If the Work includes a "NOTICE" text file as part of its
111
+ distribution, then any Derivative Works that You distribute must
112
+ include a readable copy of the attribution notices contained
113
+ within such NOTICE file, excluding those notices that do not
114
+ pertain to any part of the Derivative Works, in at least one
115
+ of the following places: within a NOTICE text file distributed
116
+ as part of the Derivative Works; within the Source form or
117
+ documentation, if provided along with the Derivative Works; or,
118
+ within a display generated by the Derivative Works, if and
119
+ wherever such third-party notices normally appear. The contents
120
+ of the NOTICE file are for informational purposes only and
121
+ do not modify the License. You may add Your own attribution
122
+ notices within Derivative Works that You distribute, alongside
123
+ or as an addendum to the NOTICE text from the Work, provided
124
+ that such additional attribution notices cannot be construed
125
+ as modifying the License.
126
+
127
+ You may add Your own copyright statement to Your modifications and
128
+ may provide additional or different license terms and conditions
129
+ for use, reproduction, or distribution of Your modifications, or
130
+ for any such Derivative Works as a whole, provided Your use,
131
+ reproduction, and distribution of the Work otherwise complies with
132
+ the conditions stated in this License.
133
+
134
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
135
+ any Contribution intentionally submitted for inclusion in the Work
136
+ by You to the Licensor shall be under the terms and conditions of
137
+ this License, without any additional terms or conditions.
138
+ Notwithstanding the above, nothing herein shall supersede or modify
139
+ the terms of any separate license agreement you may have executed
140
+ with Licensor regarding such Contributions.
141
+
142
+ 6. Trademarks. This License does not grant permission to use the trade
143
+ names, trademarks, service marks, or product names of the Licensor,
144
+ except as required for reasonable and customary use in describing the
145
+ origin of the Work and reproducing the content of the NOTICE file.
146
+
147
+ 7. Disclaimer of Warranty. Unless required by applicable law or
148
+ agreed to in writing, Licensor provides the Work (and each
149
+ Contributor provides its Contributions) on an "AS IS" BASIS,
150
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151
+ implied, including, without limitation, any warranties or conditions
152
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153
+ PARTICULAR PURPOSE. You are solely responsible for determining the
154
+ appropriateness of using or redistributing the Work and assume any
155
+ risks associated with Your exercise of permissions under this License.
156
+
157
+ 8. Limitation of Liability. In no event and under no legal theory,
158
+ whether in tort (including negligence), contract, or otherwise,
159
+ unless required by applicable law (such as deliberate and grossly
160
+ negligent acts) or agreed to in writing, shall any Contributor be
161
+ liable to You for damages, including any direct, indirect, special,
162
+ incidental, or consequential damages of any character arising as a
163
+ result of this License or out of the use or inability to use the
164
+ Work (including but not limited to damages for loss of goodwill,
165
+ work stoppage, computer failure or malfunction, or any and all
166
+ other commercial damages or losses), even if such Contributor
167
+ has been advised of the possibility of such damages.
168
+
169
+ 9. Accepting Warranty or Additional Liability. While redistributing
170
+ the Work or Derivative Works thereof, You may choose to offer,
171
+ and charge a fee for, acceptance of support, warranty, indemnity,
172
+ or other liability obligations and/or rights consistent with this
173
+ License. However, in accepting such obligations, You may act only
174
+ on Your own behalf and on Your sole responsibility, not on behalf
175
+ of any other Contributor, and only if You agree to indemnify,
176
+ defend, and hold each Contributor harmless for any liability
177
+ incurred by, or claims asserted against, such Contributor by reason
178
+ of your accepting any such warranty or additional liability.
179
+
180
+ END OF TERMS AND CONDITIONS
181
+
182
+ APPENDIX: How to apply the Apache License to your work.
183
+
184
+ To apply the Apache License to your work, attach the following
185
+ boilerplate notice, with the fields enclosed by brackets "[]"
186
+ replaced with your own identifying information. (Don't include
187
+ the brackets!) The text should be enclosed in the appropriate
188
+ comment syntax for the file format. We also recommend that a
189
+ file or class name and description of purpose be included on the
190
+ same "printed page" as the copyright notice for easier
191
+ identification within third-party archives.
192
+
193
+ Copyright [yyyy] [name of copyright owner]
194
+
195
+ Licensed under the Apache License, Version 2.0 (the "License");
196
+ you may not use this file except in compliance with the License.
197
+ You may obtain a copy of the License at
198
+
199
+ http://www.apache.org/licenses/LICENSE-2.0
200
+
201
+ Unless required by applicable law or agreed to in writing, software
202
+ distributed under the License is distributed on an "AS IS" BASIS,
203
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
204
+ See the License for the specific language governing permissions and
205
+ limitations under the License.
206
+
207
+ ------------------------------------------------------------------------------
208
+
209
+ Model Weights License
210
+
211
+ This license applies to model weights in the directory.
212
+
213
+ Attribution-NonCommercial-ShareAlike 4.0 International
214
+
215
+ =======================================================================
216
+
217
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
218
+ does not provide legal services or legal advice. Distribution of
219
+ Creative Commons public licenses does not create a lawyer-client or
220
+ other relationship. Creative Commons makes its licenses and related
221
+ information available on an "as-is" basis. Creative Commons gives no
222
+ warranties regarding its licenses, any material licensed under their
223
+ terms and conditions, or any related information. Creative Commons
224
+ disclaims all liability for damages resulting from their use to the
225
+ fullest extent possible.
226
+
227
+ Using Creative Commons Public Licenses
228
+
229
+ Creative Commons public licenses provide a standard set of terms and
230
+ conditions that creators and other rights holders may use to share
231
+ original works of authorship and other material subject to copyright
232
+ and certain other rights specified in the public license below. The
233
+ following considerations are for informational purposes only, are not
234
+ exhaustive, and do not form part of our licenses.
235
+
236
+ Considerations for licensors: Our public licenses are
237
+ intended for use by those authorized to give the public
238
+ permission to use material in ways otherwise restricted by
239
+ copyright and certain other rights. Our licenses are
240
+ irrevocable. Licensors should read and understand the terms
241
+ and conditions of the license they choose before applying it.
242
+ Licensors should also secure all rights necessary before
243
+ applying our licenses so that the public can reuse the
244
+ material as expected. Licensors should clearly mark any
245
+ material not subject to the license. This includes other CC-
246
+ licensed material, or material used under an exception or
247
+ limitation to copyright. More considerations for licensors:
248
+ wiki.creativecommons.org/Considerations_for_licensors
249
+
250
+ Considerations for the public: By using one of our public
251
+ licenses, a licensor grants the public permission to use the
252
+ licensed material under specified terms and conditions. If
253
+ the licensor's permission is not necessary for any reason--for
254
+ example, because of any applicable exception or limitation to
255
+ copyright--then that use is not regulated by the license. Our
256
+ licenses grant only permissions under copyright and certain
257
+ other rights that a licensor has authority to grant. Use of
258
+ the licensed material may still be restricted for other
259
+ reasons, including because others have copyright or other
260
+ rights in the material. A licensor may make special requests,
261
+ such as asking that all changes be marked or described.
262
+ Although not required by our licenses, you are encouraged to
263
+ respect those requests where reasonable. More considerations
264
+ for the public:
265
+ wiki.creativecommons.org/Considerations_for_licensees
266
+
267
+ =======================================================================
268
+
269
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
270
+ Public License
271
+
272
+ By exercising the Licensed Rights (defined below), You accept and agree
273
+ to be bound by the terms and conditions of this Creative Commons
274
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
275
+ ("Public License"). To the extent this Public License may be
276
+ interpreted as a contract, You are granted the Licensed Rights in
277
+ consideration of Your acceptance of these terms and conditions, and the
278
+ Licensor grants You such rights in consideration of benefits the
279
+ Licensor receives from making the Licensed Material available under
280
+ these terms and conditions.
281
+
282
+
283
+ Section 1 -- Definitions.
284
+
285
+ a. Adapted Material means material subject to Copyright and Similar
286
+ Rights that is derived from or based upon the Licensed Material
287
+ and in which the Licensed Material is translated, altered,
288
+ arranged, transformed, or otherwise modified in a manner requiring
289
+ permission under the Copyright and Similar Rights held by the
290
+ Licensor. For purposes of this Public License, where the Licensed
291
+ Material is a musical work, performance, or sound recording,
292
+ Adapted Material is always produced where the Licensed Material is
293
+ synched in timed relation with a moving image.
294
+
295
+ b. Adapter's License means the license You apply to Your Copyright
296
+ and Similar Rights in Your contributions to Adapted Material in
297
+ accordance with the terms and conditions of this Public License.
298
+
299
+ c. BY-NC-SA Compatible License means a license listed at
300
+ creativecommons.org/compatiblelicenses, approved by Creative
301
+ Commons as essentially the equivalent of this Public License.
302
+
303
+ d. Copyright and Similar Rights means copyright and/or similar rights
304
+ closely related to copyright including, without limitation,
305
+ performance, broadcast, sound recording, and Sui Generis Database
306
+ Rights, without regard to how the rights are labeled or
307
+ categorized. For purposes of this Public License, the rights
308
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
309
+ Rights.
310
+
311
+ e. Effective Technological Measures means those measures that, in the
312
+ absence of proper authority, may not be circumvented under laws
313
+ fulfilling obligations under Article 11 of the WIPO Copyright
314
+ Treaty adopted on December 20, 1996, and/or similar international
315
+ agreements.
316
+
317
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
318
+ any other exception or limitation to Copyright and Similar Rights
319
+ that applies to Your use of the Licensed Material.
320
+
321
+ g. License Elements means the license attributes listed in the name
322
+ of a Creative Commons Public License. The License Elements of this
323
+ Public License are Attribution, NonCommercial, and ShareAlike.
324
+
325
+ h. Licensed Material means the artistic or literary work, database,
326
+ or other material to which the Licensor applied this Public
327
+ License.
328
+
329
+ i. Licensed Rights means the rights granted to You subject to the
330
+ terms and conditions of this Public License, which are limited to
331
+ all Copyright and Similar Rights that apply to Your use of the
332
+ Licensed Material and that the Licensor has authority to license.
333
+
334
+ j. Licensor means the individual(s) or entity(ies) granting rights
335
+ under this Public License.
336
+
337
+ k. NonCommercial means not primarily intended for or directed towards
338
+ commercial advantage or monetary compensation. For purposes of
339
+ this Public License, the exchange of the Licensed Material for
340
+ other material subject to Copyright and Similar Rights by digital
341
+ file-sharing or similar means is NonCommercial provided there is
342
+ no payment of monetary compensation in connection with the
343
+ exchange.
344
+
345
+ l. Share means to provide material to the public by any means or
346
+ process that requires permission under the Licensed Rights, such
347
+ as reproduction, public display, public performance, distribution,
348
+ dissemination, communication, or importation, and to make material
349
+ available to the public including in ways that members of the
350
+ public may access the material from a place and at a time
351
+ individually chosen by them.
352
+
353
+ m. Sui Generis Database Rights means rights other than copyright
354
+ resulting from Directive 96/9/EC of the European Parliament and of
355
+ the Council of 11 March 1996 on the legal protection of databases,
356
+ as amended and/or succeeded, as well as other essentially
357
+ equivalent rights anywhere in the world.
358
+
359
+ n. You means the individual or entity exercising the Licensed Rights
360
+ under this Public License. Your has a corresponding meaning.
361
+
362
+
363
+ Section 2 -- Scope.
364
+
365
+ a. License grant.
366
+
367
+ 1. Subject to the terms and conditions of this Public License,
368
+ the Licensor hereby grants You a worldwide, royalty-free,
369
+ non-sublicensable, non-exclusive, irrevocable license to
370
+ exercise the Licensed Rights in the Licensed Material to:
371
+
372
+ a. reproduce and Share the Licensed Material, in whole or
373
+ in part, for NonCommercial purposes only; and
374
+
375
+ b. produce, reproduce, and Share Adapted Material for
376
+ NonCommercial purposes only.
377
+
378
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
379
+ Exceptions and Limitations apply to Your use, this Public
380
+ License does not apply, and You do not need to comply with
381
+ its terms and conditions.
382
+
383
+ 3. Term. The term of this Public License is specified in Section
384
+ 6(a).
385
+
386
+ 4. Media and formats; technical modifications allowed. The
387
+ Licensor authorizes You to exercise the Licensed Rights in
388
+ all media and formats whether now known or hereafter created,
389
+ and to make technical modifications necessary to do so. The
390
+ Licensor waives and/or agrees not to assert any right or
391
+ authority to forbid You from making technical modifications
392
+ necessary to exercise the Licensed Rights, including
393
+ technical modifications necessary to circumvent Effective
394
+ Technological Measures. For purposes of this Public License,
395
+ simply making modifications authorized by this Section 2(a)
396
+ (4) never produces Adapted Material.
397
+
398
+ 5. Downstream recipients.
399
+
400
+ a. Offer from the Licensor -- Licensed Material. Every
401
+ recipient of the Licensed Material automatically
402
+ receives an offer from the Licensor to exercise the
403
+ Licensed Rights under the terms and conditions of this
404
+ Public License.
405
+
406
+ b. Additional offer from the Licensor -- Adapted Material.
407
+ Every recipient of Adapted Material from You
408
+ automatically receives an offer from the Licensor to
409
+ exercise the Licensed Rights in the Adapted Material
410
+ under the conditions of the Adapter's License You apply.
411
+
412
+ c. No downstream restrictions. You may not offer or impose
413
+ any additional or different terms or conditions on, or
414
+ apply any Effective Technological Measures to, the
415
+ Licensed Material if doing so restricts exercise of the
416
+ Licensed Rights by any recipient of the Licensed
417
+ Material.
418
+
419
+ 6. No endorsement. Nothing in this Public License constitutes or
420
+ may be construed as permission to assert or imply that You
421
+ are, or that Your use of the Licensed Material is, connected
422
+ with, or sponsored, endorsed, or granted official status by,
423
+ the Licensor or others designated to receive attribution as
424
+ provided in Section 3(a)(1)(A)(i).
425
+
426
+ b. Other rights.
427
+
428
+ 1. Moral rights, such as the right of integrity, are not
429
+ licensed under this Public License, nor are publicity,
430
+ privacy, and/or other similar personality rights; however, to
431
+ the extent possible, the Licensor waives and/or agrees not to
432
+ assert any such rights held by the Licensor to the limited
433
+ extent necessary to allow You to exercise the Licensed
434
+ Rights, but not otherwise.
435
+
436
+ 2. Patent and trademark rights are not licensed under this
437
+ Public License.
438
+
439
+ 3. To the extent possible, the Licensor waives any right to
440
+ collect royalties from You for the exercise of the Licensed
441
+ Rights, whether directly or through a collecting society
442
+ under any voluntary or waivable statutory or compulsory
443
+ licensing scheme. In all other cases the Licensor expressly
444
+ reserves any right to collect such royalties, including when
445
+ the Licensed Material is used other than for NonCommercial
446
+ purposes.
447
+
448
+
449
+ Section 3 -- License Conditions.
450
+
451
+ Your exercise of the Licensed Rights is expressly made subject to the
452
+ following conditions.
453
+
454
+ a. Attribution.
455
+
456
+ 1. If You Share the Licensed Material (including in modified
457
+ form), You must:
458
+
459
+ a. retain the following if it is supplied by the Licensor
460
+ with the Licensed Material:
461
+
462
+ i. identification of the creator(s) of the Licensed
463
+ Material and any others designated to receive
464
+ attribution, in any reasonable manner requested by
465
+ the Licensor (including by pseudonym if
466
+ designated);
467
+
468
+ ii. a copyright notice;
469
+
470
+ iii. a notice that refers to this Public License;
471
+
472
+ iv. a notice that refers to the disclaimer of
473
+ warranties;
474
+
475
+ v. a URI or hyperlink to the Licensed Material to the
476
+ extent reasonably practicable;
477
+
478
+ b. indicate if You modified the Licensed Material and
479
+ retain an indication of any previous modifications; and
480
+
481
+ c. indicate the Licensed Material is licensed under this
482
+ Public License, and include the text of, or the URI or
483
+ hyperlink to, this Public License.
484
+
485
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
486
+ reasonable manner based on the medium, means, and context in
487
+ which You Share the Licensed Material. For example, it may be
488
+ reasonable to satisfy the conditions by providing a URI or
489
+ hyperlink to a resource that includes the required
490
+ information.
491
+ 3. If requested by the Licensor, You must remove any of the
492
+ information required by Section 3(a)(1)(A) to the extent
493
+ reasonably practicable.
494
+
495
+ b. ShareAlike.
496
+
497
+ In addition to the conditions in Section 3(a), if You Share
498
+ Adapted Material You produce, the following conditions also apply.
499
+
500
+ 1. The Adapter's License You apply must be a Creative Commons
501
+ license with the same License Elements, this version or
502
+ later, or a BY-NC-SA Compatible License.
503
+
504
+ 2. You must include the text of, or the URI or hyperlink to, the
505
+ Adapter's License You apply. You may satisfy this condition
506
+ in any reasonable manner based on the medium, means, and
507
+ context in which You Share Adapted Material.
508
+
509
+ 3. You may not offer or impose any additional or different terms
510
+ or conditions on, or apply any Effective Technological
511
+ Measures to, Adapted Material that restrict exercise of the
512
+ rights granted under the Adapter's License You apply.
513
+
514
+
515
+ Section 4 -- Sui Generis Database Rights.
516
+
517
+ Where the Licensed Rights include Sui Generis Database Rights that
518
+ apply to Your use of the Licensed Material:
519
+
520
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
521
+ to extract, reuse, reproduce, and Share all or a substantial
522
+ portion of the contents of the database for NonCommercial purposes
523
+ only;
524
+
525
+ b. if You include all or a substantial portion of the database
526
+ contents in a database in which You have Sui Generis Database
527
+ Rights, then the database in which You have Sui Generis Database
528
+ Rights (but not its individual contents) is Adapted Material,
529
+ including for purposes of Section 3(b); and
530
+
531
+ c. You must comply with the conditions in Section 3(a) if You Share
532
+ all or a substantial portion of the contents of the database.
533
+
534
+ For the avoidance of doubt, this Section 4 supplements and does not
535
+ replace Your obligations under this Public License where the Licensed
536
+ Rights include other Copyright and Similar Rights.
537
+
538
+
539
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
540
+
541
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
542
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
543
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
544
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
545
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
546
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
547
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
548
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
549
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
550
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
551
+
552
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
553
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
554
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
555
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
556
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
557
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
558
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
559
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
560
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
561
+
562
+ c. The disclaimer of warranties and limitation of liability provided
563
+ above shall be interpreted in a manner that, to the extent
564
+ possible, most closely approximates an absolute disclaimer and
565
+ waiver of all liability.
566
+
567
+
568
+ Section 6 -- Term and Termination.
569
+
570
+ a. This Public License applies for the term of the Copyright and
571
+ Similar Rights licensed here. However, if You fail to comply with
572
+ this Public License, then Your rights under this Public License
573
+ terminate automatically.
574
+
575
+ b. Where Your right to use the Licensed Material has terminated under
576
+ Section 6(a), it reinstates:
577
+
578
+ 1. automatically as of the date the violation is cured, provided
579
+ it is cured within 30 days of Your discovery of the
580
+ violation; or
581
+
582
+ 2. upon express reinstatement by the Licensor.
583
+
584
+ For the avoidance of doubt, this Section 6(b) does not affect any
585
+ right the Licensor may have to seek remedies for Your violations
586
+ of this Public License.
587
+
588
+ c. For the avoidance of doubt, the Licensor may also offer the
589
+ Licensed Material under separate terms or conditions or stop
590
+ distributing the Licensed Material at any time; however, doing so
591
+ will not terminate this Public License.
592
+
593
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
594
+ License.
595
+
596
+
597
+ Section 7 -- Other Terms and Conditions.
598
+
599
+ a. The Licensor shall not be bound by any additional or different
600
+ terms or conditions communicated by You unless expressly agreed.
601
+
602
+ b. Any arrangements, understandings, or agreements regarding the
603
+ Licensed Material not stated herein are separate from and
604
+ independent of the terms and conditions of this Public License.
605
+
606
+
607
+ Section 8 -- Interpretation.
608
+
609
+ a. For the avoidance of doubt, this Public License does not, and
610
+ shall not be interpreted to, reduce, limit, restrict, or impose
611
+ conditions on any use of the Licensed Material that could lawfully
612
+ be made without permission under this Public License.
613
+
614
+ b. To the extent possible, if any provision of this Public License is
615
+ deemed unenforceable, it shall be automatically reformed to the
616
+ minimum extent necessary to make it enforceable. If the provision
617
+ cannot be reformed, it shall be severed from this Public License
618
+ without affecting the enforceability of the remaining terms and
619
+ conditions.
620
+
621
+ c. No term or condition of this Public License will be waived and no
622
+ failure to comply consented to unless expressly agreed to by the
623
+ Licensor.
624
+
625
+ d. Nothing in this Public License constitutes or may be interpreted
626
+ as a limitation upon, or waiver of, any privileges and immunities
627
+ that apply to the Licensor or You, including from the legal
628
+ processes of any jurisdiction or authority.
629
+
630
+ =======================================================================
631
+
632
+ Creative Commons is not a party to its public
633
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
634
+ its public licenses to material it publishes and in those instances
635
+ will be considered the “Licensor.” The text of the Creative Commons
636
+ public licenses is dedicated to the public domain under the CC0 Public
637
+ Domain Dedication. Except for the limited purpose of indicating that
638
+ material is shared under a Creative Commons public license or as
639
+ otherwise permitted by the Creative Commons policies published at
640
+ creativecommons.org/policies, Creative Commons does not authorize the
641
+ use of the trademark "Creative Commons" or any other trademark or logo
642
+ of Creative Commons without its prior written consent including,
643
+ without limitation, in connection with any unauthorized modifications
644
+ to any of its public licenses or any other arrangements,
645
+ understandings, or agreements concerning use of licensed material. For
646
+ the avoidance of doubt, this paragraph does not form part of the
647
+ public licenses.
648
+
649
+ Creative Commons may be contacted at creativecommons.org.
configs/hyper_parameters.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imports:
2
+ - $import os
3
+
4
+ # seed: 28022024 # uncommend for deterministic results (but slower)
5
+ seed: null
6
+
7
+ bundle_root: "."
8
+ ckpt_path: $os.path.join(@bundle_root, "models") # location to save checkpoints
9
+ output_dir: $os.path.join(@bundle_root, "eval") # location to save events and logs
10
+ log_output_file: $os.path.join(@output_dir, "vista_cell.log")
11
+
12
+ mlflow_tracking_uri: null # enable mlflow logging, e.g. $@ckpt_path + '/mlruns/ or "http://127.0.0.1:8080" or a remote url
13
+ mlflow_log_system_metrics: true # log system metrics to mlflow (requires: pip install psutil pynvml)
14
+ mlflow_run_name: null # optional name of the current run
15
+
16
+ ckpt_save: true # save checkpoints periodically
17
+ amp: true
18
+ amp_dtype: "float16" #float16 or bfloat16 (Ampere or newer)
19
+ channels_last: true
20
+ compile: false # complie the model for faster processing
21
+
22
+ start_epoch: 0
23
+ run_final_testing: true
24
+ use_weighted_sampler: false # only applicable when using several dataset jsons for data_list_files
25
+
26
+ pretrained_ckpt_name: null
27
+ pretrained_ckpt_path: null
28
+
29
+ # for commandline setting of a single dataset
30
+ datalist: datalists/cellpose_datalist.json
31
+ basedir: /cellpose_dataset
32
+ data_list_files:
33
+ - {datalist: "@datalist", basedir: "@basedir"}
34
+
35
+
36
+ fold: 0
37
+ learning_rate: 0.01 # try 1.0e-4 if using AdamW
38
+ quick: false # whether to use a small subset of data for quick testing
39
+ roi_size: [256, 256]
40
+
41
+ train:
42
+ skip: false
43
+ handlers: []
44
+ trainer:
45
+ num_warmup_epochs: 3
46
+ max_epochs: 200
47
+ num_epochs_per_saving: 1
48
+ num_epochs_per_validation: null
49
+ num_workers: 4
50
+ batch_size: 1
51
+ dataset:
52
+ preprocessing:
53
+ roi_size: "@roi_size"
54
+ data:
55
+ key: null # set to 'testing' to use this subset in periodic validations, instead of the the validation set
56
+ data_list_files: "@data_list_files"
57
+
58
+ dataset:
59
+ data:
60
+ key: "testing"
61
+ data_list_files: "@data_list_files"
62
+
63
+ validate:
64
+ grouping: true
65
+ evaluator:
66
+ postprocessing: "@postprocessing"
67
+ dataset:
68
+ data: "@dataset#data"
69
+ batch_size: 1
70
+ num_workers: 4
71
+ preprocessing: null
72
+ postprocessing: null
73
+ inferer: null
74
+ handlers: null
75
+ key_metric: null
76
+
77
+ infer:
78
+ evaluator:
79
+ postprocessing: "@postprocessing"
80
+ dataset:
81
+ data: "@dataset#data"
82
+
83
+
84
+ device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')"
85
+ network_def:
86
+ _target_: monai.networks.nets.cell_sam_wrapper.CellSamWrapper
87
+ checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth")
88
+ network: $@network_def.to(@device)
89
+
90
+ loss_function:
91
+ _target_: scripts.components.CellLoss
92
+
93
+ key_metric:
94
+ _target_: scripts.components.CellAcc
95
+
96
+ # optimizer:
97
+ # _target_: torch.optim.AdamW
98
+ # params: [email protected]()
99
+ # lr: "@learning_rate"
100
+ # weight_decay: 1.0e-5
101
+
102
+ optimizer:
103
+ _target_: torch.optim.SGD
104
+ params: [email protected]()
105
+ momentum: 0.9
106
+ lr: "@learning_rate"
107
+ weight_decay: 1.0e-5
108
+
109
+ lr_scheduler:
110
+ _target_: monai.optimizers.lr_scheduler.WarmupCosineSchedule
111
+ optimizer: "@optimizer"
112
+ warmup_steps: "@train#trainer#num_warmup_epochs"
113
+ warmup_multiplier: 0.1
114
+ t_total: "@train#trainer#max_epochs"
115
+
116
+ inferer:
117
+ sliding_inferer:
118
+ _target_: monai.inferers.SlidingWindowInfererAdapt
119
+ roi_size: "@roi_size"
120
+ sw_batch_size: 1
121
+ overlap: 0.625
122
+ mode: "gaussian"
123
+ cache_roi_weight_map: true
124
+ progress: false
125
+
126
+ image_saver:
127
+ _target_: scripts.components.SaveTiffd
128
+ keys: "seg"
129
+ output_dir: "@output_dir"
130
+ nested_folder: false
131
+
132
+ postprocessing:
133
+ _target_: monai.transforms.Compose
134
+ transforms:
135
+ - "@image_saver"
configs/inference.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import numpy as np"
4
+ ],
5
+ "bundle_root": ".",
6
+ "ckpt_dir": "$@bundle_root + '/models'",
7
+ "output_dir": "$@bundle_root + '/eval'",
8
+ "output_ext": ".tif",
9
+ "output_postfix": "trans",
10
+ "roi_size": [
11
+ 256,
12
+ 256
13
+ ],
14
+ "input_dict": "${'image': '/cellpose_dataset/test/001_img.png'}",
15
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
16
+ "sam_ckpt_path": "$@ckpt_dir + '/sam_vit_b_01ec64.pth'",
17
+ "pretrained_ckpt_path": "$@ckpt_dir + '/model.pt'",
18
+ "image_key": "image",
19
+ "channels_last": true,
20
+ "use_amp": true,
21
+ "amp_dtype": "$torch.float",
22
+ "network_def": {
23
+ "_target_": "monai.networks.nets.cell_sam_wrapper.CellSamWrapper",
24
+ "checkpoint": "@sam_ckpt_path"
25
+ },
26
+ "network": "$@network_def.to(@device)",
27
+ "preprocessing_transforms": [
28
+ {
29
+ "_target_": "scripts.components.LoadTiffd",
30
+ "keys": "@image_key"
31
+ },
32
+ {
33
+ "_target_": "EnsureTyped",
34
+ "keys": "@image_key",
35
+ "data_type": "tensor",
36
+ "dtype": "$torch.float"
37
+ },
38
+ {
39
+ "_target_": "ScaleIntensityd",
40
+ "keys": "@image_key",
41
+ "minv": 0,
42
+ "maxv": 1,
43
+ "channel_wise": true
44
+ },
45
+ {
46
+ "_target_": "ScaleIntensityRangePercentilesd",
47
+ "keys": "image",
48
+ "lower": 1,
49
+ "upper": 99,
50
+ "b_min": 0.0,
51
+ "b_max": 1.0,
52
+ "channel_wise": true,
53
+ "clip": true
54
+ }
55
+ ],
56
+ "preprocessing": {
57
+ "_target_": "Compose",
58
+ "transforms": "$@preprocessing_transforms "
59
+ },
60
+ "dataset": {
61
+ "_target_": "Dataset",
62
+ "data": "$[@input_dict]",
63
+ "transform": "@preprocessing"
64
+ },
65
+ "dataloader": {
66
+ "_target_": "ThreadDataLoader",
67
+ "dataset": "@dataset",
68
+ "batch_size": 1,
69
+ "shuffle": false,
70
+ "num_workers": 0
71
+ },
72
+ "inferer": {
73
+ "_target_": "SlidingWindowInfererAdapt",
74
+ "roi_size": "@roi_size",
75
+ "sw_batch_size": 1,
76
+ "overlap": 0.625,
77
+ "mode": "gaussian",
78
+ "cache_roi_weight_map": true,
79
+ "progress": false
80
+ },
81
+ "postprocessing": {
82
+ "_target_": "Compose",
83
+ "transforms": [
84
+ {
85
+ "_target_": "ToDeviced",
86
+ "keys": "pred",
87
+ "device": "cpu"
88
+ },
89
+ {
90
+ "_target_": "scripts.components.LogitsToLabelsd",
91
+ "keys": "pred"
92
+ },
93
+ {
94
+ "_target_": "scripts.components.SaveTiffExd",
95
+ "keys": "pred",
96
+ "output_dir": "@output_dir",
97
+ "output_ext": "@output_ext",
98
+ "output_postfix": "@output_postfix"
99
+ }
100
+ ]
101
+ },
102
+ "handlers": [
103
+ {
104
+ "_target_": "StatsHandler",
105
+ "iteration_log": false
106
+ }
107
+ ],
108
+ "checkpointloader": {
109
+ "_target_": "CheckpointLoader",
110
+ "load_path": "@pretrained_ckpt_path",
111
+ "map_location": "cpu",
112
+ "load_dict": {
113
+ "state_dict": "@network"
114
+ }
115
+ },
116
+ "evaluator": {
117
+ "_target_": "SupervisedEvaluator",
118
+ "device": "@device",
119
+ "val_data_loader": "@dataloader",
120
+ "network": "@network",
121
+ "inferer": "@inferer",
122
+ "postprocessing": "@postprocessing",
123
+ "val_handlers": "@handlers",
124
+ "amp": true
125
+ },
126
+ "initialize": [
127
+ "$monai.utils.set_determinism(seed=123)",
128
+ "$@checkpointloader(@evaluator)"
129
+ ],
130
+ "run": [
131
132
+ ]
133
+ }
configs/inference_trt.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import numpy",
4
+ "$from monai.networks import trt_compile"
5
+ ],
6
+ "trt_args": {
7
+ "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]"
8
+ },
9
+ "network": "$trt_compile(@network_def.to(@device), @pretrained_ckpt_path, args=@trt_args)"
10
+ }
configs/metadata.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3
+ "version": "0.3.1",
4
+ "changelog": {
5
+ "0.3.1": "update to huggingface hosting",
6
+ "0.3.0": "update readme",
7
+ "0.2.9": "fix unsupported data dtype in findContours",
8
+ "0.2.8": "remove relative path in readme",
9
+ "0.2.7": "enhance readme",
10
+ "0.2.6": "update tensorrt benchmark results",
11
+ "0.2.5": "add tensorrt benchmark results",
12
+ "0.2.4": "enable tensorrt inference",
13
+ "0.2.3": "update weights link",
14
+ "0.2.2": "update to use monai components",
15
+ "0.2.1": "initial OSS version"
16
+ },
17
+ "monai_version": "1.4.0",
18
+ "pytorch_version": "2.4.0",
19
+ "numpy_version": "1.24.4",
20
+ "required_packages_version": {
21
+ "einops": "0.7.0",
22
+ "scikit-image": "0.23.2",
23
+ "cucim-cu12": "24.6.0",
24
+ "gdown": "5.2.0",
25
+ "fire": "0.6.0",
26
+ "pyyaml": "6.0.1",
27
+ "tensorboard": "2.17.0",
28
+ "opencv-python": "4.7.0.68",
29
+ "numba": "0.59.1",
30
+ "torchvision": "0.19.0",
31
+ "cellpose": "3.0.8",
32
+ "natsort": "8.4.0",
33
+ "roifile": "2024.5.24",
34
+ "tifffile": "2024.7.2",
35
+ "fastremap": "1.15.0",
36
+ "imagecodecs": "2024.6.1",
37
+ "segment_anything": "1.0"
38
+ },
39
+ "optional_packages_version": {
40
+ "mlflow": "2.14.3",
41
+ "pynvml": "11.4.1",
42
+ "psutil": "5.9.8"
43
+ },
44
+ "supported_apps": {},
45
+ "name": "VISTA-Cell",
46
+ "task": "cell image segmentation",
47
+ "description": "VISTA2D bundle for cell image analysis",
48
+ "authors": "MONAI team",
49
+ "copyright": "Copyright (c) MONAI Consortium",
50
+ "data_type": "tiff",
51
+ "image_classes": "1 channel data, intensity scaled to [0, 1]",
52
+ "label_classes": "3-channel data",
53
+ "pred_classes": "3 channels",
54
+ "eval_metrics": {
55
+ "mean_dice": 0.0
56
+ },
57
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
58
+ "references": [],
59
+ "network_data_format": {
60
+ "inputs": {
61
+ "image": {
62
+ "type": "image",
63
+ "num_channels": 3,
64
+ "spatial_shape": [
65
+ 256,
66
+ 256
67
+ ],
68
+ "format": "RGB",
69
+ "value_range": [
70
+ 0,
71
+ 255
72
+ ],
73
+ "dtype": "float32",
74
+ "is_patch_data": true,
75
+ "channel_def": {
76
+ "0": "image"
77
+ }
78
+ }
79
+ },
80
+ "outputs": {
81
+ "pred": {
82
+ "type": "image",
83
+ "format": "segmentation",
84
+ "num_channels": 3,
85
+ "dtype": "float32",
86
+ "value_range": [
87
+ 0,
88
+ 1
89
+ ],
90
+ "spatial_shape": [
91
+ 256,
92
+ 256
93
+ ]
94
+ }
95
+ }
96
+ }
97
+ }
datalists.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b60b29bb320a2ddf2d534d150cc88fc9e2a4825044e7382ede78a2f0a557c9a9
3
+ size 619548
docs/README.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Overview
2
+
3
+ The **VISTA2D** is a cell segmentation training and inference pipeline for cell imaging [[`Blog`](https://developer.nvidia.com/blog/advancing-cell-segmentation-and-morphology-analysis-with-nvidia-ai-foundation-model-vista-2d/)].
4
+
5
+ A pretrained model was trained on collection of 15K public microscopy images. The data collection and training can be reproduced following the `download_preprocessor/`. Alternatively, the model can be retrained on your own dataset. The pretrained vista2d model achieves good performance on diverse set of cell types, microscopy image modalities, and can be further finetuned if necessary. The codebase utilizes several components from other great works including [SegmentAnything](https://github.com/facebookresearch/segment-anything) and [Cellpose](https://www.cellpose.org/), which must be pip installed as dependencies. Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).
6
+
7
+ <div align="center"> <img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/04/magnified-cells-1.png" width="800"/> </div>
8
+
9
+ ### Model highlights
10
+
11
+ - Robust deep learning algorithm based on transformers
12
+ - Generalist model as compared to specialist models
13
+ - Multiple dataset sources and file formats supported
14
+ - Multiple modalities of imaging data collectively supported
15
+ - Multi-GPU and multinode training support
16
+
17
+ ### Generalization performance
18
+
19
+ Evaluation was performed for the VISTA2D model with multiple public datasets, such as TissueNet, LIVECell, Omnipose, DeepBacs, Cellpose, and more. For more details about dataset licenses, please refer to `/docs/data_license.txt`. A total of ~15K annotated cell images were collected to train the generalist VISTA2D model. This ensured broad coverage of many different types of cells, which were acquired by various imaging acquisition types. The benchmark results of the experiment were performed on held-out test sets for each public dataset that were already defined by the dataset contributors. Average precision at an IoU threshold of 0.5 was used for evaluating performance. The benchmark results are reported in comparison with the best numbers found in the literature, in addition to a specialist VISTA2D model trained only on a particular dataset or a subset of data.
20
+
21
+ <div align="center"> <img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/04/vista-2d-model-precision-versus-specialist-model-baseline-performance.png" width="800"/> </div>
22
+
23
+ ### TensorRT speedup
24
+ The `vista2d` bundle supports acceleration with TensorRT. The table below displays the speedup ratios observed on an A100 80G GPU. Please note that 32-bit precision models are benchmarked with tf32 weight format.
25
+
26
+ | method | torch_tf32(ms) | torch_amp(ms) | trt_tf32(ms) | trt_fp16(ms) | speedup amp | speedup tf32 | speedup fp16 | amp vs fp16|
27
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
28
+ | model computation | 39.72 | 39.68 | 26.13 | 17.32 | 1.00 | 1.52 | 2.29 | 2.29 |
29
+ | end2end | 1562 | 1903 | 1494 | 1440 | 0.82 | 1.05 | 1.08 | 1.32|
30
+
31
+ Where:
32
+ - `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing
33
+ - `end2end` means run the bundle end-to-end with the TensorRT based model.
34
+ - `torch_tf32` and `torch_amp` are for the PyTorch models with or without `amp` mode.
35
+ - `trt_tf32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision.
36
+ - `speedup amp`, `speedup tf32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model
37
+ - `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model.
38
+
39
+ This result is benchmarked under:
40
+ - TensorRT: 10.3.0+cuda12.6
41
+ - Torch-TensorRT Version: 2.4.0
42
+ - CPU Architecture: x86-64
43
+ - OS: ubuntu 20.04
44
+ - Python version:3.10.12
45
+ - CUDA version: 12.6
46
+ - GPU models and configuration: A100 80G
47
+
48
+ ### Prepare Data Lists and Datasets
49
+
50
+ The default dataset for training, validation, and inference is the [Cellpose](https://www.cellpose.org/) dataset. Please follow the `download_preprocessor/` to prepare the dataset before executing any commands below.
51
+
52
+ Additionally, all data lists are available in the `datalists.zip` file located in the root directory of the bundle. Extract the contents of the `.zip` file to access the data lists.
53
+
54
+ ### Dependencies
55
+ Please refer to the `required_packages_version` section in `configs/metadata.json` to install all necessary dependencies before execution. If you’re using the MONAI container, you can simply run the commands below and ignore any "opencv-python-headless not installed" error message, as this package is already included in the container.
56
+
57
+ ```
58
+ pip install fastremap==1.15.0 roifile==2024.5.24 natsort==8.4.0
59
+ pip install --no-deps cellpose
60
+ ```
61
+
62
+ Important Note: if your environment already contains OpenCV, installing `cellpose` may lead to conflicts and produce errors such as:
63
+
64
+ ```
65
+ AttributeError: partially initialized module 'cv2' has no attribute 'dnn' (most likely due to a circular import)
66
+ ```
67
+
68
+ To resolve this, uninstall `OpenCV` first, and then install `cellpose` using the following commands:
69
+
70
+ ```Bash
71
+ pip uninstall -y opencv && rm /usr/local/lib/python3.*/dist-packages/cv2
72
+ ```
73
+ Make sure to replace 3.* with your actual Python version (e.g., 3.10).
74
+
75
+ Alternatively, you can install `cellpose` without its dependencies to avoid potential conflicts:
76
+
77
+ ```
78
+ pip install --no-deps cellpose
79
+ ```
80
+
81
+ ### Execute training
82
+ ```bash
83
+ python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml
84
+ ```
85
+
86
+ You can override the `basedir` to specify a different dataset directory by using the following command:
87
+
88
+ ```bash
89
+ python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --basedir <actual dataset ditectory>
90
+ ```
91
+
92
+ #### Quick run with a few data points
93
+ ```bash
94
+ python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --quick True --train#trainer#max_epochs 3
95
+ ```
96
+
97
+ ### Execute multi-GPU training
98
+ ```bash
99
+ torchrun --nproc_per_node=gpu -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml
100
+ ```
101
+
102
+ ### Execute validation
103
+ ```bash
104
+ python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --pretrained_ckpt_name model.pt --mode eval
105
+ ```
106
+ (can append `--quick True` for quick demoing)
107
+
108
+ ### Execute multi-GPU validation
109
+ ```bash
110
+ torchrun --nproc_per_node=gpu -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --mode eval
111
+ ```
112
+
113
+ ### Execute inference
114
+ ```bash
115
+ python -m monai.bundle run --config_file configs/inference.json
116
+ ```
117
+
118
+ Please note that the data used in this config file is: "/cellpose_dataset/test/001_img.png", if the dataset path is different or you want to do inference on another file, please modify in `configs/inference.json` accordingly.
119
+
120
+ #### Execute inference with the TensorRT model:
121
+
122
+ ```
123
+ python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']"
124
+ ```
125
+
126
+ ### Execute multi-GPU inference
127
+ ```bash
128
+ torchrun --nproc_per_node=gpu -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --mode infer --pretrained_ckpt_name model.pt
129
+ ```
130
+ (can append `--quick True` for quick demoing)
131
+
132
+
133
+
134
+ #### Finetune starting from a trained checkpoint
135
+ (we use a smaller learning rate, small number of epochs, and initialize from a checkpoint)
136
+ ```bash
137
+ python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --learning_rate=0.001 --train#trainer#max_epochs 20 --pretrained_ckpt_path /path/to/saved/model.pt
138
+ ```
139
+
140
+
141
+ #### Configuration options
142
+
143
+ To disable the segmentation writing:
144
+ ```
145
+ --postprocessing []
146
+ ```
147
+
148
+ Load a checkpoint for validation or inference (relative path within results directory):
149
+ ```
150
+ --pretrained_ckpt_name "model.pt"
151
+ ```
152
+
153
+ Load a checkpoint for validation or inference (absolute path):
154
+ ```
155
+ --pretrained_ckpt_path "/path/to/another/location/model.pt"
156
+ ```
157
+
158
+ `--mode eval` or `--mode infer`will use the corresponding configurations from the `validate` or `infer`
159
+ of the `configs/hyper_parameters.yaml`.
160
+
161
+ By default the generated `model.pt` corresponds to the checkpoint at the best validation score,
162
+ `model_final.pt` is the checkpoint after the latest training epoch.
163
+
164
+
165
+ ### Development
166
+
167
+ For development purposes it's possible to run the script directly (without monai bundle calls)
168
+
169
+ ```bash
170
+ python scripts/workflow.py --config_file configs/hyper_parameters.yaml ...
171
+ torchrun --nproc_per_node=gpu -m scripts/workflow.py --config_file configs/hyper_parameters.yaml ..
172
+ ```
173
+
174
+ ### MLFlow support
175
+
176
+ Enable MLFlow logging by specifying "mlflow_tracking_uri" (can be local or remote URL).
177
+
178
+ ```bash
179
+ python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --mlflow_tracking_uri=http://127.0.0.1:8080
180
+ ```
181
+
182
+ Optionally use "--mlflow_run_name=.." to specify MLFlow experiment name, and "--mlflow_log_system_metrics=True/False" to enable logging of CPU/GPU resources (requires pip install psutil pynvml)
183
+
184
+
185
+
186
+ ### Unit tests
187
+
188
+ Test single GPU training:
189
+ ```
190
+ python unit_tests/test_vista2d.py
191
+ ```
192
+
193
+ Test multi-GPU training (may need to uncomment the `"--standalone"` in the `unit_tests/utils.py` file):
194
+ ```
195
+ python unit_tests/test_vista2d_mgpu.py
196
+ ```
197
+
198
+ ## Compute Requirements
199
+ Min GPU memory requirements 16Gb.
200
+
201
+
202
+ ## Contributing
203
+ Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).
204
+ Make sure to run pre-commit before committing code changes to git
205
+ ```bash
206
+ pip install pre-commit
207
+ python3 -m pre_commit run --all-files
208
+ ```
209
+
210
+
211
+ ## Community
212
+
213
+ Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join
214
+ our [Slack channel](https://projectmonai.slack.com/archives/C031QRE0M1C).
215
+
216
+ Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://github.com/Project-MONAI/VISTA/discussions).
217
+
218
+ ## License
219
+
220
+ The codebase is under Apache 2.0 Licence. The model weight is released under CC-BY-NC-SA-4.0. For various public data licenses please see `data_license.txt`.
221
+
222
+ ## Acknowledgement
223
+ - [segment-anything](https://github.com/facebookresearch/segment-anything)
224
+ - [Cellpose](https://www.cellpose.org/)
docs/data_license.txt ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i.Cellpose dataset
6
+
7
+ https://www.cellpose.org/dataset
8
+ The user agrees to the listed conditions of Cellpose dataset by default that are cited below:
9
+
10
+ Howard Hughes Medical Institute
11
+
12
+ Research Content Terms and Conditions
13
+
14
+ Please read these Research Content Terms and Conditions (“Terms and Conditions”) carefully before you download or use
15
+ any images in any format from the cellpose.org website (“Content”), and do not download or use Content if you do not
16
+ agree with these Terms and Conditions. The Howard Hughes Medical Institute (“HHMI”, “we”, “us” and “our”) may at any
17
+ time revise these Terms and Conditions by updating this posting. You are bound by any such revisions and should
18
+ therefore periodically visit this page to review the then-current Terms and Conditions.
19
+
20
+ BY ACCEPTING THESE TERMS AND CONDITIONS, DOWNLOADING THE CONTENT OR USING THE CONTENT, YOU ARE CONFIRMING YOUR AGREEMENT
21
+ TO BE BOUND BY THESE TERMS AND CONDITIONS INCLUDING THE WARRANTY DISCLAIMERS, LIMITATIONS OF LIABILITY AND TERMINATION
22
+ PROVISIONS BELOW. IF ANY OF THESE TERMS AND CONDITIONS OR ANY FUTURE CHANGES ARE UNACCEPTABLE TO YOU, DO NOT DOWNLOAD
23
+ OR USE THE CONTENT AT THE CELLPOSE.ORG WEBSITE. BY DOWNLOADING OR USING CONTENT FROM CELLPOSE.ORG YOU ACCEPT AND AGREE
24
+ TO THESE TERMS AND CONDITIONS WITHOUT ANY RESERVATIONS, MODIFICATIONS, ADDITIONS, OR DELETIONS. IF YOU DO NOT AGREE TO
25
+ THESE TERMS AND CONDITIONS, YOU ARE NOT AUTHORIZED TO DOWNLOAD OR USE THE CONTENT. IF YOU REPRESENT A CORPORATION,
26
+ PARTNERSHIP, OR OTHER NON-INDIVIDUAL ENTITY, THE PERSON ACCEPTING THESE TERMS AND CONDITIONS ON BEHALF OF THAT ENTITY
27
+ REPRESENTS AND WARRANTS THAT THEY HAVE ALL NECESSARY AUTHORITY TO BIND THAT ENTITY.
28
+
29
+ Ownership
30
+ All Content is protected by copyright, and such copyrights and other proprietary rights may be held by individuals or
31
+ entities other than, or in addition to, us.
32
+ Use and Restrictions
33
+ The Content is made available for limited non-commercial, educational, research and personal use only, and for fair use
34
+ as defined under United States copyright laws. You may download and use Content only for your own non-commercial,
35
+ educational, research and personal use only, subject to any additional terms or restrictions which may be applicable to
36
+ an individual file as part of the Content. Copying or redistribution of the Content in any manner for commercial use,
37
+ including commercial publication, or for personal gain, or making any other use of the Content beyond that allowed by
38
+ “fair use,” as such term is understood under the United States Copyright Act and applicable law, is strictly prohibited.
39
+ HHMI may terminate these Terms and Conditions, and your right to use the Content at any time upon notice to you (which
40
+ notice may be to your email address of record with HHMI). Upon any termination by HHMI, you agree that you will promptly
41
+ delete all copies of Content and, upon request by HHMI, certify in writing your deletion of all copies of Content.
42
+ Indemnity
43
+ You agree to indemnify, defend, and hold harmless HHMI and our affiliates, and our trustees, officers, members,
44
+ directors, employees, representatives and agents from and against all claims, losses, expenses, damages, costs and other
45
+ liability (including without limitation attorneys’ fees), arising or resulting from your use of the Content (including,
46
+ without limitation, any copies and derivative works of any Content), or any violation or alleged violation by you of
47
+ these Terms and Conditions, including for any violation of any applicable law, rule, or regulation. We reserve the
48
+ right to assume, at our sole expense, the exclusive defense and control of any matter subject to indemnification by
49
+ you, in which event you will fully cooperate with us.
50
+
51
+ Disclaimers
52
+ WE MAKE NO EXPRESS WARRANTIES OR REPRESENTATIONS AS TO THE QUALITY, COMPREHENSIVENESS, AND ACCURACY OF THE CONTENT, AND
53
+ WE DISCLAIM ANY IMPLIED WARRANTIES OR REPRESENTATIONS, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY,
54
+ FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT, TO THE FULL EXTENT PERMISSIBLE UNDER APPLICABLE LAW. WE OFFER THE
55
+ CONTENT ON AN "AS IS” BASIS AND DO NOT ACCEPT RESPONSIBILITY FOR ANY USE OF OR RELIANCE ON THE CONTENT. IN ADDITION, WE
56
+ DO NOT MAKE ANY REPRESENTATIONS AS TO THE ACCURACY, COMPREHENSIVENESS, COMPLETENESS, QUALITY, CURRENCY, ERROR-FREE NATURE,
57
+ COMPATIBILITY, OR FITNESS FOR ANY PARTICULAR PURPOSE OF THE CONTENT. WE ASSUME NO LIABILITY, AND SHALL NOT BE LIABLE FOR,
58
+ ANY DAMAGES TO, OR VIRUSES OR OTHER MALWARE THAT MAY AFFECT, YOUR COMPUTER EQUIPMENT OR OTHER PROPERTY AS A RESULT OF
59
+ YOUR DOWNLOADING OF, AND USE OF, ANY CONTENT.
60
+
61
+ Limitation of Liability
62
+ TO THE FULLEST EXTENT PERMITTED UNDER APPLICABLE LAW, IN NO EVENT WILL WE, OR ANY OF OUR EMPLOYEES, AGENTS, OFFICERS, OR
63
+ TRUSTEES, BE LIABLE FOR DAMAGES OF ANY KIND, UNDER ANY LEGAL THEORY, ARISING OUT OF OR IN CONNECTION WITH YOUR USE, OR
64
+ INABILITY TO USE, THE CONTENT, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, CONSEQUENTIAL OR PUNITIVE DAMAGES,
65
+ INCLUDING BUT NOT LIMITED TO, LOSS OF REVENUE, LOSS OF PROFITS, LOSS OF BUSINESS OR ANTICIPATED SAVINGS, LOSS OF USE,
66
+ LOSS OF GOODWILL, LOSS OF DATA, AND WHETHER CAUSED BY TORT (INCLUDING NEGLIGENCE), BREACH OF CONTRACT OR OTHERWISE,
67
+ EVEN IF FORESEEABLE. BECAUSE SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR LIMITATION OF LIABLITY FOR CONSEQUENTIAL
68
+ OR INCIDENTAL DAMAGES, ALL OR A PORTION OF THE ABOVE LIMITATION MAY NOT APPLY TO YOU.
69
+
70
+ General
71
+ If any provision of these Terms and Conditions is held to be invalid, illegal, or unenforceable, then such provision
72
+ shall be eliminated or limited to the minimum extent such that the remaining provisions of the Terms and Conditions
73
+ will continue in full force and effect. All matters relating to and arising from the Content or these Terms and Conditions
74
+ shall be governed by and construed in accordance with the internal laws of the State of Maryland without giving effect
75
+ to any choice or conflict of law provision or rule. If you choose to download or access the Content from locations
76
+ outside the United States, you do so at your own risk and you are responsible for compliance with any local laws.
77
+
78
+ /*********************************************************************/
79
+
80
+ /*********************************************************************/
81
+ ii. TissueNet dataset
82
+
83
+ https://datasets.deepcell.org/
84
+
85
+ Modified Apache License
86
+ Version 2.0, January 2004
87
+
88
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
89
+
90
+ 1. Definitions.
91
+
92
+ "License" shall mean the terms and conditions for use, reproduction,
93
+ and distribution as defined by Sections 1 through 9 of this document.
94
+
95
+ "Licensor" shall mean the copyright owner or entity authorized by
96
+ the copyright owner that is granting the License.
97
+
98
+ "Legal Entity" shall mean the union of the acting entity and all
99
+ other entities that control, are controlled by, or are under common
100
+ control with that entity. For the purposes of this definition,
101
+ "control" means (i) the power, direct or indirect, to cause the
102
+ direction or management of such entity, whether by contract or
103
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
104
+ outstanding shares, or (iii) beneficial ownership of such entity.
105
+
106
+ "You" (or "Your") shall mean an individual or Legal Entity
107
+ exercising permissions granted by this License.
108
+
109
+ "Source" form shall mean the preferred form for making modifications,
110
+ including but not limited to software source code, documentation
111
+ source, and configuration files.
112
+
113
+ "Object" form shall mean any form resulting from mechanical
114
+ transformation or translation of a Source form, including but
115
+ not limited to compiled object code, generated documentation,
116
+ and conversions to other media types.
117
+
118
+ "Work" shall mean the work of authorship, whether in Source or
119
+ Object form, made available under the License, as indicated by a
120
+ copyright notice that is included in or attached to the work
121
+ (an example is provided in the Appendix below).
122
+
123
+ "Derivative Works" shall mean any work, whether in Source or Object
124
+ form, that is based on (or derived from) the Work and for which the
125
+ editorial revisions, annotations, elaborations, or other modifications
126
+ represent, as a whole, an original work of authorship. For the purposes
127
+ of this License, Derivative Works shall not include works that remain
128
+ separable from, or merely link (or bind by name) to the interfaces of,
129
+ the Work and Derivative Works thereof.
130
+
131
+ "Contribution" shall mean any work of authorship, including
132
+ the original version of the Work and any modifications or additions
133
+ to that Work or Derivative Works thereof, that is intentionally
134
+ submitted to Licensor for inclusion in the Work by the copyright owner
135
+ or by an individual or Legal Entity authorized to submit on behalf of
136
+ the copyright owner. For the purposes of this definition, "submitted"
137
+ means any form of electronic, verbal, or written communication sent
138
+ to the Licensor or its representatives, including but not limited to
139
+ communication on electronic mailing lists, source code control systems,
140
+ and issue tracking systems that are managed by, or on behalf of, the
141
+ Licensor for the purpose of discussing and improving the Work, but
142
+ excluding communication that is conspicuously marked or otherwise
143
+ designated in writing by the copyright owner as "Not a Contribution."
144
+
145
+ "Contributor" shall mean Licensor and any individual or Legal Entity
146
+ on behalf of whom a Contribution has been received by Licensor and
147
+ subsequently incorporated within the Work.
148
+
149
+ 2. Grant of Copyright License. Subject to the terms and conditions of
150
+ this License, each Contributor hereby grants to You a non-commercial,
151
+ academic perpetual, worldwide, non-exclusive, no-charge, royalty-free,
152
+ irrevocable copyright license to reproduce, prepare Derivative Works
153
+ of, publicly display, publicly perform, sublicense, and distribute the
154
+ Work and such Derivative Works in Source or Object form. For any other
155
+ use, including commercial use, please contact: [email protected].
156
+
157
+ 3. Grant of Patent License. Subject to the terms and conditions of
158
+ this License, each Contributor hereby grants to You a non-commercial,
159
+ academic perpetual, worldwide, non-exclusive, no-charge, royalty-free,
160
+ irrevocable (except as stated in this section) patent license to make,
161
+ have made, use, offer to sell, sell, import, and otherwise transfer the
162
+ Work, where such license applies only to those patent claims licensable
163
+ by such Contributor that are necessarily infringed by their
164
+ Contribution(s) alone or by combination of their Contribution(s)
165
+ with the Work to which such Contribution(s) was submitted. If You
166
+ institute patent litigation against any entity (including a
167
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
168
+ or a Contribution incorporated within the Work constitutes direct
169
+ or contributory patent infringement, then any patent licenses
170
+ granted to You under this License for that Work shall terminate
171
+ as of the date such litigation is filed.
172
+
173
+ 4. Redistribution. You may reproduce and distribute copies of the
174
+ Work or Derivative Works thereof in any medium, with or without
175
+ modifications, and in Source or Object form, provided that You
176
+ meet the following conditions:
177
+
178
+ (a) You must give any other recipients of the Work or
179
+ Derivative Works a copy of this License; and
180
+
181
+ (b) You must cause any modified files to carry prominent notices
182
+ stating that You changed the files; and
183
+
184
+ (c) You must retain, in the Source form of any Derivative Works
185
+ that You distribute, all copyright, patent, trademark, and
186
+ attribution notices from the Source form of the Work,
187
+ excluding those notices that do not pertain to any part of
188
+ the Derivative Works; and
189
+
190
+ (d) If the Work includes a "NOTICE" text file as part of its
191
+ distribution, then any Derivative Works that You distribute must
192
+ include a readable copy of the attribution notices contained
193
+ within such NOTICE file, excluding those notices that do not
194
+ pertain to any part of the Derivative Works, in at least one
195
+ of the following places: within a NOTICE text file distributed
196
+ as part of the Derivative Works; within the Source form or
197
+ documentation, if provided along with the Derivative Works; or,
198
+ within a display generated by the Derivative Works, if and
199
+ wherever such third-party notices normally appear. The contents
200
+ of the NOTICE file are for informational purposes only and
201
+ do not modify the License. You may add Your own attribution
202
+ notices within Derivative Works that You distribute, alongside
203
+ or as an addendum to the NOTICE text from the Work, provided
204
+ that such additional attribution notices cannot be construed
205
+ as modifying the License.
206
+
207
+ You may add Your own copyright statement to Your modifications and
208
+ may provide additional or different license terms and conditions
209
+ for use, reproduction, or distribution of Your modifications, or
210
+ for any such Derivative Works as a whole, provided Your use,
211
+ reproduction, and distribution of the Work otherwise complies with
212
+ the conditions stated in this License.
213
+
214
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
215
+ any Contribution intentionally submitted for inclusion in the Work
216
+ by You to the Licensor shall be under the terms and conditions of
217
+ this License, without any additional terms or conditions.
218
+ Notwithstanding the above, nothing herein shall supersede or modify
219
+ the terms of any separate license agreement you may have executed
220
+ with Licensor regarding such Contributions.
221
+
222
+ 6. Trademarks. This License does not grant permission to use the trade
223
+ names, trademarks, service marks, or product names of the Licensor,
224
+ except as required for reasonable and customary use in describing the
225
+ origin of the Work and reproducing the content of the NOTICE file.
226
+
227
+ 7. Disclaimer of Warranty. Unless required by applicable law or
228
+ agreed to in writing, Licensor provides the Work (and each
229
+ Contributor provides its Contributions) on an "AS IS" BASIS,
230
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
231
+ implied, including, without limitation, any warranties or conditions
232
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
233
+ PARTICULAR PURPOSE. You are solely responsible for determining the
234
+ appropriateness of using or redistributing the Work and assume any
235
+ risks associated with Your exercise of permissions under this License.
236
+
237
+ 8. Limitation of Liability. In no event and under no legal theory,
238
+ whether in tort (including negligence), contract, or otherwise,
239
+ unless required by applicable law (such as deliberate and grossly
240
+ negligent acts) or agreed to in writing, shall any Contributor be
241
+ liable to You for damages, including any direct, indirect, special,
242
+ incidental, or consequential damages of any character arising as a
243
+ result of this License or out of the use or inability to use the
244
+ Work (including but not limited to damages for loss of goodwill,
245
+ work stoppage, computer failure or malfunction, or any and all
246
+ other commercial damages or losses), even if such Contributor
247
+ has been advised of the possibility of such damages.
248
+
249
+ 9. Accepting Warranty or Additional Liability. While redistributing
250
+ the Work or Derivative Works thereof, You may choose to offer,
251
+ and charge a fee for, acceptance of support, warranty, indemnity,
252
+ or other liability obligations and/or rights consistent with this
253
+ License. However, in accepting such obligations, You may act only
254
+ on Your own behalf and on Your sole responsibility, not on behalf
255
+ of any other Contributor, and only if You agree to indemnify,
256
+ defend, and hold each Contributor harmless for any liability
257
+ incurred by, or claims asserted against, such Contributor by reason
258
+ of your accepting any such warranty or additional liability.
259
+
260
+ 10. Neither the name of Caltech nor the names of its contributors may be
261
+ used to endorse or promote products derived from this software without
262
+ specific prior written permission.
263
+
264
+ END OF TERMS AND CONDITIONS
265
+
266
+ APPENDIX: How to apply the Apache License to your work.
267
+
268
+ To apply the Apache License to your work, attach the following
269
+ boilerplate notice, with the fields enclosed by brackets "[]"
270
+ replaced with your own identifying information. (Don't include
271
+ the brackets!) The text should be enclosed in the appropriate
272
+ comment syntax for the file format. We also recommend that a
273
+ file or class name and description of purpose be included on the
274
+ same "printed page" as the copyright notice for easier
275
+ identification within third-party archives.
276
+
277
+ Copyright [yyyy] [name of copyright owner]
278
+
279
+ Licensed under the Apache License, Version 2.0 (the "License");
280
+ you may not use this file except in compliance with the License.
281
+ You may obtain a copy of the License at
282
+
283
+ http://www.apache.org/licenses/LICENSE-2.0
284
+
285
+ Unless required by applicable law or agreed to in writing, software
286
+ distributed under the License is distributed on an "AS IS" BASIS,
287
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
288
+ See the License for the specific language governing permissions and
289
+ limitations under the License.
290
+
291
+ /*********************************************************************/
292
+
293
+ /*********************************************************************/
294
+ iii. Kaggle Nuclei Segmentation
295
+ https://www.nature.com/articles/s41592-019-0612-7#rightslink
296
+
297
+ CC BY 4.0
298
+ http://creativecommons.org/licenses/by/4.0/
299
+
300
+ /*********************************************************************/
301
+
302
+ /*********************************************************************/
303
+ iv. Omnipose
304
+
305
+ https://github.com/kevinjohncutler/omnipose/blob/main/LICENSE
306
+
307
+ Omnipose NonCommercial License
308
+ Copyright (c) 2021 University of Washington.
309
+
310
+ Redistribution and use for noncommercial purposes in source and binary forms, with or without modification, are permitted
311
+ provided that the following conditions are met:
312
+ 1. The software is used solely for noncommercial purposes. For commercial use rights, contact University of Washington,
313
+ CoMotion, at [email protected].
314
+ 2. Redistributions of source code must retain the above copyright notice, this list of conditions and the below
315
+ disclaimer.
316
+ 3. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
317
+ disclaimer in the documentation and/or other materials provided with the distribution.
318
+ 4. Redistributions, with or without modifications, shall only be licensed under this NonCommercial License.
319
+ 5. Neither the name of the University of Washington nor the names of its contributors may be used to endorse or promote
320
+ products derived from this software without specific prior written permission.
321
+
322
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
323
+ INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
324
+ DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
325
+ INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
326
+ OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
327
+ WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
328
+ THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
329
+
330
+ /*********************************************************************/
331
+
332
+ /*********************************************************************/
333
+ v. NIPS Cell Segmentation Challenge
334
+
335
+ https://neurips22-cellseg.grand-challenge.org/dataset/
336
+
337
+ CC BY-NC-ND
338
+ https://creativecommons.org/licenses/by-nc-nd/4.0/deed.en
339
+
340
+ /*********************************************************************/
341
+
342
+ /*********************************************************************/
343
+ vi. LiveCell
344
+
345
+ https://sartorius-research.github.io/LIVECell/
346
+
347
+ CC BY-NC 4.0
348
+ https://creativecommons.org/licenses/by-nc/4.0/
349
+
350
+ /*********************************************************************/
351
+
352
+ /*********************************************************************/
353
+ vii. Deepbacs
354
+
355
+ https://github.com/HenriquesLab/DeepBacs/blob/main/LICENSE
356
+
357
+ CC0 1.0
358
+ https://creativecommons.org/publicdomain/zero/1.0/deed.en
359
+
360
+ /*********************************************************************/
361
+ Data Usage Agreement / Citations
download_preprocessor/all_file_downloader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ import os
14
+
15
+ import requests
16
+ from tqdm import tqdm
17
+
18
+
19
+ def download_files(url_dict, directory):
20
+ if not os.path.exists(directory):
21
+ os.makedirs(directory)
22
+
23
+ for key, url in url_dict.items():
24
+ if key == "nips_train.zip" or key == "nips_test.zip":
25
+ if not os.path.exists(os.path.join(directory, "nips_dataset")):
26
+ os.mkdir(os.path.join(directory, "nips_dataset"))
27
+ base_dir = os.path.join(directory, "nips_dataset")
28
+ elif key == "deepbacs.zip":
29
+ if not os.path.exists(os.path.join(directory, "deepbacs_dataset")):
30
+ os.mkdir(os.path.join(directory, "deepbacs_dataset"))
31
+ base_dir = os.path.join(directory, "deepbacs_dataset")
32
+ elif key == "livecell":
33
+ if not os.path.exists(os.path.join(directory, "livecell_dataset")):
34
+ os.mkdir(os.path.join(directory, "livecell_dataset"))
35
+ base_dir = os.path.join(directory, "livecell_dataset")
36
+ print(f"Downloading from {key}: {url}")
37
+ os.system(url + base_dir)
38
+ continue
39
+
40
+ try:
41
+ print(f"Downloading from {key}: {url}")
42
+ response = requests.get(url, stream=True, allow_redirects=True)
43
+ total_size = int(response.headers.get("content-length", 0))
44
+
45
+ # Extract the filename from the URL or use the key as the filename
46
+ filename = os.path.basename(key)
47
+ file_path = os.path.join(base_dir, filename)
48
+
49
+ # Write the content to a file in the specified directory with progress
50
+ with open(file_path, "wb") as file, tqdm(
51
+ desc=filename, total=total_size, unit="iB", unit_scale=True, unit_divisor=1024
52
+ ) as bar:
53
+ for data in response.iter_content(chunk_size=1024):
54
+ size = file.write(data)
55
+ bar.update(size)
56
+
57
+ print(f"Saved to {file_path}")
58
+ except Exception as e:
59
+ print(f"Failed to download from {key} ({url}). Reason: {str(e)}")
60
+
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(description="Process some integers.")
64
+ parser.add_argument("--dir", type=str, help="Directory to download files to", default="/set/the/path")
65
+
66
+ args = parser.parse_args()
67
+ directory = os.path.normpath(args.dir)
68
+
69
+ url_dict = {
70
+ "deepbacs.zip": "https://zenodo.org/records/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1",
71
+ "nips_test.zip": "https://zenodo.org/records/10719375/files/Testing.zip?download=1",
72
+ "nips_train.zip": "https://zenodo.org/records/10719375/files/Training-labeled.zip?download=1",
73
+ "livecell": "wget --recursive --no-parent --cut-dirs=0 --timestamping -i urls.txt --directory-prefix=",
74
+ # Add URLs with keys here
75
+ }
76
+ download_files(url_dict, directory)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
download_preprocessor/cellpose_agreement.png ADDED

Git LFS Details

  • SHA256: d96d87817e2d7def9b435bbed93cd08eb6d50f57fc6f8f165d3cc596dce8e93b
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
download_preprocessor/cellpose_links.png ADDED

Git LFS Details

  • SHA256: fcb78176512c2db402c8779e91580bd5b786b3601c2a8cfb4923aa1b73fdf13b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
download_preprocessor/data_tree.png ADDED
download_preprocessor/generate_json.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ import gc
14
+ import json
15
+ import os
16
+ import shutil
17
+ import time
18
+ import warnings
19
+ import zipfile
20
+
21
+ import imageio.v3 as imageio
22
+ import numpy as np
23
+ from PIL import Image
24
+ from pycocotools.coco import COCO
25
+ from sklearn.model_selection import KFold
26
+
27
+ # from skimage.io import imsave
28
+ # from skimage.measure import label
29
+ # import imageio
30
+
31
+
32
+ def min_label_precision(label):
33
+ lm = label.max()
34
+
35
+ if lm <= 255:
36
+ label = label.astype(np.uint8)
37
+ elif lm <= 65535:
38
+ label = label.astype(np.uint16)
39
+ else:
40
+ label = label.astype(np.uint32)
41
+
42
+ return label
43
+
44
+
45
+ def guess_convert_to_uint16(img, margin=30):
46
+ """
47
+ Guess a multiplier that makes all pixels integers.
48
+ The input img (each channel) is already in the range 0..1, they must have been converted from uint16 integers as image / scale,
49
+ where scale was the unknown max intensity.
50
+ We could guess the scale by looking at unique values: 1/np.min(np.diff(np.unique(im)).
51
+ the hypothesis is that it will be more accurate recovery of the original image,
52
+ instead of doing a simple (img*65535).astype(np.uint16)
53
+ """
54
+
55
+ for i in range(img.shape[0]):
56
+ im = img[i]
57
+
58
+ if im.any():
59
+ start = time.time()
60
+ imsmall = im[::4, ::4] # subsample
61
+ # imsmall = im
62
+
63
+ scale = int(np.round(1 / np.min(np.diff(np.unique(imsmall))))) # guessing scale
64
+ test = [
65
+ (np.sum((imsmall * k) % 1)) for k in range(scale - margin, scale + margin)
66
+ ] # finetune, guess a multiplier that makes all pixels integers
67
+ sid = np.argmin(test) # fine tune scale
68
+
69
+ if scale < 16000 or scale > 16400:
70
+ warnings.warn("scale not in expected range")
71
+ print(
72
+ "guessing scale",
73
+ scale,
74
+ test[margin],
75
+ "fine tuning scale",
76
+ scale - margin + sid,
77
+ "dif",
78
+ test[sid],
79
+ "time",
80
+ time.time() - start,
81
+ )
82
+
83
+ scale = 16384
84
+ else:
85
+ scale = scale - margin + sid
86
+ # all the recovered scale values seems to be up to 16384,
87
+ # we can stretch to 65535(for better visualization, most tiff viewers expect that range)
88
+ scale = min(65535, scale * 4)
89
+ img[i] = im * scale
90
+
91
+ img = img.astype(np.uint16)
92
+ return img
93
+
94
+
95
+ def concatenate_masks(mask_dir):
96
+ labeled_mask = None
97
+ i = 0
98
+ for filename in sorted(os.listdir(mask_dir)):
99
+ if filename.endswith(".png"):
100
+ mask = imageio.imread(os.path.join(mask_dir, filename))
101
+ if labeled_mask is None:
102
+ labeled_mask = np.zeros(shape=mask.shape, dtype=np.uint16)
103
+ labeled_mask[mask > 0] = i
104
+ i = i + 1
105
+
106
+ if i <= 255:
107
+ labeled_mask = labeled_mask.astype(np.uint8)
108
+
109
+ return labeled_mask
110
+
111
+
112
+ # def concatenate_masks(mask_dir):
113
+ # masks = []
114
+ # for filename in sorted(os.listdir(mask_dir)):
115
+ # if filename.endswith('.png'):
116
+ # mask = imageio.imread(os.path.join(mask_dir, filename))
117
+ # masks.append(mask)
118
+ # concatenated_mask = np.any(masks, axis=0).astype(np.uint8)
119
+ # labeled_mask = label(concatenated_mask)
120
+ # return labeled_mask
121
+
122
+ # def normalize_image(image):
123
+ # # Convert to float and normalize each channel
124
+ # image = image.astype(np.float32)
125
+ # for i in range(3):
126
+ # channel = image[..., i]
127
+ # channel_min = np.min(channel)
128
+ # channel_max = np.max(channel)
129
+ # if channel_max - channel_min != 0:
130
+ # image[..., i] = (channel - channel_min) / (channel_max - channel_min)
131
+ # return image
132
+
133
+
134
+ def get_filenames_exclude_masks(dir1, target_string):
135
+ filenames = []
136
+ # Combine lists of files from both directories
137
+ files = os.listdir(dir1)
138
+ # Filter files that contain the target string but exclude 'masks'
139
+ filenames = [f for f in files if target_string in f and "masks" not in f]
140
+
141
+ return filenames
142
+
143
+
144
+ def remove_overlaps(masks, medians, overlap_threshold=0.75):
145
+ """replace overlapping mask pixels with mask id of closest mask
146
+ if mask fully within another mask, remove it
147
+ masks = Nmasks x Ly x Lx
148
+ """
149
+ cellpix = masks.sum(axis=0)
150
+ igood = np.ones(masks.shape[0], "bool")
151
+ for i in masks.sum(axis=(1, 2)).argsort():
152
+ npix = float(masks[i].sum())
153
+ noverlap = float(masks[i][cellpix > 1].sum())
154
+ if noverlap / npix >= overlap_threshold:
155
+ igood[i] = False
156
+ cellpix[masks[i] > 0] -= 1
157
+ # print(cellpix.min())
158
+ print(f"removing {(~igood).sum()} masks")
159
+ masks = masks[igood]
160
+ medians = medians[igood]
161
+ cellpix = masks.sum(axis=0)
162
+ overlaps = np.array(np.nonzero(cellpix > 1.0)).T
163
+ dists = ((overlaps[:, :, np.newaxis] - medians.T) ** 2).sum(axis=1)
164
+ tocell = np.argmin(dists, axis=1)
165
+ masks[:, overlaps[:, 0], overlaps[:, 1]] = 0
166
+ masks[tocell, overlaps[:, 0], overlaps[:, 1]] = 1
167
+
168
+ # labels should be 1 to mask.shape[0]
169
+ masks = masks.astype(int) * np.arange(1, masks.shape[0] + 1, 1, int)[:, np.newaxis, np.newaxis]
170
+ masks = masks.sum(axis=0)
171
+ gc.collect()
172
+ return masks
173
+
174
+
175
+ def livecell_json_files(dataset_dir, json_f_path):
176
+ """
177
+ This function takes in the directory of livecell extracted dataset as input and
178
+ creates 7 json lists with 5 folds. Separate testing set is recorded in the json list.
179
+ Please note that there are some hard-coded directory names as per the original dataset.
180
+ At the time of creation, the livecell zipfile had 'images' and 'LIVECell_dataset_2021' directories
181
+ """
182
+
183
+ # "A172", "BT474", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"
184
+ # TODO "BV2" is being skipped
185
+ cell_type_list = ["A172", "BT474", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]
186
+ for each_cell_tp in cell_type_list:
187
+ for split in ["train", "val", "test"]:
188
+ print(f"Working on split: {split}")
189
+
190
+ if split == "test":
191
+ img_path = os.path.join(dataset_dir, "images", "livecell_test_images", each_cell_tp)
192
+ msk_path = os.path.join(dataset_dir, "images", "livecell_test_images", each_cell_tp + "_masks")
193
+ else:
194
+ img_path = os.path.join(dataset_dir, "images", "livecell_train_val_images", each_cell_tp)
195
+ msk_path = os.path.join(dataset_dir, "images", "livecell_train_val_images", each_cell_tp + "_masks")
196
+ if not os.path.exists(msk_path):
197
+ os.makedirs(msk_path)
198
+
199
+ # annotation path
200
+ path = os.path.join(
201
+ dataset_dir,
202
+ "livecell-dataset.s3.eu-central-1.amazonaws.com",
203
+ "LIVECell_dataset_2021",
204
+ "annotations",
205
+ "LIVECell_single_cells",
206
+ each_cell_tp.lower(),
207
+ split + ".json",
208
+ )
209
+ annotation = COCO(path)
210
+ # Convert COCO format segmentation to binary mask
211
+ images = annotation.loadImgs(annotation.getImgIds())
212
+ height = []
213
+ width = []
214
+ for index, im in enumerate(images):
215
+ print("Status: {}/{}, Process image: {}".format(index, len(images), im["file_name"]))
216
+ if (
217
+ im["file_name"] == "BV2_Phase_C4_2_03d00h00m_1.tif"
218
+ or im["file_name"] == "BV2_Phase_C4_2_03d00h00m_3.tif"
219
+ ):
220
+ print("Skipping the file: BV2_Phase_C4_2_03d00h00m_1.tif, as it is troublesome")
221
+ continue
222
+ # load image
223
+ img = Image.open(os.path.join(img_path, im["file_name"])).convert("L")
224
+ height.append(img.size[0])
225
+ width.append(img.size[1])
226
+ # arr = np.asarray(img) #? not used
227
+ # msk = np.zeros(arr.shape)
228
+ # load and display instance annotations
229
+ annids = annotation.getAnnIds(imgIds=im["id"], iscrowd=None)
230
+ anns = annotation.loadAnns(annids)
231
+ idx = 1
232
+ medians = []
233
+ masks = []
234
+ k = 0
235
+ for ann in anns:
236
+ # convert segmentation to binary mask
237
+ mask = annotation.annToMask(ann)
238
+ masks.append(mask)
239
+ ypix, xpix = mask.nonzero()
240
+ medians.append(np.array([ypix.mean().astype(np.float32), xpix.mean().astype(np.float32)]))
241
+ k += 1
242
+ # add instance mask to image mask
243
+ # msk = np.add(msk, mask*idx)
244
+ # idx += 1
245
+
246
+ masks = np.array(masks).astype(np.int8)
247
+ medians = np.array(medians)
248
+ masks = remove_overlaps(masks, medians, overlap_threshold=0.75)
249
+ gc.collect()
250
+
251
+ # ## Create new name for the image and also for the mask and save them as .tif format
252
+ # masks_int32 = masks.astype(np.int32)
253
+ # mask_pil = Image.fromarray(masks_int32, 'I')
254
+
255
+ t_filename = im["file_name"]
256
+ # cell_type = t_filename.split('_')[0] #? not used
257
+ new_mask_name = t_filename[:-4] + "_masks.tif"
258
+ # mask_pil.save(os.path.join(msk_path, new_mask_name))
259
+ imageio.imwrite(os.path.join(msk_path, new_mask_name), min_label_precision(masks))
260
+ gc.collect()
261
+
262
+ print(f"In total {len(images)} images")
263
+
264
+ # The directory containing your files
265
+ # cell_type = 'BV2'
266
+ json_save_path = os.path.join(json_f_path, f"lc_{each_cell_tp}.json")
267
+ directory = os.path.join(dataset_dir, "images", "livecell_train_val_images", each_cell_tp)
268
+ mask_directory = os.path.join(dataset_dir, "images", "livecell_train_val_images", each_cell_tp + "_masks")
269
+ test_directory = os.path.join(dataset_dir, "images", "livecell_test_images", each_cell_tp)
270
+ mask_test_directory = os.path.join(dataset_dir, "images", "livecell_test_images", each_cell_tp + "_masks")
271
+ # List to hold all image-mask pairs
272
+ data_pairs = []
273
+ test_data_pairs = []
274
+ all_data = {}
275
+
276
+ # Scan the directory for image files and create pairs
277
+ for filename in os.listdir(directory):
278
+ if filename.endswith(".tif"):
279
+ # Construct the corresponding mask filename
280
+ mask_filename = filename.replace(".tif", "_masks.tif")
281
+
282
+ # Check if the corresponding mask file exists
283
+ if os.path.exists(os.path.join(mask_directory, mask_filename)):
284
+ # Add the pair to the list
285
+ data_pairs.append(
286
+ {
287
+ "image": os.path.join(
288
+ "livecell_dataset", "images", "livecell_train_val_images", each_cell_tp, filename
289
+ ),
290
+ "label": os.path.join(
291
+ "livecell_dataset",
292
+ "images",
293
+ "livecell_train_val_images",
294
+ f"{each_cell_tp}_masks",
295
+ mask_filename,
296
+ ),
297
+ }
298
+ )
299
+
300
+ # Convert data_pairs to a numpy array for easy indexing by KFold
301
+ data_pairs_array = np.array(data_pairs)
302
+
303
+ # Initialize KFold
304
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
305
+
306
+ # Assign fold numbers
307
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
308
+ for idx in val_index:
309
+ data_pairs_array[idx]["fold"] = fold
310
+
311
+ # Convert the array back to a list and sort by fold
312
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
313
+
314
+ print(sorted_data_pairs)
315
+
316
+ # Scan the directory for image files and create pairs
317
+ for filename in os.listdir(test_directory):
318
+ if filename.endswith(".tif"):
319
+ # Construct the corresponding mask filename
320
+ mask_filename = filename.replace(".tif", "_masks.tif")
321
+
322
+ # Check if the corresponding mask file exists
323
+ if os.path.exists(os.path.join(mask_test_directory, mask_filename)):
324
+ # Add the pair to the list
325
+ test_data_pairs.append(
326
+ {
327
+ "image": os.path.join(
328
+ "livecell_dataset", "images", "livecell_test_images", each_cell_tp, filename
329
+ ),
330
+ "label": os.path.join(
331
+ "livecell_dataset",
332
+ "images",
333
+ "livecell_test_images",
334
+ f"{each_cell_tp}_masks",
335
+ mask_filename,
336
+ ),
337
+ }
338
+ )
339
+
340
+ all_data["training"] = sorted_data_pairs
341
+ all_data["testing"] = test_data_pairs
342
+
343
+ with open(json_save_path, "w") as j_file:
344
+ json.dump(all_data, j_file, indent=4)
345
+ j_file.close()
346
+
347
+
348
+ def tissuenet_json_files(dataset_dir, json_f_path):
349
+ """
350
+ This function takes in the directory of TissueNet extracted dataset as input and
351
+ creates 13 json lists with 5 folds each. Separate testing set is recorded in the json list per subset.
352
+ Please note that there are some hard-coded directory names as per the original dataset.
353
+ At the time of creation, the tissuenet 1.0 zipfile had 'train', 'val' and 'test' directories that
354
+ images with paired labels.
355
+ """
356
+
357
+ for folder in ["train", "val", "test"]:
358
+ if not os.path.exists(os.path.join(dataset_dir, "tissuenet_1.0", folder)):
359
+ os.mkdir(os.path.join(dataset_dir, "tissuenet_1.0", folder))
360
+
361
+ for folder in ["train", "val", "test"]:
362
+ print(f"Working on {folder} directory of tissuenet")
363
+ f_name = f"tissuenet_1.0/tissuenet_v1.0_{folder}.npz"
364
+ dat = np.load(os.path.join(dataset_dir, f_name))
365
+ data = dat["X"]
366
+ labels = dat["y"]
367
+ tissues = dat["tissue_list"]
368
+ platforms = dat["platform_list"]
369
+ tlabels = np.unique(tissues)
370
+ plabels = np.unique(platforms)
371
+ tp = 0
372
+ for t in tlabels:
373
+ for p in plabels:
374
+ ix = ((tissues == t) * (platforms == p)).nonzero()[0]
375
+ tp += 1
376
+ if len(ix) > 0:
377
+ print(f"Working on {t} {p}")
378
+
379
+ for k, i in enumerate(ix):
380
+ print(f"Status: {k}/{len(ix)} {tp}/{len(tlabels) * len(plabels)} {t} {p}")
381
+ img = data[i].transpose(2, 0, 1)
382
+ label = labels[i][:, :, 0]
383
+
384
+ img = guess_convert_to_uint16(img) # guess inverse scale and convert to uint16
385
+ label = min_label_precision(label)
386
+
387
+ if folder == "train":
388
+ img = img.reshape(2, 2, 256, 2, 256).transpose(0, 1, 3, 2, 4).reshape(2, 4, 256, 256)
389
+ label = label.reshape(2, 256, 2, 256).transpose(0, 2, 1, 3).reshape(4, 256, 256)
390
+
391
+ zero_channel = np.zeros((1, img.shape[1], img.shape[2], img.shape[3]), dtype=img.dtype)
392
+
393
+ # Concatenate the zero channel with the original array along the first dimension
394
+ new_array = np.concatenate([img, zero_channel], axis=0)
395
+ # reshaped_array = np.transpose(new_array, (1, 2, 3, 0))
396
+ for j in range(4):
397
+ img_name = f"{folder}/{t}_{p}_{k}_{j}.tif"
398
+ mask_name = f"{folder}/{t}_{p}_{k}_{j}_masks.tif"
399
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", img_name), new_array[:, j])
400
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", mask_name), label[j])
401
+ else:
402
+ zero_channel = np.zeros((1, img.shape[1], img.shape[2]), dtype=img.dtype)
403
+ new_array = np.concatenate([img, zero_channel], axis=0)
404
+ # reshaped_array = np.transpose(new_array, (1, 2, 0))
405
+ img_name = f"{folder}/{t}_{p}_{k}.tif"
406
+ mask_name = f"{folder}/{t}_{p}_{k}_masks.tif"
407
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", img_name), new_array)
408
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", mask_name), label)
409
+
410
+ t_p_combos = [
411
+ ["breast", "imc"],
412
+ ["breast", "mibi"],
413
+ ["breast", "vectra"],
414
+ ["gi", "codex"],
415
+ ["gi", "mibi"],
416
+ ["gi", "mxif"],
417
+ ["immune", "cycif"],
418
+ ["immune", "mibi"],
419
+ ["immune", "vectra"],
420
+ ["lung", "cycif"],
421
+ ["lung", "mibi"],
422
+ ["pancreas", "codex"],
423
+ ["pancreas", "vectra"],
424
+ ["skin", "mibi"],
425
+ ]
426
+
427
+ for each_t_p in t_p_combos:
428
+ json_f_name = "tn_" + each_t_p[0] + "_" + each_t_p[1] + ".json"
429
+ json_f_subset_path = os.path.join(json_f_path, json_f_name)
430
+
431
+ tp_match = each_t_p[0] + "_" + each_t_p[1]
432
+ train_filenames = get_filenames_exclude_masks(os.path.join(dataset_dir, "tissuenet_1.0", "train"), tp_match)
433
+ val_filenames = get_filenames_exclude_masks(os.path.join(dataset_dir, "tissuenet_1.0", "val"), tp_match)
434
+ test_filenames = get_filenames_exclude_masks(os.path.join(dataset_dir, "tissuenet_1.0", "test"), tp_match)
435
+
436
+ train_data_list = []
437
+ test_data_list = []
438
+
439
+ for each_tf in train_filenames:
440
+ t_dict = {
441
+ "image": os.path.join("tissuenet_dataset", "tissuenet_1.0", "train", each_tf),
442
+ "label": os.path.join("tissuenet_dataset", "tissuenet_1.0", "train", each_tf[:-4] + "_masks.tif"),
443
+ }
444
+ train_data_list.append(t_dict)
445
+
446
+ for each_vf in val_filenames:
447
+ t_dict = {
448
+ "image": os.path.join("tissuenet_dataset", "tissuenet_1.0", "val", each_vf),
449
+ "label": os.path.join("tissuenet_dataset", "tissuenet_1.0", "val", each_vf[:-4] + "_masks.tif"),
450
+ }
451
+ train_data_list.append(t_dict)
452
+
453
+ for each_tf in test_filenames:
454
+ t_dict = {
455
+ "image": os.path.join("tissuenet_dataset", "tissuenet_1.0", "test", each_tf),
456
+ "label": os.path.join("tissuenet_dataset", "tissuenet_1.0", "test", each_tf[:-4] + "_masks.tif"),
457
+ }
458
+ test_data_list.append(t_dict)
459
+
460
+ # print(train_data_list)
461
+ # print(test_data_list)
462
+
463
+ # Convert data_pairs to a numpy array for easy indexing by KFold
464
+ data_pairs_array = np.array(train_data_list)
465
+
466
+ # Initialize KFold
467
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
468
+
469
+ # Assign fold numbers
470
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
471
+ for idx in val_index:
472
+ data_pairs_array[idx]["fold"] = fold
473
+
474
+ # Convert the array back to a list and sort by fold
475
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
476
+
477
+ print(sorted_data_pairs)
478
+
479
+ all_data = {}
480
+ all_data["training"] = sorted_data_pairs
481
+ all_data["testing"] = test_data_list
482
+
483
+ with open(json_f_subset_path, "w") as j_file:
484
+ json.dump(all_data, j_file, indent=4)
485
+ j_file.close()
486
+
487
+
488
+ def omnipose_json_file(dataset_dir, json_path):
489
+ """
490
+ This function takes in the directory of extracted Omnipose dataset as input
491
+ and creates a json list with 5 folds. Please note that only 'bact_phase' and 'bact_fluor' were
492
+ used for creating datasets as they have bacteria the other directiories are worms. Each directory
493
+ has 'train_sorted' and 'test_sorted'.Separate testing set is recorded in the json list.
494
+ Please note that there are some hard-coded directory names as per the original dataset.
495
+ """
496
+ # Define the folders
497
+ op_list = ["bact_fluor", "bact_phase"]
498
+ for each_op in op_list:
499
+ print(f"Working on {each_op} ...")
500
+ images_folder = os.path.join(dataset_dir, each_op, "train_sorted")
501
+ test_images_folder = os.path.join(dataset_dir, each_op, "test_sorted")
502
+ json_f_path = os.path.join(json_path, f"op_{each_op}.json")
503
+
504
+ # Initialize the list for training data
505
+ training_data = []
506
+
507
+ # Loop through each image file to find its corresponding label file
508
+ sub_dirs = os.listdir(images_folder)
509
+ # Likely Omnipose dataset was created using a Mac and hence the spare filename
510
+ sub_dirs.remove(".DS_Store")
511
+ for each_sub in sub_dirs:
512
+ # List files in the images folder
513
+ image_files = os.listdir(os.path.join(images_folder, each_sub))
514
+ for image_file in image_files:
515
+ # Extract the name without the extension
516
+ base_name = os.path.splitext(image_file)[0]
517
+
518
+ # Construct the label file name by adding '_label' before the extension
519
+ label_file = base_name + "_masks.tif" # + os.path.splitext(image_file)[1]
520
+ flows_file = base_name + "_flows.tif"
521
+ # Check if the corresponding label file exists in the labels folder
522
+ if label_file in os.listdir(os.path.join(images_folder, each_sub)):
523
+ # Add the file names to the training data list
524
+ training_data.append(
525
+ {
526
+ "image": os.path.join("omnipose_dataset", each_op, "train_sorted", each_sub, image_file),
527
+ "label": os.path.join("omnipose_dataset", each_op, "train_sorted", each_sub, label_file),
528
+ "flows": os.path.join("omnipose_dataset", each_op, "train_sorted", each_sub, flows_file),
529
+ }
530
+ )
531
+
532
+ # Convert data_pairs to a numpy array for easy indexing by KFold
533
+ data_pairs_array = np.array(training_data)
534
+
535
+ # Initialize KFold
536
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
537
+
538
+ # Assign fold numbers
539
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
540
+ for idx in val_index:
541
+ data_pairs_array[idx]["fold"] = fold
542
+
543
+ # Convert the array back to a list and sort by fold
544
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
545
+
546
+ # Initialize the list for testing data
547
+ testing_data = []
548
+
549
+ test_sub_dirs = os.listdir(test_images_folder)
550
+ # Likely Omnipose dataset was created using a Mac and hence the spare filename
551
+ test_sub_dirs.remove(".DS_Store")
552
+ # Loop through each image file to find its corresponding label file
553
+ for each_test_sub in test_sub_dirs:
554
+ # List files in the images folder
555
+ test_image_files = os.listdir(os.path.join(test_images_folder, each_test_sub))
556
+ for image_file in test_image_files:
557
+ # Extract the name without the extension
558
+ base_name = os.path.splitext(image_file)[0]
559
+
560
+ # Construct the label file name by adding '_label' before the extension
561
+ label_file = base_name + "_masks.tif" # + os.path.splitext(image_file)[1]
562
+
563
+ # Check if the corresponding label file exists in the labels folder
564
+ if label_file in os.listdir(os.path.join(test_images_folder, each_test_sub)):
565
+ # Add the file names to the training data list
566
+ testing_data.append(
567
+ {
568
+ "image": os.path.join(
569
+ "omnipose_dataset", each_op, "test_sorted", each_test_sub, image_file
570
+ ),
571
+ "label": os.path.join(
572
+ "omnipose_dataset", each_op, "test_sorted", each_test_sub, label_file
573
+ ),
574
+ }
575
+ )
576
+
577
+ all_data = {}
578
+ all_data["training"] = sorted_data_pairs
579
+ all_data["testing"] = testing_data
580
+
581
+ # Save the training data list to a JSON file
582
+ with open(json_f_path, "w") as json_file:
583
+ json.dump(all_data, json_file, indent=4)
584
+
585
+
586
+ def nips_json_file(dataset_dir, json_f_path):
587
+ """
588
+ This function takes in the directory of extracted NIPS cell segmentation challenge as input
589
+ and creates a json list with 5 folds. Separate testing set is recorded in the json list.
590
+ Please note that there are some hard-coded directory names as per the original dataset.
591
+ At the time of creation, the NIPS zipfile had 'Training-labeled' and 'Testing' directories that
592
+ both contained 'images' and 'labels' directories
593
+ """
594
+ # The directory containing your files
595
+ json_save_path = os.path.normpath(json_f_path)
596
+ directory = os.path.join(dataset_dir, "Training-labeled")
597
+ test_directory = os.path.join(dataset_dir, "Testing", "Public")
598
+ # List to hold all image-mask pairs
599
+ data_pairs = []
600
+ test_data_pairs = []
601
+ all_data = {}
602
+
603
+ # Scan the directory for image files and create pairs
604
+ for filename in os.listdir(os.path.join(directory, "images")):
605
+ if os.path.exists(os.path.join(directory, "images", filename)):
606
+ # Extract the name without the extension
607
+ base_name = os.path.splitext(filename)[0]
608
+
609
+ # Construct the label file name by adding '_label' before the extension
610
+ label_file = base_name + "_label.tiff" # + os.path.splitext(image_file)[1]
611
+
612
+ # Check if the corresponding label file exists in the labels folder
613
+ if label_file in os.listdir(os.path.join(directory, "labels")):
614
+ # Add the file names to the training data list
615
+ data_pairs.append(
616
+ {
617
+ "image": os.path.join("nips_dataset", "Training-labeled", "images", filename),
618
+ "label": os.path.join("nips_dataset", "Training-labeled", "labels", label_file),
619
+ }
620
+ )
621
+
622
+ # Convert data_pairs to a numpy array for easy indexing by KFold
623
+ data_pairs_array = np.array(data_pairs)
624
+
625
+ # Initialize KFold
626
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
627
+
628
+ # Assign fold numbers
629
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
630
+ for idx in val_index:
631
+ data_pairs_array[idx]["fold"] = fold
632
+
633
+ # Convert the array back to a list and sort by fold
634
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
635
+
636
+ print(sorted_data_pairs)
637
+
638
+ # Scan the directory for image files and create pairs
639
+ for filename in os.listdir(os.path.join(test_directory, "images")):
640
+ if os.path.exists(os.path.join(test_directory, "images", filename)):
641
+ # Extract the name without the extension
642
+ base_name = os.path.splitext(filename)[0]
643
+
644
+ # Construct the label file name by adding '_label' before the extension
645
+ label_file = base_name + "_label.tiff" # + os.path.splitext(image_file)[1]
646
+
647
+ # Check if the corresponding label file exists in the labels folder
648
+ if label_file in os.listdir(os.path.join(test_directory, "labels")):
649
+ # Add the file names to the training data list
650
+ test_data_pairs.append(
651
+ {
652
+ "image": os.path.join("nips_dataset", "Testing", "Public", "images", filename),
653
+ "label": os.path.join("nips_dataset", "Testing", "Public", "labels", label_file),
654
+ }
655
+ )
656
+
657
+ all_data["training"] = sorted_data_pairs
658
+ all_data["testing"] = test_data_pairs
659
+
660
+ with open(json_save_path, "w") as j_file:
661
+ json.dump(all_data, j_file, indent=4)
662
+ j_file.close()
663
+
664
+
665
+ def kaggle_json_file(dataset_dir, json_f_path):
666
+ """
667
+ This function takes in the directory of kaggle nuclei extracted dataset as input and
668
+ creates a json list with 5 folds.
669
+ Please note that there are some hard-coded directory names as per the original dataset.
670
+ The function creates an instance processed dataset and then a 5 fold json file based on
671
+ the instance processed dataset
672
+ """
673
+ data_dir = os.path.join(dataset_dir, "stage1_train")
674
+ saving_path = os.path.join(dataset_dir, "instance_processed_data")
675
+ if not os.path.exists(saving_path):
676
+ os.mkdir(saving_path)
677
+
678
+ # Process the images and create instance masks first
679
+ for idx, subdir in enumerate(os.listdir(data_dir)):
680
+ subdir_path = os.path.join(data_dir, subdir)
681
+ if os.path.isdir(subdir_path):
682
+ images_dir = os.path.join(subdir_path, "images")
683
+ masks_dir = os.path.join(subdir_path, "masks")
684
+ if os.path.isdir(images_dir) and os.path.isdir(masks_dir):
685
+ image_file = os.path.join(images_dir, os.listdir(images_dir)[0])
686
+ filename_prefix = f"kg_bowl_{idx}_"
687
+
688
+ mask_data = concatenate_masks(masks_dir)
689
+
690
+ # ## Apply channel-wise normalization and use only the first three channels
691
+ # image_data = imageio.imread(image_file)
692
+ # normalized_image = normalize_image(image_data[..., :3])
693
+ # imageio.imwrite(os.path.join(saving_path, f"{filename_prefix}img.tiff"), normalized_image)
694
+ shutil.copyfile(image_file, os.path.join(saving_path, f"{filename_prefix}img.png"))
695
+ imageio.imwrite(os.path.join(saving_path, f"{filename_prefix}img_masks.tiff"), mask_data)
696
+
697
+ directory = saving_path
698
+
699
+ # List to hold all image-mask pairs
700
+ data_pairs = []
701
+ all_data = {}
702
+
703
+ # Scan the directory for image files and create pairs
704
+ for filename in os.listdir(directory):
705
+ if filename.endswith("_img.png"):
706
+ # Construct the corresponding mask filename
707
+ mask_filename = filename.replace("_img.png", "_img_masks.tiff")
708
+
709
+ # Check if the corresponding mask file exists
710
+ if os.path.exists(os.path.join(directory, mask_filename)):
711
+ # Add the pair to the list
712
+ data_pairs.append(
713
+ {
714
+ "image": os.path.join("kaggle_dataset", "instance_processed_data", filename),
715
+ "label": os.path.join("kaggle_dataset", "instance_processed_data", mask_filename),
716
+ }
717
+ )
718
+
719
+ # Convert data_pairs to a numpy array for easy indexing by KFold
720
+ data_pairs_array = np.array(data_pairs)
721
+
722
+ # Initialize KFold
723
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
724
+
725
+ # Assign fold numbers
726
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
727
+ for idx in val_index:
728
+ data_pairs_array[idx]["fold"] = fold
729
+
730
+ # Convert the array back to a list and sort by fold
731
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
732
+
733
+ print(sorted_data_pairs)
734
+
735
+ all_data["training"] = sorted_data_pairs
736
+
737
+ with open(json_f_path, "w") as j_file:
738
+ json.dump(all_data, j_file, indent=4)
739
+ j_file.close()
740
+
741
+
742
+ def deepbacs_json_file(dataset_dir, json_f_path):
743
+ """
744
+ This function takes in the directory of deepbacs extracted dataset as input and
745
+ creates a json list with 5 folds. Separate testing set is recorded in the json list.
746
+ Please note that there are some hard-coded directory names as per the original dataset.
747
+ At the time of creation, the deepbacs zipfile had 'training' and 'test' directories that
748
+ both contained 'source' and 'target' directories
749
+ """
750
+ # The directory containing your files
751
+ json_save_path = os.path.normpath(json_f_path)
752
+ directory = os.path.join(dataset_dir, "training")
753
+ test_directory = os.path.join(dataset_dir, "test")
754
+ # List to hold all image-mask pairs
755
+ data_pairs = []
756
+ test_data_pairs = []
757
+ all_data = {}
758
+
759
+ # Scan the directory for image files and create pairs
760
+ for filename in os.listdir(os.path.join(directory, "source")):
761
+ if os.path.exists(os.path.join(directory, "source", filename)):
762
+ # Construct the corresponding mask filename
763
+ mask_filename = filename
764
+
765
+ # Check if the corresponding mask file exists
766
+ if os.path.exists(os.path.join(directory, "target", mask_filename)):
767
+ # Add the pair to the list
768
+ data_pairs.append(
769
+ {
770
+ "image": os.path.join("deepbacs_dataset", "training", "source", filename),
771
+ "label": os.path.join("deepbacs_dataset", "training", "target", mask_filename),
772
+ }
773
+ )
774
+
775
+ # Convert data_pairs to a numpy array for easy indexing by KFold
776
+ data_pairs_array = np.array(data_pairs)
777
+
778
+ # Initialize KFold
779
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
780
+
781
+ # Assign fold numbers
782
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
783
+ for idx in val_index:
784
+ data_pairs_array[idx]["fold"] = fold
785
+
786
+ # Convert the array back to a list and sort by fold
787
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
788
+
789
+ print(sorted_data_pairs)
790
+
791
+ # Scan the directory for image files and create pairs
792
+ for filename in os.listdir(os.path.join(test_directory, "source")):
793
+ if os.path.exists(os.path.join(test_directory, "source", filename)):
794
+ # Construct the corresponding mask filename
795
+ mask_filename = filename
796
+
797
+ # Check if the corresponding mask file exists
798
+ if os.path.exists(os.path.join(test_directory, "target", mask_filename)):
799
+ # Add the pair to the list
800
+ test_data_pairs.append(
801
+ {
802
+ "image": os.path.join("deepbacs_dataset", "test", "source", filename),
803
+ "label": os.path.join("deepbacs_dataset", "test", "target", filename),
804
+ }
805
+ )
806
+
807
+ all_data["training"] = sorted_data_pairs
808
+ all_data["testing"] = test_data_pairs
809
+
810
+ with open(json_save_path, "w") as j_file:
811
+ json.dump(all_data, j_file, indent=4)
812
+ j_file.close()
813
+
814
+
815
+ def cellpose_json_file(dataset_dir, json_f_path):
816
+ """
817
+ This function takes in the directory of cellpose extracted dataset as input and
818
+ creates a json list with 5 folds. Separate testing set is recorded in the json list.
819
+ Please note that there are some hard-coded directory names as per the original dataset.
820
+ At the time of creation, the cellpose dataset had 'train.zip' and 'test.zip' that
821
+ extracted as 'train' and 'test' directories
822
+ """
823
+ # The directory containing your files
824
+ json_save_path = os.path.normpath(json_f_path)
825
+ directory = os.path.join(dataset_dir, "train")
826
+ test_directory = os.path.join(dataset_dir, "test")
827
+
828
+ # List to hold all image-mask pairs
829
+ data_pairs = []
830
+ test_data_pairs = []
831
+ all_data = {}
832
+
833
+ # Scan the directory for image files and create pairs
834
+ for filename in os.listdir(directory):
835
+ if filename.endswith("_img.png"):
836
+ # Construct the corresponding mask filename
837
+ mask_filename = filename.replace("_img.png", "_masks.png")
838
+
839
+ # Check if the corresponding mask file exists
840
+ if os.path.exists(os.path.normpath(os.path.join(directory, mask_filename))):
841
+ # Add the pair to the list
842
+ data_pairs.append(
843
+ {
844
+ "image": os.path.join("cellpose_dataset", "train", filename),
845
+ "label": os.path.join("cellpose_dataset", "train", mask_filename),
846
+ }
847
+ )
848
+
849
+ # Convert data_pairs to a numpy array for easy indexing by KFold
850
+ data_pairs_array = np.array(data_pairs)
851
+
852
+ # Initialize KFold
853
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
854
+
855
+ # Assign fold numbers
856
+ for fold, (_train_index, val_index) in enumerate(kf.split(data_pairs_array)):
857
+ for idx in val_index:
858
+ data_pairs_array[idx]["fold"] = fold
859
+
860
+ # Convert the array back to a list and sort by fold
861
+ sorted_data_pairs = sorted(data_pairs_array.tolist(), key=lambda x: x["fold"])
862
+
863
+ print(sorted_data_pairs)
864
+
865
+ # Scan the directory for image files and create pairs
866
+ for filename in os.listdir(test_directory):
867
+ if filename.endswith("_img.png"):
868
+ # Construct the corresponding mask filename
869
+ mask_filename = filename.replace("_img.png", "_masks.png")
870
+
871
+ # Check if the corresponding mask file exists
872
+ if os.path.exists(os.path.join(directory, mask_filename)):
873
+ # Add the pair to the list
874
+ test_data_pairs.append(
875
+ {
876
+ "image": os.path.join("cellpose_dataset", "test", filename),
877
+ "label": os.path.join("cellpose_dataset", "test", mask_filename),
878
+ }
879
+ )
880
+
881
+ all_data["training"] = sorted_data_pairs
882
+ all_data["testing"] = test_data_pairs
883
+
884
+ with open(json_save_path, "w") as j_file:
885
+ json.dump(all_data, j_file, indent=4)
886
+ j_file.close()
887
+
888
+
889
+ def extract_zip(zip_path, extract_to):
890
+ # Ensure the target directory exists
891
+ print(f"Extracting from: {zip_path}")
892
+ print(f"Extracting to: {extract_to}")
893
+
894
+ if not os.path.exists(extract_to):
895
+ os.makedirs(extract_to)
896
+
897
+ # Extract all contents of the zip file to the specified directory
898
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
899
+ zip_ref.extractall(extract_to)
900
+
901
+
902
+ def main():
903
+ parser = argparse.ArgumentParser(description="Process some integers.")
904
+ parser.add_argument("--dir", type=str, help="Directory of datasets to generate json", default="/set/the/path")
905
+
906
+ args = parser.parse_args()
907
+ data_root_path = os.path.normpath(args.dir)
908
+
909
+ if not os.path.exists(os.path.join(data_root_path, "json_files")):
910
+ os.mkdir(os.path.join(data_root_path, "json_files"))
911
+
912
+ dataset_dict = {
913
+ "cellpose_dataset": ["train.zip", "test.zip"],
914
+ "deepbacs_dataset": ["deepbacs.zip"],
915
+ "kaggle_dataset": ["data-science-bowl-2018.zip"],
916
+ "nips_dataset": ["nips_train.zip", "nips_test.zip"],
917
+ "omnipose_dataset": ["datasets.zip"],
918
+ "tissuenet_dataset": ["tissuenet_v1.0.zip"],
919
+ "livecell_dataset": [
920
+ "livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/images_per_celltype.zip"
921
+ ],
922
+ }
923
+
924
+ for key, value in dataset_dict.items():
925
+ dataset_path = os.path.join(data_root_path, key)
926
+
927
+ for each_zipped in value:
928
+ in_path = os.path.join(dataset_path, each_zipped)
929
+ try:
930
+ if os.path.exists(in_path):
931
+ print(f"File exists at: {in_path}")
932
+ except Exception:
933
+ print(f"File: {in_path} was not found")
934
+ out_path = os.path.join(dataset_path)
935
+ extract_zip(in_path, out_path)
936
+
937
+ print(
938
+ "If we reached here, that means all zip files got extracted ... Working on pre-processing and generating json files"
939
+ )
940
+
941
+ # Looping over all datasets again, Cellpose & Deepbacs have a similar directory structure
942
+ for key, _value in dataset_dict.items():
943
+ if key == "cellpose_dataset":
944
+ print("Creating Cellpose Dataset Json file ...")
945
+ dataset_path = os.path.join(data_root_path, key)
946
+ json_path = os.path.join(data_root_path, "json_files", "cellpose.json")
947
+ cellpose_json_file(dataset_dir=dataset_path, json_f_path=json_path)
948
+
949
+ elif key == "nips_dataset":
950
+ print("Creating NIPS Dataset Json file ...")
951
+ dataset_path = os.path.join(data_root_path, key)
952
+ json_path = os.path.join(data_root_path, "json_files", "nips.json")
953
+ nips_json_file(dataset_dir=dataset_path, json_f_path=json_path)
954
+
955
+ elif key == "omnipose_dataset":
956
+ print("Creating Omnipose Dataset Json files ...")
957
+ dataset_path = os.path.join(data_root_path, key)
958
+ json_path = os.path.join(data_root_path, "json_files")
959
+ omnipose_json_file(dataset_dir=dataset_path, json_path=json_path)
960
+
961
+ elif key == "kaggle_dataset":
962
+ print("Needs additional extraction")
963
+ train_zip_path = os.path.join(data_root_path, key, "stage1_train.zip")
964
+ zip_out_path = os.path.join(data_root_path, key, "stage1_train")
965
+ extract_zip(train_zip_path, zip_out_path)
966
+ print("Creating Kaggle Dataset Json files ...")
967
+ dataset_path = os.path.join(data_root_path, key)
968
+ json_f_path = os.path.join(data_root_path, "json_files", "kaggle.json")
969
+ kaggle_json_file(dataset_dir=dataset_path, json_f_path=json_f_path)
970
+
971
+ elif key == "livecell_dataset":
972
+ print("Creating LiveCell Dataset Json files ... Please note that 7 files will be created from livecell")
973
+ dataset_path = os.path.join(data_root_path, key)
974
+ json_base_name = os.path.join(data_root_path, "json_files")
975
+ livecell_json_files(dataset_dir=dataset_path, json_f_path=json_base_name)
976
+
977
+ elif key == "deepbacs_dataset":
978
+ print("Creating Deepbacs Dataset Json file ...")
979
+ dataset_path = os.path.join(data_root_path, key)
980
+ json_path = os.path.join(data_root_path, "json_files", "deepbacs.json")
981
+ deepbacs_json_file(dataset_dir=dataset_path, json_f_path=json_path)
982
+
983
+ elif key == "tissuenet_dataset":
984
+ print("Creating TissueNet Dataset Json files ... Please note that 13 files will be created from tissuenet")
985
+ dataset_path = os.path.join(data_root_path, key)
986
+ json_base_name = os.path.join(data_root_path, "json_files")
987
+ tissuenet_json_files(dataset_dir=dataset_path, json_f_path=json_base_name)
988
+
989
+ return None
990
+
991
+
992
+ if __name__ == "__main__":
993
+ main()
download_preprocessor/kaggle_download.png ADDED

Git LFS Details

  • SHA256: 6b5f428931c960a335240014964e2678c511297dd0ae0cd5a68809075b558b22
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
download_preprocessor/omnipose_download.png ADDED

Git LFS Details

  • SHA256: a441c4a02c1f215c108909340549d1e73b34c5284dba317aa9dd634786f292cd
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
download_preprocessor/process_data.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ import gc
14
+ import os
15
+ import shutil
16
+ import time
17
+ import warnings
18
+ import zipfile
19
+
20
+ import imageio.v3 as imageio
21
+ import numpy as np
22
+ from PIL import Image
23
+ from pycocotools.coco import COCO
24
+
25
+
26
+ def min_label_precision(label):
27
+ lm = label.max()
28
+
29
+ if lm <= 255:
30
+ label = label.astype(np.uint8)
31
+ elif lm <= 65535:
32
+ label = label.astype(np.uint16)
33
+ else:
34
+ label = label.astype(np.uint32)
35
+
36
+ return label
37
+
38
+
39
+ def guess_convert_to_uint16(img, margin=30):
40
+ """
41
+ Guess a multiplier that makes all pixels integers.
42
+ The input img (each channel) is already in the range 0..1, they must have been converted from uint16 integers as image / scale,
43
+ where scale was the unknown max intensity.
44
+ We could guess the scale by looking at unique values: 1/np.min(np.diff(np.unique(im)).
45
+ the hypothesis is that it will be more accurate recovery of the original image,
46
+ instead of doing a simple (img*65535).astype(np.uint16)
47
+ """
48
+
49
+ for i in range(img.shape[0]):
50
+ im = img[i]
51
+
52
+ if im.any():
53
+ start = time.time()
54
+ imsmall = im[::4, ::4] # subsample
55
+ # imsmall = im
56
+
57
+ scale = int(np.round(1 / np.min(np.diff(np.unique(imsmall))))) # guessing scale
58
+ test = [
59
+ (np.sum((imsmall * k) % 1)) for k in range(scale - margin, scale + margin)
60
+ ] # finetune, guess a multiplier that makes all pixels integers
61
+ sid = np.argmin(test) # fine tune scale
62
+
63
+ if scale < 16000 or scale > 16400:
64
+ warnings.warn("scale not in expected range")
65
+ print(
66
+ "guessing scale",
67
+ scale,
68
+ test[margin],
69
+ "fine tuning scale",
70
+ scale - margin + sid,
71
+ "dif",
72
+ test[sid],
73
+ "time",
74
+ time.time() - start,
75
+ )
76
+
77
+ scale = 16384
78
+ else:
79
+ scale = scale - margin + sid
80
+ # all the recovered scale values seems to be up to 16384,
81
+ # we can stretch to 65535 (for better visualization, most tiff viewers expect that range)
82
+ scale = min(65535, scale * 4)
83
+ img[i] = im * scale
84
+
85
+ img = img.astype(np.uint16)
86
+ return img
87
+
88
+
89
+ def concatenate_masks(mask_dir):
90
+ labeled_mask = None
91
+ i = 0
92
+ for filename in sorted(os.listdir(mask_dir)):
93
+ if filename.endswith(".png"):
94
+ mask = imageio.imread(os.path.join(mask_dir, filename))
95
+ if labeled_mask is None:
96
+ labeled_mask = np.zeros(shape=mask.shape, dtype=np.uint16)
97
+ labeled_mask[mask > 0] = i
98
+ i = i + 1
99
+
100
+ if i <= 255:
101
+ labeled_mask = labeled_mask.astype(np.uint8)
102
+
103
+ return labeled_mask
104
+
105
+
106
+ def get_filenames_exclude_masks(dir1, target_string):
107
+ filenames = []
108
+ # Combine lists of files from both directories
109
+ files = os.listdir(dir1)
110
+ # Filter files that contain the target string but exclude 'masks'
111
+ filenames = [f for f in files if target_string in f and "masks" not in f]
112
+
113
+ return filenames
114
+
115
+
116
+ def remove_overlaps(masks, medians, overlap_threshold=0.75):
117
+ """replace overlapping mask pixels with mask id of closest mask
118
+ if mask fully within another mask, remove it
119
+ masks = Nmasks x Ly x Lx
120
+ """
121
+ cellpix = masks.sum(axis=0)
122
+ igood = np.ones(masks.shape[0], "bool")
123
+ for i in masks.sum(axis=(1, 2)).argsort():
124
+ npix = float(masks[i].sum())
125
+ noverlap = float(masks[i][cellpix > 1].sum())
126
+ if noverlap / npix >= overlap_threshold:
127
+ igood[i] = False
128
+ cellpix[masks[i] > 0] -= 1
129
+ # print(cellpix.min())
130
+ print(f"removing {(~igood).sum()} masks")
131
+ masks = masks[igood]
132
+ medians = medians[igood]
133
+ cellpix = masks.sum(axis=0)
134
+ overlaps = np.array(np.nonzero(cellpix > 1.0)).T
135
+ dists = ((overlaps[:, :, np.newaxis] - medians.T) ** 2).sum(axis=1)
136
+ tocell = np.argmin(dists, axis=1)
137
+ masks[:, overlaps[:, 0], overlaps[:, 1]] = 0
138
+ masks[tocell, overlaps[:, 0], overlaps[:, 1]] = 1
139
+
140
+ # labels should be 1 to mask.shape[0]
141
+ masks = masks.astype(int) * np.arange(1, masks.shape[0] + 1, 1, int)[:, np.newaxis, np.newaxis]
142
+ masks = masks.sum(axis=0)
143
+ gc.collect()
144
+ return masks
145
+
146
+
147
+ def livecell_process_files(dataset_dir):
148
+ """
149
+ This function takes in the directory of livecell extracted dataset as input and
150
+ extracts labels from the coco format.
151
+ """
152
+
153
+ # "A172", "BT474", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"
154
+ # "BV2" is being skipped, runs into memory constraints
155
+ cell_type_list = ["A172", "BT474", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]
156
+ for each_cell_tp in cell_type_list:
157
+ for split in ["train", "val", "test"]:
158
+ print(f"Working on split: {split}")
159
+
160
+ if split == "test":
161
+ img_path = os.path.join(dataset_dir, "images", "livecell_test_images", each_cell_tp)
162
+ msk_path = os.path.join(dataset_dir, "images", "livecell_test_images", each_cell_tp + "_masks")
163
+ else:
164
+ img_path = os.path.join(dataset_dir, "images", "livecell_train_val_images", each_cell_tp)
165
+ msk_path = os.path.join(dataset_dir, "images", "livecell_train_val_images", each_cell_tp + "_masks")
166
+ if not os.path.exists(msk_path):
167
+ os.makedirs(msk_path)
168
+
169
+ # annotation path
170
+ path = os.path.join(
171
+ dataset_dir,
172
+ "livecell-dataset.s3.eu-central-1.amazonaws.com",
173
+ "LIVECell_dataset_2021",
174
+ "annotations",
175
+ "LIVECell_single_cells",
176
+ each_cell_tp.lower(),
177
+ split + ".json",
178
+ )
179
+ annotation = COCO(path)
180
+ # Convert COCO format segmentation to binary mask
181
+ images = annotation.loadImgs(annotation.getImgIds())
182
+ height = []
183
+ width = []
184
+ for index, im in enumerate(images):
185
+ print("Status: {}/{}, Process image: {}".format(index, len(images), im["file_name"]))
186
+ if (
187
+ im["file_name"] == "BV2_Phase_C4_2_03d00h00m_1.tif"
188
+ or im["file_name"] == "BV2_Phase_C4_2_03d00h00m_3.tif"
189
+ ):
190
+ print("Skipping the file: BV2_Phase_C4_2_03d00h00m_1.tif, as it is troublesome")
191
+ continue
192
+ # load image
193
+ img = Image.open(os.path.join(img_path, im["file_name"])).convert("L")
194
+ height.append(img.size[0])
195
+ width.append(img.size[1])
196
+
197
+ # load and display instance annotations
198
+ annids = annotation.getAnnIds(imgIds=im["id"], iscrowd=None)
199
+ anns = annotation.loadAnns(annids)
200
+
201
+ medians = []
202
+ masks = []
203
+ k = 0
204
+ for ann in anns:
205
+ # convert segmentation to binary mask
206
+ mask = annotation.annToMask(ann)
207
+ masks.append(mask)
208
+ ypix, xpix = mask.nonzero()
209
+ medians.append(np.array([ypix.mean().astype(np.float32), xpix.mean().astype(np.float32)]))
210
+ k += 1
211
+
212
+ masks = np.array(masks).astype(np.int8)
213
+ medians = np.array(medians)
214
+ masks = remove_overlaps(masks, medians, overlap_threshold=0.75)
215
+ gc.collect()
216
+
217
+ # ## Create new name for the image and also for the mask and save them as .tif format
218
+ # masks_int32 = masks.astype(np.int32)
219
+ # mask_pil = Image.fromarray(masks_int32, 'I')
220
+
221
+ t_filename = im["file_name"]
222
+ # cell_type = t_filename.split('_')[0] #? not used
223
+ new_mask_name = t_filename[:-4] + "_masks.tif"
224
+ # mask_pil.save(os.path.join(msk_path, new_mask_name))
225
+ imageio.imwrite(os.path.join(msk_path, new_mask_name), min_label_precision(masks))
226
+ gc.collect()
227
+
228
+ print(f"In total {len(images)} images")
229
+
230
+
231
+ def tissuenet_process_files(dataset_dir):
232
+ """
233
+ This function takes in the directory of TissueNet extracted dataset as input and
234
+ creates tiled images into 4 from each image
235
+ """
236
+
237
+ for folder in ["train", "val", "test"]:
238
+ if not os.path.exists(os.path.join(dataset_dir, "tissuenet_1.0", folder)):
239
+ os.mkdir(os.path.join(dataset_dir, "tissuenet_1.0", folder))
240
+
241
+ for folder in ["train", "val", "test"]:
242
+ print(f"Working on {folder} directory of tissuenet")
243
+ f_name = f"tissuenet_1.0/tissuenet_v1.0_{folder}.npz"
244
+ dat = np.load(os.path.join(dataset_dir, f_name))
245
+ data = dat["X"]
246
+ labels = dat["y"]
247
+ tissues = dat["tissue_list"]
248
+ platforms = dat["platform_list"]
249
+ tlabels = np.unique(tissues)
250
+ plabels = np.unique(platforms)
251
+ tp = 0
252
+ for t in tlabels:
253
+ for p in plabels:
254
+ ix = ((tissues == t) * (platforms == p)).nonzero()[0]
255
+ tp += 1
256
+ if len(ix) > 0:
257
+ print(f"Working on {t} {p}")
258
+
259
+ for k, i in enumerate(ix):
260
+ print(f"Status: {k}/{len(ix)} {tp}/{len(tlabels) * len(plabels)} {t} {p}")
261
+ img = data[i].transpose(2, 0, 1)
262
+ label = labels[i][:, :, 0]
263
+
264
+ img = guess_convert_to_uint16(img) # guess inverse scale and convert to uint16
265
+ label = min_label_precision(label)
266
+
267
+ if folder == "train":
268
+ img = img.reshape(2, 2, 256, 2, 256).transpose(0, 1, 3, 2, 4).reshape(2, 4, 256, 256)
269
+ label = label.reshape(2, 256, 2, 256).transpose(0, 2, 1, 3).reshape(4, 256, 256)
270
+
271
+ zero_channel = np.zeros((1, img.shape[1], img.shape[2], img.shape[3]), dtype=img.dtype)
272
+
273
+ # Concatenate the zero channel with the original array along the first dimension
274
+ new_array = np.concatenate([img, zero_channel], axis=0)
275
+ # reshaped_array = np.transpose(new_array, (1, 2, 3, 0))
276
+ for j in range(4):
277
+ img_name = f"{folder}/{t}_{p}_{k}_{j}.tif"
278
+ mask_name = f"{folder}/{t}_{p}_{k}_{j}_masks.tif"
279
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", img_name), new_array[:, j])
280
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", mask_name), label[j])
281
+ else:
282
+ zero_channel = np.zeros((1, img.shape[1], img.shape[2]), dtype=img.dtype)
283
+ new_array = np.concatenate([img, zero_channel], axis=0)
284
+ # reshaped_array = np.transpose(new_array, (1, 2, 0))
285
+ img_name = f"{folder}/{t}_{p}_{k}.tif"
286
+ mask_name = f"{folder}/{t}_{p}_{k}_masks.tif"
287
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", img_name), new_array)
288
+ imageio.imwrite(os.path.join(dataset_dir, "tissuenet_1.0", mask_name), label)
289
+
290
+
291
+ def kaggle_process_files(dataset_dir):
292
+ """
293
+ This function takes in the directory of kaggle nuclei extracted dataset as input and
294
+ creates a json list with 5 folds.
295
+ Please note that there are some hard-coded directory names as per the original dataset.
296
+ The function creates an instance processed dataset and then a 5 fold json file based on
297
+ the instance processed dataset
298
+ """
299
+ data_dir = os.path.join(dataset_dir, "stage1_train")
300
+ saving_path = os.path.join(dataset_dir, "instance_processed_data")
301
+ if not os.path.exists(saving_path):
302
+ os.mkdir(saving_path)
303
+
304
+ # Process the images and create instance masks first
305
+ for idx, subdir in enumerate(os.listdir(data_dir)):
306
+ subdir_path = os.path.join(data_dir, subdir)
307
+ if os.path.isdir(subdir_path):
308
+ images_dir = os.path.join(subdir_path, "images")
309
+ masks_dir = os.path.join(subdir_path, "masks")
310
+ if os.path.isdir(images_dir) and os.path.isdir(masks_dir):
311
+ image_file = os.path.join(images_dir, os.listdir(images_dir)[0])
312
+ filename_prefix = f"kg_bowl_{idx}_"
313
+
314
+ mask_data = concatenate_masks(masks_dir)
315
+
316
+ # ## Apply channel-wise normalization and use only the first three channels
317
+ # image_data = imageio.imread(image_file)
318
+ # normalized_image = normalize_image(image_data[..., :3])
319
+ # imageio.imwrite(os.path.join(saving_path, f"{filename_prefix}img.tiff"), normalized_image)
320
+ shutil.copyfile(image_file, os.path.join(saving_path, f"{filename_prefix}img.png"))
321
+ imageio.imwrite(os.path.join(saving_path, f"{filename_prefix}img_masks.tiff"), mask_data)
322
+
323
+
324
+ def extract_zip(zip_path, extract_to):
325
+ # Ensure the target directory exists
326
+ print(f"Extracting from: {zip_path}")
327
+ print(f"Extracting to: {extract_to}")
328
+
329
+ if not os.path.exists(extract_to):
330
+ os.makedirs(extract_to)
331
+
332
+ # Extract all contents of the zip file to the specified directory
333
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
334
+ zip_ref.extractall(extract_to)
335
+
336
+
337
+ def main():
338
+ parser = argparse.ArgumentParser(description="Script to process the cell imaging datasets")
339
+ parser.add_argument("--dir", type=str, help="Directory of datasets to process it ...", default="/set/the/path")
340
+
341
+ args = parser.parse_args()
342
+ data_root_path = os.path.normpath(args.dir)
343
+
344
+ dataset_dict = {
345
+ "cellpose_dataset": ["train.zip", "test.zip"],
346
+ "deepbacs_dataset": ["deepbacs.zip"],
347
+ "kaggle_dataset": ["data-science-bowl-2018.zip"],
348
+ "nips_dataset": ["nips_train.zip", "nips_test.zip"],
349
+ "omnipose_dataset": ["datasets.zip"],
350
+ "tissuenet_dataset": ["tissuenet_v1.0.zip"],
351
+ "livecell_dataset": [
352
+ "livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/images_per_celltype.zip"
353
+ ],
354
+ }
355
+
356
+ for key, value in dataset_dict.items():
357
+ dataset_path = os.path.join(data_root_path, key)
358
+
359
+ for each_zipped in value:
360
+ in_path = os.path.join(dataset_path, each_zipped)
361
+ try:
362
+ if os.path.exists(in_path):
363
+ print(f"File exists at: {in_path}")
364
+ except Exception:
365
+ print(f"File: {in_path} was not found")
366
+ out_path = os.path.join(dataset_path)
367
+ extract_zip(in_path, out_path)
368
+
369
+ print("If we reached here, that means all zip files got extracted ... Working on pre-processing")
370
+
371
+ # Looping over all datasets again, Cellpose & Deepbacs have a similar directory structure
372
+ for key, _value in dataset_dict.items():
373
+ if key == "kaggle_dataset":
374
+ print("Needs additional extraction")
375
+ train_zip_path = os.path.join(data_root_path, key, "stage1_train.zip")
376
+ zip_out_path = os.path.join(data_root_path, key, "stage1_train")
377
+ extract_zip(train_zip_path, zip_out_path)
378
+ print("Processing Kaggle Dataset ...")
379
+ dataset_path = os.path.join(data_root_path, key)
380
+ kaggle_process_files(dataset_dir=dataset_path)
381
+
382
+ elif key == "livecell_dataset":
383
+ print("Processing LiveCell Dataset ...")
384
+ print(
385
+ "Fyi, this processing might take upto an hour, coffee break might be more fruitful in the meanwhile ..."
386
+ )
387
+ dataset_path = os.path.join(data_root_path, key)
388
+ livecell_process_files(dataset_dir=dataset_path)
389
+
390
+ elif key == "tissuenet_dataset":
391
+ print("Processing TissueNet Dataset ...")
392
+ dataset_path = os.path.join(data_root_path, key)
393
+ tissuenet_process_files(dataset_dir=dataset_path)
394
+
395
+ return None
396
+
397
+
398
+ if __name__ == "__main__":
399
+ main()
download_preprocessor/readme.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Tutorial: VISTA2D Model Creation
2
+
3
+ This tutorial will guide the users to setting up all the datasets, running pre-processing, creation of organized json file lists which can be provided to VISTA-2D training pipeline.
4
+ Some datasets need to be manually downloaded, others will be downloaded by a provided script. Please do not manually unzip any of the downloaded files, it will be automatically handled in the final step.
5
+
6
+ ### List of Datasets
7
+ 1.) [Cellpose](https://www.cellpose.org/dataset)
8
+
9
+ 2.) [TissueNet](https://datasets.deepcell.org/login)
10
+
11
+ 3.) [Kaggle Nuclei Segmentation](https://www.kaggle.com/c/data-science-bowl-2018/data)
12
+
13
+ 4.) [Omnipose - OSF repository](https://osf.io/xmury/)
14
+
15
+ 5.) [NIPS Cell Segmentation Challenge](https://neurips22-cellseg.grand-challenge.org/)
16
+
17
+ 6.) [LiveCell](https://sartorius-research.github.io/LIVECell/)
18
+
19
+ 7.) [Deepbacs](https://github.com/HenriquesLab/DeepBacs/wiki/Segmentation)
20
+
21
+ Datasets 1-4 need to be manually downloaded, instructions to download them have been provided below.
22
+
23
+ ### Manual Dataset Download Instructions
24
+ #### 1.) Cellpose:
25
+ The dataset can be downloaded from this [link](https://www.cellpose.org/dataset). Please see below screenshots to assist in downloading it
26
+ ![cellpose_agreement.png](cellpose_agreement.png)
27
+ Please enter your email and accept terms and conditions to download the dataset.
28
+
29
+ ![cellpose_links.png](cellpose_links.png)
30
+ Click on train.zip and test.zip to download both directories independently. They both need to be placed in a `cellpose_dataset` directory. The `cellpose_dataset` will have to be created by the user in the root data directory.
31
+
32
+ #### 2.) TissueNet
33
+ Login credentials have to be created at below provided link. Please see below screenshots for further assistance.
34
+
35
+ ![tissuenet_login.png](tissuenet_login.png)
36
+ Please create an account at the provided [link](https://datasets.deepcell.org/login).
37
+
38
+ ![tissuenet_download.png](tissuenet_download.png)
39
+ After logging in, the above page will be visible, please make sure that version 1.0 is selected for TissueNet before clicking on download button.
40
+ All the downloaded files need to be placed in a `tissuenet_dataset` directory, this directory has to be created by the user.
41
+
42
+ #### 3.) Kaggle Nuclei Segmentation
43
+ Kaggle credentials are required in order to access this dataset at this [link](https://www.kaggle.com/c/data-science-bowl-2018/data), the user will have to register for the challenge to access and download the dataset.
44
+ Please refer below screenshots for additional help.
45
+
46
+ ![kaggle_download.png](kaggle_download.png)
47
+ The `Download All` button needs to be used so all files are downloaded, the files need to be placed in a directory created by the user `kaggle_dataset`.
48
+
49
+ #### 4.) Omnipose
50
+ The Omnipose dataset is hosted on an [OSF repository](https://osf.io/xmury/) and the dataset part needs to be downloaded from it. Please refer below screenshots for further assistance.
51
+
52
+ ![omnipose_download.png](omnipose_download.png)
53
+ The `datasets` directory needs to be selected as highlighted in the screenshot, then `download as zip` needs to be pressed for downloading the dataset. The user will have to place all the files in
54
+ a user created directory named `omnipose_dataset`.
55
+
56
+ ### The remaining datasets will be downloaded by a python script.
57
+ To run the script use the following example command `python all_file_downloader.py --dir provide_the_same_root_data_path`
58
+
59
+ After completion of downloading of all datasets, below is how the data root directory should look:
60
+
61
+ ![data_tree.png](data_tree.png)
62
+
63
+ ### Process the downloaded data
64
+ To execute VISTA-2D training pipeline, some datasets require label conversion. Please use the `root_data_path` as the input to the script, example command to execute the script is given below:
65
+
66
+ `python generate_json.py --dir provide_the_same_root_data_path`
67
+
68
+ ### Generation of Json data lists (Optional)
69
+ If one desires to generate JSON files from scratch, `generate_json.py` script performs both processing and creation of JSON files.
70
+ To execute VISTA-2D training pipeline, some datasets require label conversion and then a json file list which the VISTA-2D training uses a format.
71
+ Creating the json lists from the raw dataset sources, please use the `root_data_path` as the input to the script, example command to execute the script is given below:
72
+
73
+ `python generate_json.py --dir provide_the_same_root_data_path`
download_preprocessor/tissuenet_download.png ADDED

Git LFS Details

  • SHA256: 9960dd0d30ef4e8fa75195a7f99642f9e837c15bf36fade4dd4d1d89d1a5957e
  • Pointer size: 131 Bytes
  • Size of remote file: 570 kB
download_preprocessor/tissuenet_login.png ADDED
download_preprocessor/urls.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LICENSE
2
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/
3
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/
4
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell/
5
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/
6
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/a172/
7
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/a172/test.json
8
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/a172/train.json
9
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/a172/val.json
10
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bt474/
11
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bt474/test.json
12
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bt474/train.json
13
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bt474/val.json
14
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bv2/
15
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bv2/test.json
16
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bv2/train.json
17
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/bv2/val.json
18
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/huh7/
19
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/huh7/test.json
20
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/huh7/train.json
21
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/huh7/val.json
22
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/
23
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/test.json
24
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/train.json
25
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/val.json
26
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/shsy5y/
27
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/shsy5y/test.json
28
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/shsy5y/train.json
29
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/shsy5y/val.json
30
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skbr3/
31
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skbr3/test.json
32
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skbr3/train.json
33
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skbr3/val.json
34
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skov3/
35
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skov3/test.json
36
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skov3/train.json
37
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/LIVECell_single_cells/skov3/val.json
38
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/images_per_celltype.zip
39
+ http://livecell-dataset.s3.eu-central-1.amazonaws.com/README.md
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d857047f53e7a02a7709bed4f6f533c428b0a3a0932d203c91d9de296f2a1536
3
+ size 359948402
models/sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/cell_distributed_weighted_sampler.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # based on Pytorch DistributedSampler and WeightedRandomSampler combined
13
+
14
+ import math
15
+ from typing import Iterator, Optional, Sequence, TypeVar
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from torch.utils.data import Dataset, Sampler
20
+
21
+ __all__ = ["DistributedWeightedSampler"]
22
+
23
+ T_co = TypeVar("T_co", covariant=True)
24
+
25
+
26
+ class DistributedWeightedSampler(Sampler[T_co]):
27
+ def __init__(
28
+ self,
29
+ dataset: Dataset,
30
+ weights: Sequence[float],
31
+ num_samples: int,
32
+ num_replicas: Optional[int] = None,
33
+ rank: Optional[int] = None,
34
+ shuffle: bool = True,
35
+ seed: int = 0,
36
+ drop_last: bool = False,
37
+ ) -> None:
38
+ if not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0:
39
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
40
+
41
+ weights_tensor = torch.as_tensor(weights, dtype=torch.float)
42
+ if len(weights_tensor.shape) != 1:
43
+ raise ValueError(
44
+ "weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}"
45
+ )
46
+
47
+ self.weights = weights_tensor
48
+ self.num_samples = num_samples
49
+
50
+ if num_replicas is None:
51
+ if not dist.is_available():
52
+ raise RuntimeError("Requires distributed package to be available")
53
+ num_replicas = dist.get_world_size()
54
+ if rank is None:
55
+ if not dist.is_available():
56
+ raise RuntimeError("Requires distributed package to be available")
57
+ rank = dist.get_rank()
58
+ if rank >= num_replicas or rank < 0:
59
+ raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
60
+ self.dataset = dataset
61
+ self.num_replicas = num_replicas
62
+ self.rank = rank
63
+ self.epoch = 0
64
+ self.drop_last = drop_last
65
+ self.shuffle = shuffle
66
+
67
+ if self.shuffle:
68
+ self.num_samples = int(math.ceil(self.num_samples / self.num_replicas))
69
+ else:
70
+ # this is not used, as we always shuffle, the only reason to use this class
71
+
72
+ # If the dataset length is evenly divisible by # of replicas, then there
73
+ # is no need to drop any data, since the dataset will be split equally.
74
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
75
+ # Split to nearest available length that is evenly divisible.
76
+ # This is to ensure each rank receives the same amount of data when
77
+ # using this Sampler.
78
+ self.num_samples = math.ceil(
79
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
80
+ )
81
+ else:
82
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
83
+
84
+ self.total_size = self.num_samples * self.num_replicas
85
+ self.shuffle = shuffle
86
+ self.seed = seed
87
+
88
+ def __iter__(self) -> Iterator[T_co]:
89
+ if self.shuffle:
90
+ # deterministically shuffle based on epoch and seed
91
+ g = torch.Generator()
92
+ g.manual_seed(self.seed + self.epoch)
93
+ indices = torch.multinomial(input=self.weights, num_samples=self.total_size, replacement=True, generator=g).tolist() # type: ignore[arg-type]
94
+ else:
95
+ # this is not used, as we always shuffle, the only reason to use this class
96
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
97
+ if not self.drop_last:
98
+ # add extra samples to make it evenly divisible
99
+ padding_size = self.total_size - len(indices)
100
+ if padding_size <= len(indices):
101
+ indices += indices[:padding_size]
102
+ else:
103
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
104
+ else:
105
+ # remove tail of data to make it evenly divisible.
106
+ indices = indices[: self.total_size]
107
+ assert len(indices) == self.total_size
108
+
109
+ # subsample
110
+ indices = indices[self.rank : self.total_size : self.num_replicas]
111
+ assert len(indices) == self.num_samples
112
+
113
+ return iter(indices)
114
+
115
+ def __len__(self) -> int:
116
+ return self.num_samples
117
+
118
+ def set_epoch(self, epoch: int) -> None:
119
+ self.epoch = epoch
scripts/components.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import json
13
+ import os
14
+
15
+ import cv2
16
+ import fastremap
17
+ import numpy as np
18
+ import PIL
19
+ import tifffile
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from cellpose.dynamics import compute_masks, masks_to_flows
23
+ from cellpose.metrics import _intersection_over_union, _true_positive
24
+ from monai.apps import get_logger
25
+ from monai.data import MetaTensor
26
+ from monai.transforms import MapTransform
27
+ from monai.utils import ImageMetaKey, convert_to_dst_type
28
+
29
+ logger = get_logger("VistaCell")
30
+
31
+
32
+ class LoadTiffd(MapTransform):
33
+ def __call__(self, data):
34
+ d = dict(data)
35
+ for key in self.key_iterator(d):
36
+ filename = d[key]
37
+
38
+ extension = os.path.splitext(filename)[1][1:]
39
+ image_size = None
40
+
41
+ if extension in ["tif", "tiff"]:
42
+ img_array = tifffile.imread(filename) # use tifffile for tif images
43
+ image_size = img_array.shape
44
+ if len(img_array.shape) == 3 and img_array.shape[-1] <= 3:
45
+ img_array = np.transpose(img_array, (2, 0, 1)) # channels first without transpose
46
+ else:
47
+ img_array = np.array(PIL.Image.open(filename)) # PIL for all other images (png, jpeg)
48
+ image_size = img_array.shape
49
+ if len(img_array.shape) == 3:
50
+ img_array = np.transpose(img_array, (2, 0, 1)) # channels first
51
+
52
+ if len(img_array.shape) not in [2, 3]:
53
+ raise ValueError(
54
+ "Unsupported image dimensions, filename " + str(filename) + " shape " + str(img_array.shape)
55
+ )
56
+
57
+ if len(img_array.shape) == 2:
58
+ img_array = img_array[np.newaxis] # add channels_first if no channel
59
+
60
+ if key == "label":
61
+ if img_array.shape[0] > 1:
62
+ print(
63
+ f"Strange case, label with several channels {filename} shape {img_array.shape}, keeping only first"
64
+ )
65
+ img_array = img_array[[0]]
66
+
67
+ elif key == "image":
68
+ if img_array.shape[0] == 1:
69
+ img_array = np.repeat(img_array, repeats=3, axis=0) # if grayscale, repeat as 3 channels
70
+ elif img_array.shape[0] == 2:
71
+ print(
72
+ f"Strange case, image with 2 channels {filename} shape {img_array.shape}, appending first channel to make 3"
73
+ )
74
+ img_array = np.stack(
75
+ (img_array[0], img_array[1], img_array[0]), axis=0
76
+ ) # this should not happen, we got 2 channel input image
77
+ elif img_array.shape[0] > 3:
78
+ print(f"Strange case, image with >3 channels, {filename} shape {img_array.shape}, keeping first 3")
79
+ img_array = img_array[:3]
80
+
81
+ meta_data = {ImageMetaKey.FILENAME_OR_OBJ: filename, ImageMetaKey.SPATIAL_SHAPE: image_size}
82
+ d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
83
+
84
+ return d
85
+
86
+
87
+ class SaveTiffd(MapTransform):
88
+ def __init__(self, output_dir, data_root_dir="/", nested_folder=False, *args, **kwargs) -> None:
89
+ super().__init__(*args, **kwargs)
90
+
91
+ self.output_dir = output_dir
92
+ self.data_root_dir = data_root_dir
93
+ self.nested_folder = nested_folder
94
+
95
+ def set_data_root_dir(self, data_root_dir):
96
+ self.data_root_dir = data_root_dir
97
+
98
+ def __call__(self, data):
99
+ d = dict(data)
100
+ os.makedirs(self.output_dir, exist_ok=True)
101
+
102
+ for key in self.key_iterator(d):
103
+ seg = d[key]
104
+ filename = seg.meta[ImageMetaKey.FILENAME_OR_OBJ]
105
+
106
+ basename = os.path.splitext(os.path.basename(filename))[0]
107
+
108
+ if self.nested_folder:
109
+ reldir = os.path.relpath(os.path.dirname(filename), self.data_root_dir)
110
+ outdir = os.path.join(self.output_dir, reldir)
111
+ os.makedirs(outdir, exist_ok=True)
112
+ else:
113
+ outdir = self.output_dir
114
+
115
+ outname = os.path.join(outdir, basename + ".tif")
116
+
117
+ label = seg.cpu().numpy()
118
+ lm = label.max()
119
+ if lm <= 255:
120
+ label = label.astype(np.uint8)
121
+ elif lm <= 65535:
122
+ label = label.astype(np.uint16)
123
+ else:
124
+ label = label.astype(np.uint32)
125
+
126
+ tifffile.imwrite(outname, label)
127
+
128
+ print(f"Saving {outname} shape {label.shape} max {label.max()} dtype {label.dtype}")
129
+
130
+ return d
131
+
132
+
133
+ class LabelsToFlows(MapTransform):
134
+ # based on dynamics labels_to_flows()
135
+ # created a 3 channel output (foreground, flowx, flowy) and saves under flow (new) key
136
+
137
+ def __init__(self, flow_key, *args, **kwargs) -> None:
138
+ super().__init__(*args, **kwargs)
139
+ self.flow_key = flow_key
140
+
141
+ def __call__(self, data):
142
+ d = dict(data)
143
+ for key in self.key_iterator(d):
144
+ label = d[key].int().numpy()
145
+
146
+ label = fastremap.renumber(label, in_place=True)[0]
147
+ veci = masks_to_flows(label[0], device=None)
148
+
149
+ flows = np.concatenate((label > 0.5, veci), axis=0).astype(np.float32)
150
+ flows = convert_to_dst_type(flows, d[key], dtype=torch.float, device=d[key].device)[0]
151
+ d[self.flow_key] = flows
152
+ # meta_data = {ImageMetaKey.FILENAME_OR_OBJ : filename}
153
+ # d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
154
+ return d
155
+
156
+
157
+ class LogitsToLabels:
158
+ def __call__(self, logits, filename=None):
159
+ device = logits.device
160
+ logits = logits.float().cpu().numpy()
161
+ dp = logits[1:] # vectors
162
+ cellprob = logits[0] # foreground prob (logit)
163
+
164
+ try:
165
+ pred_mask, p = compute_masks(
166
+ dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=device
167
+ )
168
+ except RuntimeError as e:
169
+ logger.warning(f"compute_masks failed on GPU retrying on CPU {logits.shape} file {filename} {e}")
170
+ pred_mask, p = compute_masks(
171
+ dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=None
172
+ )
173
+
174
+ return pred_mask, p
175
+
176
+
177
+ class LogitsToLabelsd(MapTransform):
178
+ def __call__(self, data):
179
+ d = dict(data)
180
+ f = LogitsToLabels()
181
+ for key in self.key_iterator(d):
182
+ pred_mask, p = f(d[key])
183
+ d[key] = pred_mask
184
+ d[f"{key}_centroids"] = p
185
+ return d
186
+
187
+
188
+ class SaveTiffExd(MapTransform):
189
+ def __init__(self, output_dir, output_ext=".png", output_postfix="seg", image_key="image", *args, **kwargs) -> None:
190
+ super().__init__(*args, **kwargs)
191
+
192
+ self.output_dir = output_dir
193
+ self.output_ext = output_ext
194
+ self.output_postfix = output_postfix
195
+ self.image_key = image_key
196
+
197
+ def to_polygons(self, contours):
198
+ polygons = []
199
+ for contour in contours:
200
+ if len(contour) < 3:
201
+ continue
202
+ polygons.append(np.squeeze(contour).astype(int).tolist())
203
+ return polygons
204
+
205
+ def __call__(self, data):
206
+ d = dict(data)
207
+
208
+ output_dir = d.get("output_dir", self.output_dir)
209
+ output_ext = d.get("output_ext", self.output_ext)
210
+ overlayed_masks = d.get("overlayed_masks", False)
211
+ output_contours = d.get("output_contours", False)
212
+
213
+ os.makedirs(self.output_dir, exist_ok=True)
214
+
215
+ img = d.get(self.image_key, None)
216
+ filename = img.meta.get(ImageMetaKey.FILENAME_OR_OBJ) if img is not None else None
217
+ image_size = img.meta.get(ImageMetaKey.SPATIAL_SHAPE) if img is not None else None
218
+ basename = os.path.splitext(os.path.basename(filename))[0] if filename else "mask"
219
+ logger.info(f"File: {filename}; Base: {basename}")
220
+
221
+ for key in self.key_iterator(d):
222
+ label = d[key]
223
+ output_filename = f"{basename}{'_' + self.output_postfix if self.output_postfix else ''}{output_ext}"
224
+ output_filepath = os.path.join(output_dir, output_filename)
225
+ lm = label.max()
226
+ logger.info(f"Mask Shape: {label.shape}; Instances: {lm}")
227
+
228
+ if lm <= 255:
229
+ label = label.astype(np.uint8)
230
+ elif lm <= 65535:
231
+ label = label.astype(np.uint16)
232
+ else:
233
+ label = label.astype(np.uint32)
234
+
235
+ tifffile.imwrite(output_filepath, label)
236
+ logger.info(f"Saving {output_filepath}")
237
+
238
+ polygons = []
239
+ if overlayed_masks:
240
+ logger.info(f"Overlay Masks: Reading original Image: {filename}")
241
+ image = cv2.imread(filename)
242
+ mask = cv2.imread(output_filepath, 0)
243
+
244
+ for i in range(1, np.max(mask)):
245
+ m = np.zeros_like(mask)
246
+ m[mask == i] = 1
247
+ color = np.random.choice(range(256), size=3).tolist()
248
+ contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
249
+ polygons.extend(self.to_polygons(contours))
250
+ cv2.drawContours(image, contours, -1, color, 1)
251
+ cv2.imwrite(output_filepath, image)
252
+ logger.info(f"Overlay Masks: Saving {output_filepath}")
253
+ else:
254
+ label = cv2.convertScaleAbs(label, alpha=255.0 / label.max())
255
+ contours, _ = cv2.findContours(label, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
256
+ polygons.extend(self.to_polygons(contours))
257
+
258
+ meta_json = {"image_size": image_size, "contours": len(polygons)}
259
+ with open(os.path.join(output_dir, "meta.json"), "w") as fp:
260
+ json.dump(meta_json, fp, indent=2)
261
+
262
+ if output_contours:
263
+ logger.info(f"Total Polygons: {len(polygons)}")
264
+ with open(os.path.join(output_dir, "contours.json"), "w") as fp:
265
+ json.dump({"count": len(polygons), "contours": polygons}, fp, indent=2)
266
+
267
+ return d
268
+
269
+
270
+ # Loss (adopted from Cellpose)
271
+ class CellLoss:
272
+ def __call__(self, y_pred, y):
273
+ loss = 0.5 * F.mse_loss(y_pred[:, 1:], 5 * y[:, 1:]) + F.binary_cross_entropy_with_logits(
274
+ y_pred[:, [0]], y[:, [0]]
275
+ )
276
+ return loss
277
+
278
+
279
+ # Accuracy (adopted from Cellpose)
280
+ class CellAcc:
281
+ def __call__(self, mask_pred, mask_true):
282
+ if isinstance(mask_true, torch.Tensor):
283
+ mask_true = mask_true.cpu().numpy()
284
+
285
+ if isinstance(mask_pred, torch.Tensor):
286
+ mask_pred = mask_pred.cpu().numpy()
287
+
288
+ # print("CellAcc mask_true", mask_true.shape, 'max', np.max(mask_true), ",
289
+ # "'mask_pred', mask_pred.shape, 'max', np.max(mask_pred) )
290
+
291
+ iou = _intersection_over_union(mask_true, mask_pred)[1:, 1:]
292
+ tp = _true_positive(iou, th=0.5)
293
+
294
+ fp = np.max(mask_pred) - tp
295
+ fn = np.max(mask_true) - tp
296
+ ap = tp / (tp + fp + fn)
297
+
298
+ # print("CellAcc ap", ap, 'tp', tp, 'fp', fp, 'fn', fn)
299
+ return ap
scripts/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import logging
13
+ import os
14
+ import warnings
15
+ from logging.config import fileConfig
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ from monai.apps import get_logger
20
+ from monai.apps.utils import DEFAULT_FMT
21
+ from monai.bundle import ConfigParser
22
+ from monai.utils import RankFilter, ensure_tuple
23
+
24
+ logger = get_logger("VistaCell")
25
+
26
+ np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
27
+ logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.WARNING)
28
+ warnings.filterwarnings("ignore", message=".*Divide by zero.*") # intensity transform divide by zero warning
29
+
30
+ LOGGING_CONFIG = {
31
+ "version": 1,
32
+ "disable_existing_loggers": False,
33
+ "formatters": {"monai_default": {"format": DEFAULT_FMT}},
34
+ "loggers": {"VistaCell": {"handlers": ["file", "console"], "level": "DEBUG", "propagate": False}},
35
+ "filters": {"rank_filter": {"()": RankFilter}},
36
+ "handlers": {
37
+ "file": {
38
+ "class": "logging.FileHandler",
39
+ "filename": "default.log",
40
+ "mode": "a", # append or overwrite
41
+ "level": "DEBUG",
42
+ "formatter": "monai_default",
43
+ "filters": ["rank_filter"],
44
+ },
45
+ "console": {
46
+ "class": "logging.StreamHandler",
47
+ "level": "INFO",
48
+ "formatter": "monai_default",
49
+ "filters": ["rank_filter"],
50
+ },
51
+ },
52
+ }
53
+
54
+
55
+ def parsing_bundle_config(config_file, logging_file=None, meta_file=None):
56
+ if config_file is not None:
57
+ _config_files = ensure_tuple(config_file)
58
+ config_root_path = Path(_config_files[0]).parent
59
+ for _config_file in _config_files:
60
+ _config_file = Path(_config_file)
61
+ if _config_file.parent != config_root_path:
62
+ logger.warning(
63
+ f"Not all config files are in '{config_root_path}'. If logging_file and meta_file are"
64
+ f"not specified, '{config_root_path}' will be used as the default config root directory."
65
+ )
66
+ if not _config_file.is_file():
67
+ raise FileNotFoundError(f"Cannot find the config file: {_config_file}.")
68
+ else:
69
+ config_root_path = Path("configs")
70
+
71
+ logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file
72
+ if os.path.exists(logging_file):
73
+ fileConfig(logging_file, disable_existing_loggers=False)
74
+
75
+ parser = ConfigParser()
76
+ parser.read_config(config_file)
77
+ meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
78
+ if isinstance(meta_file, str) and not os.path.exists(meta_file):
79
+ logger.error(
80
+ f"Cannot find the metadata config file: {meta_file}. "
81
+ "Please see: https://docs.monai.io/en/stable/mb_specification.html"
82
+ )
83
+ else:
84
+ parser.read_meta(f=meta_file)
85
+
86
+ return parser
scripts/workflow.py ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import csv
13
+ import gc
14
+ import logging
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import time
19
+ from collections import OrderedDict
20
+ from datetime import datetime
21
+
22
+ import monai.transforms as mt
23
+ import numpy as np
24
+ import torch
25
+ import torch.distributed as dist
26
+ import yaml
27
+ from monai.apps import get_logger
28
+ from monai.auto3dseg.utils import datafold_read
29
+ from monai.bundle import BundleWorkflow, ConfigParser
30
+ from monai.config import print_config
31
+ from monai.data import DataLoader, Dataset, decollate_batch
32
+ from monai.metrics import CumulativeAverage
33
+ from monai.utils import (
34
+ BundleProperty,
35
+ ImageMetaKey,
36
+ convert_to_dst_type,
37
+ ensure_tuple,
38
+ look_up_option,
39
+ optional_import,
40
+ set_determinism,
41
+ )
42
+ from torch.cuda.amp import GradScaler, autocast
43
+ from torch.utils.data import WeightedRandomSampler
44
+ from torch.utils.data.distributed import DistributedSampler
45
+ from torch.utils.tensorboard import SummaryWriter
46
+
47
+ mlflow, mlflow_is_imported = optional_import("mlflow")
48
+
49
+
50
+ if __package__ in (None, ""):
51
+ from cell_distributed_weighted_sampler import DistributedWeightedSampler
52
+ from components import LabelsToFlows, LoadTiffd, LogitsToLabels
53
+ from utils import LOGGING_CONFIG, parsing_bundle_config # type: ignore
54
+ else:
55
+ from .cell_distributed_weighted_sampler import DistributedWeightedSampler
56
+ from .components import LabelsToFlows, LoadTiffd, LogitsToLabels
57
+ from .utils import LOGGING_CONFIG, parsing_bundle_config
58
+
59
+
60
+ logger = get_logger("VistaCell")
61
+
62
+
63
+ class VistaCell(BundleWorkflow):
64
+ """
65
+ Primary vista model training workflow that extends
66
+ monai.bundle.BundleWorkflow for cell segmentation.
67
+ """
68
+
69
+ def __init__(self, config_file=None, meta_file=None, logging_file=None, workflow_type="train", **override):
70
+ """
71
+ config_file can be one or a list of config files.
72
+ the rest key-values in the `override` are to override config content.
73
+ """
74
+
75
+ parser = parsing_bundle_config(config_file, logging_file=logging_file, meta_file=meta_file)
76
+ parser.update(pairs=override)
77
+
78
+ mode = parser.get("mode", None)
79
+ if mode is not None: # if user specified a `mode` it'll override the workflow_type arg
80
+ workflow_type = mode
81
+ else:
82
+ mode = workflow_type # if user didn't specify mode, the workflow_type will be used
83
+ super().__init__(workflow_type=workflow_type)
84
+ self._props = {}
85
+ self._set_props = {}
86
+ self.parser = parser
87
+
88
+ self.rank = int(os.getenv("LOCAL_RANK", "0"))
89
+ self.global_rank = int(os.getenv("RANK", "0"))
90
+ self.is_distributed = dist.is_available() and dist.is_initialized()
91
+
92
+ # check if torchrun or bcprun started it
93
+ if dist.is_torchelastic_launched() or (
94
+ os.getenv("NGC_ARRAY_SIZE") is not None and int(os.getenv("NGC_ARRAY_SIZE")) > 1
95
+ ):
96
+ if dist.is_available():
97
+ dist.init_process_group(backend="nccl", init_method="env://")
98
+
99
+ self.is_distributed = dist.is_available() and dist.is_initialized()
100
+
101
+ torch.cuda.set_device(self.config("device"))
102
+ dist.barrier()
103
+
104
+ else:
105
+ self.is_distributed = False
106
+
107
+ if self.global_rank == 0 and self.config("ckpt_path") and not os.path.exists(self.config("ckpt_path")):
108
+ os.makedirs(self.config("ckpt_path"), exist_ok=True)
109
+
110
+ if self.rank == 0:
111
+ # make sure the log file exists, as a workaround for mult-gpu logging race condition
112
+ _log_file = self.config("log_output_file", "vista_cell.log")
113
+ _log_file_dir = os.path.dirname(_log_file)
114
+ if _log_file_dir and not os.path.exists(_log_file_dir):
115
+ os.makedirs(_log_file_dir, exist_ok=True)
116
+
117
+ print_config()
118
+
119
+ if self.is_distributed:
120
+ dist.barrier()
121
+
122
+ seed = self.config("seed", None)
123
+ if seed is not None:
124
+ set_determinism(seed)
125
+ logger.info(f"set determinism seed: {self.config('seed', None)}")
126
+ elif torch.cuda.is_available():
127
+ torch.backends.cudnn.benchmark = True
128
+ logger.info("No seed provided, using cudnn.benchmark for performance.")
129
+
130
+ if os.path.exists(self.config("ckpt_path")):
131
+ self.parser.export_config_file(
132
+ self.parser.config,
133
+ os.path.join(self.config("ckpt_path"), "working.yaml"),
134
+ fmt="yaml",
135
+ default_flow_style=None,
136
+ )
137
+
138
+ self.add_property("network", required=True)
139
+ self.add_property("train_loader", required=True)
140
+ self.add_property("val_dataset", required=False)
141
+ self.add_property("val_loader", required=False)
142
+ self.add_property("val_preprocessing", required=False)
143
+ self.add_property("train_sampler", required=True)
144
+ self.add_property("val_sampler", required=True)
145
+ self.add_property("mode", required=False)
146
+ # set evaluator as required when mode is infer or eval
147
+ # will change after we enhance the bundle properties
148
+ self.evaluator = None
149
+
150
+ def _set_property(self, name, property, value):
151
+ # stores user-reset initialized objects that should not be re-initialized.
152
+ self._set_props[name] = value
153
+
154
+ def _get_property(self, name, property):
155
+ """
156
+ The customized bundle workflow must implement required properties in:
157
+ https://github.com/Project-MONAI/MONAI/blob/dev/monai/bundle/properties.py.
158
+ """
159
+ if name in self._set_props:
160
+ self._props[name] = self._set_props[name]
161
+ return self._props[name]
162
+ if name in self._props:
163
+ return self._props[name]
164
+ try:
165
+ value = getattr(self, f"get_{name}")()
166
+ except AttributeError as err:
167
+ if property[BundleProperty.REQUIRED]:
168
+ raise ValueError(
169
+ f"Property '{name}' is required by the bundle format, "
170
+ f"but the method 'get_{name}' is not implemented."
171
+ ) from err
172
+ raise AttributeError from err
173
+ self._props[name] = value
174
+ return value
175
+
176
+ def config(self, name, default="null", **kwargs):
177
+ """read the parsed content (evaluate the expression) from the config file."""
178
+ if default != "null":
179
+ return self.parser.get_parsed_content(name, default=default, **kwargs)
180
+ return self.parser.get_parsed_content(name, **kwargs)
181
+
182
+ def initialize(self):
183
+ _log_file = self.config("log_output_file", "vista_cell.log")
184
+ if _log_file is None:
185
+ LOGGING_CONFIG["loggers"]["VistaCell"]["handlers"].remove("file")
186
+ LOGGING_CONFIG["handlers"].pop("file", None)
187
+ else:
188
+ LOGGING_CONFIG["handlers"]["file"]["filename"] = _log_file
189
+ logging.config.dictConfig(LOGGING_CONFIG)
190
+
191
+ def get_mode(self):
192
+ mode_str = self.config("mode", self.workflow_type)
193
+ return look_up_option(mode_str, ("train", "training", "infer", "inference", "eval", "evaluation"))
194
+
195
+ def run(self):
196
+ if str(self.mode).startswith("train"):
197
+ return self.train()
198
+ if str(self.mode).startswith("infer"):
199
+ return self.infer()
200
+ return self.validate()
201
+
202
+ def finalize(self):
203
+ if self.is_distributed:
204
+ dist.destroy_process_group()
205
+ set_determinism(None)
206
+
207
+ def get_network_def(self):
208
+ return self.config("network_def")
209
+
210
+ def get_network(self):
211
+ pretrained_ckpt_name = self.config("pretrained_ckpt_name", None)
212
+ pretrained_ckpt_path = self.config("pretrained_ckpt_path", None)
213
+ if pretrained_ckpt_name is not None and pretrained_ckpt_path is None:
214
+ # if relative name specified, append to default ckpt_path dir
215
+ pretrained_ckpt_path = os.path.join(self.config("ckpt_path"), pretrained_ckpt_name)
216
+
217
+ if pretrained_ckpt_path is not None and not os.path.exists(pretrained_ckpt_path):
218
+ logger.info(f"Pretrained checkpoint {pretrained_ckpt_path} not found.")
219
+ raise ValueError(f"Pretrained checkpoint {pretrained_ckpt_path} not found.")
220
+
221
+ if pretrained_ckpt_path is not None and os.path.exists(pretrained_ckpt_path):
222
+ # not loading sam weights, if we're using our own checkpoint
223
+ if "checkpoint" in self.parser.config["network_def"]:
224
+ self.parser.config["network_def"]["checkpoint"] = None
225
+ model = self.config("network")
226
+ self.checkpoint_load(ckpt=pretrained_ckpt_path, model=model)
227
+ else:
228
+ model = self.config("network")
229
+
230
+ if self.config("channels_last", False):
231
+ model = model.to(memory_format=torch.channels_last)
232
+
233
+ if self.is_distributed:
234
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
235
+
236
+ if self.config("compile", False):
237
+ model = torch.compile(model)
238
+
239
+ if self.is_distributed:
240
+ model = torch.nn.parallel.DistributedDataParallel(
241
+ module=model,
242
+ device_ids=[self.rank],
243
+ output_device=self.rank,
244
+ find_unused_parameters=self.config("find_unused_parameters", False),
245
+ )
246
+
247
+ pytorch_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
248
+ logger.info(f"total parameters count {pytorch_params} distributed {self.is_distributed}")
249
+ return model
250
+
251
+ def get_train_dataset_data(self):
252
+ train_files, valid_files = [], []
253
+ dataset_data = self.config("train#dataset#data")
254
+ val_key = None
255
+ if isinstance(dataset_data, dict):
256
+ val_key = dataset_data.get("key", None)
257
+ data_list_files = dataset_data["data_list_files"]
258
+
259
+ if isinstance(data_list_files, str):
260
+ data_list_files = ConfigParser.load_config_file(
261
+ data_list_files
262
+ ) # if it's a path to a separate file with a list of datasets
263
+ else:
264
+ data_list_files = ensure_tuple(data_list_files)
265
+
266
+ if self.global_rank == 0:
267
+ print("Using data_list_files ", data_list_files)
268
+
269
+ for idx, d in enumerate(data_list_files):
270
+ logger.info(f"adding datalist ({idx}): {d['datalist']}")
271
+ t, v = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=self.config("fold"))
272
+
273
+ if val_key is not None:
274
+ v, _ = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=-1, key=val_key) # e.g. testing
275
+
276
+ for item in t:
277
+ item["datalist_id"] = idx
278
+ item["datalist_count"] = len(t)
279
+ for item in v:
280
+ item["datalist_id"] = idx
281
+ item["datalist_count"] = len(v)
282
+ train_files.extend(t)
283
+ valid_files.extend(v)
284
+
285
+ if self.config("quick", False):
286
+ logger.info("quick_data")
287
+ train_files = train_files[:8]
288
+ valid_files = valid_files[:7]
289
+ if not valid_files:
290
+ logger.warning("No validation data found.")
291
+ return train_files, valid_files
292
+
293
+ def read_val_datalists(self, section="validate", data_list_files=None, val_key=None, merge=True):
294
+ """read the corresponding folds of the datalist for validation or inference"""
295
+ dataset_data = self.config(f"{section}#dataset#data")
296
+
297
+ if isinstance(dataset_data, list):
298
+ return dataset_data
299
+
300
+ if data_list_files is None:
301
+ data_list_files = dataset_data["data_list_files"]
302
+
303
+ if isinstance(data_list_files, str):
304
+ data_list_files = ConfigParser.load_config_file(
305
+ data_list_files
306
+ ) # if it's a path to a separate file with a list of datasets
307
+ else:
308
+ data_list_files = ensure_tuple(data_list_files)
309
+
310
+ if val_key is None:
311
+ val_key = dataset_data.get("key", None)
312
+
313
+ val_files, idx = [], 0
314
+ for d in data_list_files:
315
+ if val_key is not None:
316
+ v_files, _ = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=-1, key=val_key)
317
+ else:
318
+ _, v_files = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=self.config("fold"))
319
+ logger.info(f"adding datalist ({idx} -- {val_key}): {d['datalist']} {len(v_files)}")
320
+ if merge:
321
+ val_files.extend(v_files)
322
+ else:
323
+ val_files.append(v_files)
324
+ idx += 1
325
+
326
+ if self.config("quick", False):
327
+ logger.info("quick_data")
328
+ val_files = val_files[:8] if merge else [val_files[0][:8]]
329
+ return val_files
330
+
331
+ def get_train_preprocessing(self):
332
+ roi_size = self.config("train#dataset#preprocessing#roi_size")
333
+
334
+ train_xforms = []
335
+ train_xforms.append(LoadTiffd(keys=["image", "label"]))
336
+ train_xforms.append(mt.EnsureTyped(keys=["image", "label"], data_type="tensor", dtype=torch.float))
337
+ if self.config("prescale", True):
338
+ print("Prescaling images to 0..1")
339
+ train_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True))
340
+ train_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True))
341
+ train_xforms.append(
342
+ mt.ScaleIntensityRangePercentilesd(
343
+ keys="image", lower=1, upper=99, b_min=0.0, b_max=1.0, channel_wise=True, clip=True
344
+ )
345
+ )
346
+ train_xforms.append(mt.SpatialPadd(keys=["image", "label"], spatial_size=roi_size))
347
+ train_xforms.append(
348
+ mt.RandSpatialCropd(keys=["image", "label"], roi_size=roi_size)
349
+ ) # crop roi_size (if image is large)
350
+
351
+ # # add augmentations
352
+ train_xforms.extend(
353
+ [
354
+ mt.RandAffined(
355
+ keys=["image", "label"],
356
+ prob=0.5,
357
+ rotate_range=np.pi, # from -pi to pi
358
+ scale_range=[-0.5, 0.5], # from 0.5 to 1.5
359
+ mode=["bilinear", "nearest"],
360
+ spatial_size=roi_size,
361
+ cache_grid=True,
362
+ padding_mode="border",
363
+ ),
364
+ mt.RandAxisFlipd(keys=["image", "label"], prob=0.5),
365
+ mt.RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1),
366
+ mt.RandAdjustContrastd(keys=["image"], prob=0.25, gamma=(1, 2)),
367
+ mt.RandGaussianSmoothd(keys=["image"], prob=0.25, sigma_x=(1, 2)),
368
+ mt.RandHistogramShiftd(keys=["image"], prob=0.25, num_control_points=3),
369
+ mt.RandGaussianSharpend(keys=["image"], prob=0.25),
370
+ ]
371
+ )
372
+
373
+ train_xforms.append(
374
+ LabelsToFlows(keys="label", flow_key="flow")
375
+ ) # finally create new key "flows" with 3 channels 1) foreground 2) dx flow 3) dy flow
376
+
377
+ return train_xforms
378
+
379
+ def get_val_preprocessing(self):
380
+ val_xforms = []
381
+ val_xforms.append(LoadTiffd(keys=["image", "label"], allow_missing_keys=True))
382
+ val_xforms.append(
383
+ mt.EnsureTyped(keys=["image", "label"], data_type="tensor", dtype=torch.float, allow_missing_keys=True)
384
+ )
385
+
386
+ if self.config("prescale", True):
387
+ print("Prescaling val images to 0..1")
388
+ val_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True))
389
+
390
+ val_xforms.append(
391
+ mt.ScaleIntensityRangePercentilesd(
392
+ keys="image", lower=1, upper=99, b_min=0.0, b_max=1.0, channel_wise=True, clip=True
393
+ )
394
+ )
395
+ val_xforms.append(LabelsToFlows(keys="label", flow_key="flow", allow_missing_keys=True))
396
+
397
+ return val_xforms
398
+
399
+ def get_train_dataset(self):
400
+ train_dataset_data = self.config("train#dataset#data")
401
+ if isinstance(train_dataset_data, list): # FIXME, why check
402
+ train_files = train_dataset_data
403
+ else:
404
+ train_files, _ = self.train_dataset_data
405
+ logger.info(f"train files {len(train_files)}")
406
+ return Dataset(data=train_files, transform=mt.Compose(self.train_preprocessing))
407
+
408
+ def get_val_dataset(self):
409
+ """this is to be used for validation during training"""
410
+ val_dataset_data = self.config("validate#dataset#data")
411
+ if isinstance(val_dataset_data, list): # FIXME, why check
412
+ valid_files = val_dataset_data
413
+ else:
414
+ _, valid_files = self.train_dataset_data
415
+ return Dataset(data=valid_files, transform=mt.Compose(self.val_preprocessing))
416
+
417
+ def set_val_datalist(self, datalist_py):
418
+ self.parser["validate#dataset#data"] = datalist_py
419
+ self._props.pop("val_loader", None)
420
+ self._props.pop("val_dataset", None)
421
+ self._props.pop("val_sampler", None)
422
+
423
+ def get_train_sampler(self):
424
+ if self.config("use_weighted_sampler", False):
425
+ data = self.train_dataset.data
426
+ logger.info(f"Using weighted sampler, first item {data[0]}")
427
+ sample_weights = 1.0 / torch.as_tensor(
428
+ [item.get("datalist_count", 1.0) for item in data], dtype=torch.float
429
+ ) # inverse proportional to sub-datalist count
430
+ # if we are using weighed sampling, the number of iterations epoch must be provided
431
+ # (cant use a dataset length anymore)
432
+ num_samples_per_epoch = self.config("num_samples_per_epoch", None)
433
+ if num_samples_per_epoch is None:
434
+ num_samples_per_epoch = len(data) # a workaround if not provided
435
+ logger.warning(
436
+ "We are using weighted random sampler, but num_samples_per_epoch is not provided, "
437
+ f"so using {num_samples_per_epoch} full data length as a workaround!"
438
+ )
439
+
440
+ if self.is_distributed:
441
+ return DistributedWeightedSampler(
442
+ self.train_dataset, shuffle=True, weights=sample_weights, num_samples=num_samples_per_epoch
443
+ ) # custom implementation, as Pytorch does not have one
444
+ return WeightedRandomSampler(weights=sample_weights, num_samples=num_samples_per_epoch)
445
+
446
+ if self.is_distributed:
447
+ return DistributedSampler(self.train_dataset, shuffle=True)
448
+ return None
449
+
450
+ def get_val_sampler(self):
451
+ if self.is_distributed:
452
+ return DistributedSampler(self.val_dataset, shuffle=False)
453
+ return None
454
+
455
+ def get_train_loader(self):
456
+ sampler = self.train_sampler
457
+ return DataLoader(
458
+ self.train_dataset,
459
+ batch_size=self.config("train#batch_size"),
460
+ shuffle=(sampler is None),
461
+ sampler=sampler,
462
+ pin_memory=True,
463
+ num_workers=self.config("train#num_workers"),
464
+ )
465
+
466
+ def get_val_loader(self):
467
+ sampler = self.val_sampler
468
+ return DataLoader(
469
+ self.val_dataset,
470
+ batch_size=self.config("validate#batch_size"),
471
+ shuffle=False,
472
+ sampler=sampler,
473
+ pin_memory=True,
474
+ num_workers=self.config("validate#num_workers"),
475
+ )
476
+
477
+ def train(self):
478
+ config = self.config
479
+ distributed = self.is_distributed
480
+ sliding_inferrer = config("inferer#sliding_inferer")
481
+ use_amp = config("amp")
482
+
483
+ amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[
484
+ config("amp_dtype")
485
+ ]
486
+ if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
487
+ amp_dtype = torch.float16
488
+ logger.warning(
489
+ "bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually"
490
+ )
491
+
492
+ use_gradscaler = use_amp and amp_dtype == torch.float16
493
+ logger.info(f"Using grad scaler {use_gradscaler} amp_dtype {amp_dtype} use_amp {use_amp}")
494
+ grad_scaler = GradScaler(enabled=use_gradscaler) # using GradScaler only for AMP float16 (not bfloat16)
495
+
496
+ loss_function = config("loss_function")
497
+ acc_function = config("key_metric")
498
+
499
+ ckpt_path = config("ckpt_path")
500
+ channels_last = config("channels_last")
501
+
502
+ num_epochs_per_saving = config("train#trainer#num_epochs_per_saving")
503
+ num_epochs_per_validation = config("train#trainer#num_epochs_per_validation")
504
+ num_epochs = config("train#trainer#max_epochs")
505
+ val_schedule_list = self.schedule_validation_epochs(
506
+ num_epochs=num_epochs, num_epochs_per_validation=num_epochs_per_validation
507
+ )
508
+ logger.info(f"Scheduling validation loops at epochs: {val_schedule_list}")
509
+
510
+ train_loader = self.train_loader
511
+ val_loader = self.val_loader
512
+ optimizer = config("optimizer")
513
+ model = self.network
514
+
515
+ tb_writer = None
516
+ csv_path = progress_path = None
517
+
518
+ if self.global_rank == 0 and ckpt_path is not None:
519
+ # rank 0 is responsible for heavy lifting of logging/saving
520
+ progress_path = os.path.join(ckpt_path, "progress.yaml")
521
+
522
+ tb_writer = SummaryWriter(log_dir=ckpt_path)
523
+ logger.info(f"Writing Tensorboard logs to {tb_writer.log_dir}")
524
+
525
+ if mlflow_is_imported:
526
+ if config("mlflow_tracking_uri", None) is not None:
527
+ mlflow.set_tracking_uri(config("mlflow_tracking_uri"))
528
+ mlflow.set_experiment("vista2d")
529
+
530
+ mlflow_run_name = config("mlflow_run_name", f'vista2d train fold{config("fold")}')
531
+ mlflow.start_run(
532
+ run_name=mlflow_run_name, log_system_metrics=config("mlflow_log_system_metrics", False)
533
+ )
534
+ mlflow.log_params(self.parser.config)
535
+ mlflow.log_dict(self.parser.config, "hyper_parameters.yaml") # experimental
536
+
537
+ csv_path = os.path.join(ckpt_path, "accuracy_history.csv")
538
+ self.save_history_csv(
539
+ csv_path=csv_path,
540
+ header=["epoch", "metric", "loss", "iter", "time", "train_time", "validation_time", "epoch_time"],
541
+ )
542
+
543
+ do_torch_save = (
544
+ (self.global_rank == 0) and ckpt_path and config("ckpt_save") and not config("train#skip", False)
545
+ )
546
+ best_ckpt_path = os.path.join(ckpt_path, "model.pt")
547
+ intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt")
548
+
549
+ best_metric = float(config("best_metric", -1))
550
+ start_epoch = config("start_epoch", 0)
551
+ best_metric_epoch = -1
552
+ pre_loop_time = time.time()
553
+ report_num_epochs = num_epochs
554
+ train_time = validation_time = 0
555
+ val_acc_history = []
556
+
557
+ if start_epoch > 0:
558
+ val_schedule_list = [v for v in val_schedule_list if v >= start_epoch]
559
+ if len(val_schedule_list) == 0:
560
+ val_schedule_list = [start_epoch]
561
+ print(f"adjusted schedule_list {val_schedule_list}")
562
+
563
+ logger.info(
564
+ f"Using num_epochs => {num_epochs}\n "
565
+ f"Using start_epoch => {start_epoch}\n "
566
+ f"batch_size => {config('train#batch_size')} \n "
567
+ f"num_warmup_epochs => {config('train#trainer#num_warmup_epochs')} \n "
568
+ )
569
+
570
+ lr_scheduler = config("lr_scheduler")
571
+ if lr_scheduler is not None and start_epoch > 0:
572
+ lr_scheduler.last_epoch = start_epoch
573
+
574
+ range_num_epochs = range(start_epoch, num_epochs)
575
+
576
+ if distributed:
577
+ dist.barrier()
578
+
579
+ if self.global_rank == 0 and tb_writer is not None and mlflow_is_imported and mlflow.is_tracking_uri_set():
580
+ mlflow.log_param("len_train_set", len(train_loader.dataset))
581
+ mlflow.log_param("len_val_set", len(val_loader.dataset))
582
+
583
+ for epoch in range_num_epochs:
584
+ report_epoch = epoch
585
+
586
+ if distributed:
587
+ if isinstance(train_loader.sampler, DistributedSampler):
588
+ train_loader.sampler.set_epoch(epoch)
589
+ dist.barrier()
590
+
591
+ epoch_time = start_time = time.time()
592
+
593
+ train_loss, train_acc = 0, 0
594
+
595
+ if not config("train#skip", False):
596
+ train_loss, train_acc = self.train_epoch(
597
+ model=model,
598
+ train_loader=train_loader,
599
+ optimizer=optimizer,
600
+ loss_function=loss_function,
601
+ acc_function=acc_function,
602
+ grad_scaler=grad_scaler,
603
+ epoch=report_epoch,
604
+ rank=self.rank,
605
+ global_rank=self.global_rank,
606
+ num_epochs=report_num_epochs,
607
+ use_amp=use_amp,
608
+ amp_dtype=amp_dtype,
609
+ channels_last=channels_last,
610
+ device=config("device"),
611
+ )
612
+
613
+ train_time = time.time() - start_time
614
+ logger.info(
615
+ f"Latest training {report_epoch}/{report_num_epochs - 1} "
616
+ f"loss: {train_loss:.4f} time {train_time:.2f}s "
617
+ f"lr: {optimizer.param_groups[0]['lr']:.4e}"
618
+ )
619
+
620
+ if self.global_rank == 0 and tb_writer is not None:
621
+ tb_writer.add_scalar("train/loss", train_loss, report_epoch)
622
+
623
+ if mlflow_is_imported and mlflow.is_tracking_uri_set():
624
+ mlflow.log_metric("train/loss", train_loss, step=report_epoch)
625
+ mlflow.log_metric("train/epoch_time", train_time, step=report_epoch)
626
+
627
+ # validate every num_epochs_per_validation epochs (defaults to 1, every epoch)
628
+ val_acc_mean = -1
629
+ if (
630
+ len(val_schedule_list) > 0
631
+ and epoch + 1 >= val_schedule_list[0]
632
+ and val_loader is not None
633
+ and len(val_loader) > 0
634
+ ):
635
+ val_schedule_list.pop(0)
636
+
637
+ start_time = time.time()
638
+ torch.cuda.empty_cache()
639
+
640
+ val_loss, val_acc = self.val_epoch(
641
+ model=model,
642
+ val_loader=val_loader,
643
+ sliding_inferrer=sliding_inferrer,
644
+ loss_function=loss_function,
645
+ acc_function=acc_function,
646
+ epoch=report_epoch,
647
+ rank=self.rank,
648
+ global_rank=self.global_rank,
649
+ num_epochs=report_num_epochs,
650
+ use_amp=use_amp,
651
+ amp_dtype=amp_dtype,
652
+ channels_last=channels_last,
653
+ device=config("device"),
654
+ )
655
+
656
+ torch.cuda.empty_cache()
657
+ validation_time = time.time() - start_time
658
+
659
+ val_acc_mean = float(np.mean(val_acc))
660
+ val_acc_history.append((report_epoch, val_acc_mean))
661
+
662
+ if self.global_rank == 0:
663
+ logger.info(
664
+ f"Latest validation {report_epoch}/{report_num_epochs - 1} "
665
+ f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc: {val_acc} time: {validation_time:.2f}s"
666
+ )
667
+
668
+ if tb_writer is not None:
669
+ tb_writer.add_scalar("val/acc", val_acc_mean, report_epoch)
670
+ tb_writer.add_scalar("val/loss", val_loss, report_epoch)
671
+ if mlflow_is_imported and mlflow.is_tracking_uri_set():
672
+ mlflow.log_metric("val/acc", val_acc_mean, step=report_epoch)
673
+ mlflow.log_metric("val/epoch_time", validation_time, step=report_epoch)
674
+
675
+ timing_dict = {
676
+ "time": f"{(time.time() - pre_loop_time) / 3600:.2f} hr",
677
+ "train_time": f"{train_time:.2f}s",
678
+ "validation_time": f"{validation_time:.2f}s",
679
+ "epoch_time": f"{time.time() - epoch_time:.2f}s",
680
+ }
681
+
682
+ if val_acc_mean > best_metric:
683
+ logger.info(f"New best metric ({best_metric:.6f} --> {val_acc_mean:.6f}). ")
684
+ best_metric, best_metric_epoch = val_acc_mean, report_epoch
685
+ save_time = 0
686
+ if do_torch_save:
687
+ save_time = self.checkpoint_save(
688
+ ckpt=best_ckpt_path, model=model, epoch=best_metric_epoch, best_metric=best_metric
689
+ )
690
+
691
+ if progress_path is not None:
692
+ self.save_progress_yaml(
693
+ progress_path=progress_path,
694
+ ckpt=best_ckpt_path if do_torch_save else None,
695
+ best_avg_score_epoch=best_metric_epoch,
696
+ best_avg_score=best_metric,
697
+ save_time=save_time,
698
+ **timing_dict,
699
+ )
700
+ if csv_path is not None:
701
+ self.save_history_csv(
702
+ csv_path=csv_path,
703
+ epoch=report_epoch,
704
+ metric=f"{val_acc_mean:.4f}",
705
+ loss=f"{train_loss:.4f}",
706
+ iter=report_epoch * len(train_loader.dataset),
707
+ **timing_dict,
708
+ )
709
+
710
+ # sanity check
711
+ if epoch > max(20, num_epochs / 4) and 0 <= val_acc_mean < 0.01 and config("stop_on_lowacc", True):
712
+ logger.info(
713
+ f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. "
714
+ "Most likely optimization diverged, try setting a smaller learning_rate"
715
+ f" than {config('learning_rate')}"
716
+ )
717
+ raise ValueError(
718
+ f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. "
719
+ "Most likely optimization diverged, try setting a smaller learning_rate"
720
+ f" than {config('learning_rate')}"
721
+ )
722
+
723
+ # save intermediate checkpoint every num_epochs_per_saving epochs
724
+ if do_torch_save and ((epoch + 1) % num_epochs_per_saving == 0 or (epoch + 1) >= num_epochs):
725
+ if report_epoch != best_metric_epoch:
726
+ self.checkpoint_save(
727
+ ckpt=intermediate_ckpt_path, model=model, epoch=report_epoch, best_metric=val_acc_mean
728
+ )
729
+ else:
730
+ try:
731
+ shutil.copyfile(best_ckpt_path, intermediate_ckpt_path) # if already saved once
732
+ except Exception as err:
733
+ logger.warning(f"error copying {best_ckpt_path} {intermediate_ckpt_path} {err}")
734
+ pass
735
+
736
+ if lr_scheduler is not None:
737
+ lr_scheduler.step()
738
+
739
+ if self.global_rank == 0:
740
+ # report time estimate
741
+ time_remaining_estimate = train_time * (num_epochs - epoch)
742
+ if val_loader is not None and len(val_loader) > 0:
743
+ if validation_time == 0:
744
+ validation_time = train_time
745
+ time_remaining_estimate += validation_time * len(val_schedule_list)
746
+
747
+ logger.info(
748
+ f"Estimated remaining training time for the current model fold {config('fold')} is "
749
+ f"{time_remaining_estimate/3600:.2f} hr, "
750
+ f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, "
751
+ f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n"
752
+ )
753
+
754
+ # end of main epoch loop
755
+ train_loader = val_loader = optimizer = None
756
+
757
+ # optionally validate best checkpoint
758
+ logger.info(f"Checking to run final testing {config('run_final_testing')}")
759
+ if config("run_final_testing"):
760
+ if distributed:
761
+ dist.barrier()
762
+ _ckpt_name = best_ckpt_path if os.path.exists(best_ckpt_path) else intermediate_ckpt_path
763
+ if not os.path.exists(_ckpt_name):
764
+ logger.info(f"Unable to validate final no checkpoints found {best_ckpt_path}, {intermediate_ckpt_path}")
765
+ else:
766
+ # self._props.pop("network", None)
767
+ # self._set_props.pop("network", None)
768
+ gc.collect()
769
+ torch.cuda.empty_cache()
770
+ best_metric = self.run_final_testing(
771
+ pretrained_ckpt_path=_ckpt_name,
772
+ progress_path=progress_path,
773
+ best_metric_epoch=best_metric_epoch,
774
+ pre_loop_time=pre_loop_time,
775
+ )
776
+
777
+ if (
778
+ self.global_rank == 0
779
+ and tb_writer is not None
780
+ and mlflow_is_imported
781
+ and mlflow.is_tracking_uri_set()
782
+ ):
783
+ mlflow.log_param("acc_testing", val_acc_mean)
784
+ mlflow.log_metric("acc_testing", val_acc_mean)
785
+
786
+ if tb_writer is not None:
787
+ tb_writer.flush()
788
+ tb_writer.close()
789
+
790
+ if mlflow_is_imported and mlflow.is_tracking_uri_set():
791
+ mlflow.end_run()
792
+
793
+ logger.info(
794
+ f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs}."
795
+ f"Training time {(time.time() - pre_loop_time)/3600:.2f} hr."
796
+ )
797
+ return best_metric
798
+
799
+ def run_final_testing(self, pretrained_ckpt_path, progress_path, best_metric_epoch, pre_loop_time):
800
+ logger.info("Running final best model testing set!")
801
+
802
+ # validate
803
+ start_time = time.time()
804
+
805
+ self._props.pop("network", None)
806
+ self.parser["pretrained_ckpt_path"] = pretrained_ckpt_path
807
+ self.parser["validate#evaluator#postprocessing"] = None # not saving images
808
+
809
+ val_acc_mean, val_loss, val_acc = self.validate(val_key="testing")
810
+ validation_time = f"{time.time() - start_time:.2f}s"
811
+ val_acc_mean = float(np.mean(val_acc))
812
+ logger.info(f"Testing: loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc {val_acc} time {validation_time}")
813
+
814
+ if self.global_rank == 0 and progress_path is not None:
815
+ self.save_progress_yaml(
816
+ progress_path=progress_path,
817
+ ckpt=pretrained_ckpt_path,
818
+ best_avg_score_epoch=best_metric_epoch,
819
+ best_avg_score=val_acc_mean,
820
+ validation_time=validation_time,
821
+ run_final_testing=True,
822
+ time=f"{(time.time() - pre_loop_time) / 3600:.2f} hr",
823
+ )
824
+ return val_acc_mean
825
+
826
+ def validate(self, validation_files=None, val_key=None, datalist=None):
827
+ if self.config("pretrained_ckpt_name", None) is None and self.config("pretrained_ckpt_path", None) is None:
828
+ self.parser["pretrained_ckpt_name"] = "model.pt"
829
+ logger.info("Using default model.pt checkpoint for validation.")
830
+
831
+ grouping = self.config("validate#grouping", False) # whether to computer average per datalist
832
+ if validation_files is None:
833
+ validation_files = self.read_val_datalists("validate", datalist, val_key=val_key, merge=not grouping)
834
+ if len(validation_files) == 0:
835
+ logger.warning(f"No validation files found {datalist} {val_key}!")
836
+ return 0, 0, 0
837
+ if not grouping or not isinstance(validation_files[0], (list, tuple)):
838
+ validation_files = [validation_files]
839
+ logger.info(f"validation file groups {len(validation_files)} grouping {grouping}")
840
+ val_acc_dict = {}
841
+
842
+ amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[
843
+ self.config("amp_dtype")
844
+ ]
845
+ if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
846
+ amp_dtype = torch.float16
847
+ logger.warning(
848
+ "bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually"
849
+ )
850
+
851
+ for datalist_id, group_files in enumerate(validation_files):
852
+ self.set_val_datalist(group_files)
853
+ val_loader = self.val_loader
854
+
855
+ start_time = time.time()
856
+ val_loss, val_acc = self.val_epoch(
857
+ model=self.network,
858
+ val_loader=val_loader,
859
+ sliding_inferrer=self.config("inferer#sliding_inferer"),
860
+ loss_function=self.config("loss_function"),
861
+ acc_function=self.config("key_metric"),
862
+ rank=self.rank,
863
+ global_rank=self.global_rank,
864
+ use_amp=self.config("amp"),
865
+ amp_dtype=amp_dtype,
866
+ post_transforms=self.config("validate#evaluator#postprocessing"),
867
+ channels_last=self.config("channels_last"),
868
+ device=self.config("device"),
869
+ )
870
+ val_acc_mean = float(np.mean(val_acc))
871
+ logger.info(
872
+ f"Validation {datalist_id} complete, loss_avg: {val_loss:.4f} "
873
+ f"acc_avg: {val_acc_mean:.4f} acc {val_acc} time {time.time() - start_time:.2f}s"
874
+ )
875
+ val_acc_dict[datalist_id] = val_acc_mean
876
+ for k, v in val_acc_dict.items():
877
+ logger.info(f"group: {k} => {v:.4f}")
878
+ val_acc_mean = sum(val_acc_dict.values()) / len(val_acc_dict.values())
879
+ logger.info(f"Testing group score average: {val_acc_mean:.4f}")
880
+ return val_acc_mean, val_loss, val_acc
881
+
882
+ def infer(self, infer_files=None, infer_key=None, datalist=None):
883
+ if self.config("pretrained_ckpt_name", None) is None and self.config("pretrained_ckpt_path", None) is None:
884
+ self.parser["pretrained_ckpt_name"] = "model.pt"
885
+ logger.info("Using default model.pt checkpoint for inference.")
886
+
887
+ if infer_files is None:
888
+ infer_files = self.read_val_datalists("infer", datalist, val_key=infer_key, merge=True)
889
+ if len(infer_files) == 0:
890
+ logger.warning(f"no file to infer {datalist} {infer_key}.")
891
+ return
892
+ logger.info(f"inference files {len(infer_files)}")
893
+ self.set_val_datalist(infer_files)
894
+ val_loader = self.val_loader
895
+
896
+ amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[
897
+ self.config("amp_dtype")
898
+ ]
899
+ if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
900
+ amp_dtype = torch.bfloat16
901
+ logger.warning(
902
+ "bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually"
903
+ )
904
+
905
+ start_time = time.time()
906
+ self.val_epoch(
907
+ model=self.network,
908
+ val_loader=val_loader,
909
+ sliding_inferrer=self.config("inferer#sliding_inferer"),
910
+ loss_function=None,
911
+ acc_function=None,
912
+ rank=self.rank,
913
+ global_rank=self.global_rank,
914
+ use_amp=self.config("amp"),
915
+ amp_dtype=amp_dtype,
916
+ post_transforms=self.config("infer#evaluator#postprocessing"),
917
+ channels_last=self.config("channels_last"),
918
+ device=self.config("device"),
919
+ )
920
+ logger.info(f"Inference complete time {time.time() - start_time:.2f}s")
921
+ return
922
+
923
+ @torch.no_grad()
924
+ def val_epoch(
925
+ self,
926
+ model,
927
+ val_loader,
928
+ sliding_inferrer,
929
+ loss_function=None,
930
+ acc_function=None,
931
+ epoch=0,
932
+ rank=0,
933
+ global_rank=0,
934
+ num_epochs=0,
935
+ use_amp=True,
936
+ amp_dtype=torch.float16,
937
+ post_transforms=None,
938
+ channels_last=False,
939
+ device=None,
940
+ ):
941
+ model.eval()
942
+ distributed = dist.is_available() and dist.is_initialized()
943
+ memory_format = torch.channels_last if channels_last else torch.preserve_format
944
+
945
+ run_loss = CumulativeAverage()
946
+ run_acc = CumulativeAverage()
947
+ run_loss.append(torch.tensor(0, device=device), count=0)
948
+
949
+ avg_loss = avg_acc = 0
950
+ start_time = time.time()
951
+
952
+ # In DDP, each replica has a subset of data, but if total data length is not evenly divisible by num_replicas,
953
+ # then some replicas has 1 extra repeated item.
954
+ # For proper validation with batch of 1, we only want to collect metrics for non-repeated items,
955
+ # hence let's compute a proper subset length
956
+ nonrepeated_data_length = len(val_loader.dataset)
957
+ sampler = val_loader.sampler
958
+ if distributed and isinstance(sampler, DistributedSampler) and not sampler.drop_last:
959
+ nonrepeated_data_length = len(range(sampler.rank, len(sampler.dataset), sampler.num_replicas))
960
+
961
+ for idx, batch_data in enumerate(val_loader):
962
+ data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
963
+ filename = batch_data["image"].meta[ImageMetaKey.FILENAME_OR_OBJ]
964
+ batch_size = data.shape[0]
965
+ loss = acc = None
966
+
967
+ with autocast(enabled=use_amp, dtype=amp_dtype):
968
+ logits = sliding_inferrer(inputs=data, network=model)
969
+ data = None
970
+
971
+ # calc loss
972
+ if loss_function is not None:
973
+ target = batch_data["flow"].as_subclass(torch.Tensor).to(device=logits.device)
974
+ loss = loss_function(logits, target)
975
+ run_loss.append(loss.to(device=device), count=batch_size)
976
+ target = None
977
+
978
+ pred_mask_all = []
979
+
980
+ for b_ind in range(logits.shape[0]): # go over batch dim
981
+ pred_mask, p = LogitsToLabels()(logits=logits[b_ind], filename=filename)
982
+ pred_mask_all.append(pred_mask)
983
+
984
+ if acc_function is not None:
985
+ label = batch_data["label"].as_subclass(torch.Tensor)
986
+
987
+ for b_ind in range(label.shape[0]):
988
+ acc = acc_function(pred_mask_all[b_ind], label[b_ind, 0].long())
989
+ acc = acc.detach().clone() if isinstance(acc, torch.Tensor) else torch.tensor(acc)
990
+
991
+ if idx < nonrepeated_data_length:
992
+ run_acc.append(acc.to(device=device), count=1)
993
+ else:
994
+ run_acc.append(torch.zeros_like(acc, device=device), count=0)
995
+ label = None
996
+
997
+ avg_loss = loss.cpu() if loss is not None else 0
998
+ avg_acc = acc.cpu().numpy() if acc is not None else 0
999
+
1000
+ logger.info(
1001
+ f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} "
1002
+ f"loss: {avg_loss:.4f} acc {avg_acc} time {time.time() - start_time:.2f}s"
1003
+ )
1004
+
1005
+ if post_transforms:
1006
+ seg = torch.from_numpy(np.stack(pred_mask_all, axis=0).astype(np.int32)).unsqueeze(1)
1007
+ batch_data["seg"] = convert_to_dst_type(
1008
+ seg, batch_data["image"], dtype=torch.int32, device=torch.device("cpu")
1009
+ )[0]
1010
+ for bd in decollate_batch(batch_data):
1011
+ post_transforms(bd) # (currently only to save output mask)
1012
+
1013
+ start_time = time.time()
1014
+
1015
+ label = target = data = batch_data = None
1016
+
1017
+ if distributed:
1018
+ dist.barrier()
1019
+
1020
+ avg_loss = run_loss.aggregate()
1021
+ avg_acc = run_acc.aggregate()
1022
+
1023
+ if np.any(avg_acc < 0):
1024
+ dist.barrier()
1025
+ logger.warning(f"Avg accuracy is negative ({avg_acc}), something went wrong!!!!!")
1026
+
1027
+ return avg_loss, avg_acc
1028
+
1029
+ def train_epoch(
1030
+ self,
1031
+ model,
1032
+ train_loader,
1033
+ optimizer,
1034
+ loss_function,
1035
+ acc_function,
1036
+ grad_scaler,
1037
+ epoch,
1038
+ rank,
1039
+ global_rank=0,
1040
+ num_epochs=0,
1041
+ use_amp=True,
1042
+ amp_dtype=torch.float16,
1043
+ channels_last=False,
1044
+ device=None,
1045
+ ):
1046
+ model.train()
1047
+ memory_format = torch.channels_last if channels_last else torch.preserve_format
1048
+
1049
+ run_loss = CumulativeAverage()
1050
+
1051
+ start_time = time.time()
1052
+ avg_loss = avg_acc = 0
1053
+ for idx, batch_data in enumerate(train_loader):
1054
+ data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
1055
+ target = batch_data["flow"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device)
1056
+
1057
+ optimizer.zero_grad(set_to_none=True)
1058
+
1059
+ with autocast(enabled=use_amp, dtype=amp_dtype):
1060
+ logits = model(data)
1061
+
1062
+ # print('logits', logits.shape, logits.dtype)
1063
+ loss = loss_function(logits.float(), target)
1064
+
1065
+ grad_scaler.scale(loss).backward()
1066
+ grad_scaler.step(optimizer)
1067
+ grad_scaler.update()
1068
+
1069
+ batch_size = data.shape[0]
1070
+
1071
+ run_loss.append(loss, count=batch_size)
1072
+ avg_loss = run_loss.aggregate()
1073
+
1074
+ logger.info(
1075
+ f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} "
1076
+ f"loss: {avg_loss:.4f} time {time.time() - start_time:.2f}s "
1077
+ )
1078
+ start_time = time.time()
1079
+
1080
+ optimizer.zero_grad(set_to_none=True)
1081
+
1082
+ data = None
1083
+ target = None
1084
+ batch_data = None
1085
+
1086
+ return avg_loss, avg_acc
1087
+
1088
+ def save_history_csv(self, csv_path=None, header=None, **kwargs):
1089
+ if csv_path is not None:
1090
+ if header is not None:
1091
+ with open(csv_path, "a") as myfile:
1092
+ wrtr = csv.writer(myfile, delimiter="\t")
1093
+ wrtr.writerow(header)
1094
+ if len(kwargs):
1095
+ with open(csv_path, "a") as myfile:
1096
+ wrtr = csv.writer(myfile, delimiter="\t")
1097
+ wrtr.writerow(list(kwargs.values()))
1098
+
1099
+ def save_progress_yaml(self, progress_path=None, ckpt=None, **report):
1100
+ if ckpt is not None:
1101
+ report["model"] = ckpt
1102
+
1103
+ report["date"] = str(datetime.now())[:19]
1104
+
1105
+ if progress_path is not None:
1106
+ yaml.add_representer(
1107
+ float, lambda dumper, value: dumper.represent_scalar("tag:yaml.org,2002:float", f"{value:.4f}")
1108
+ )
1109
+ with open(progress_path, "a") as progress_file:
1110
+ yaml.dump([report], stream=progress_file, allow_unicode=True, default_flow_style=None, sort_keys=False)
1111
+
1112
+ logger.info("Progress:" + ",".join(f" {k}: {v}" for k, v in report.items()))
1113
+
1114
+ def checkpoint_save(self, ckpt: str, model: torch.nn.Module, **kwargs):
1115
+ # save checkpoint and config
1116
+ save_time = time.time()
1117
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
1118
+ state_dict = model.module.state_dict()
1119
+ else:
1120
+ state_dict = model.state_dict()
1121
+
1122
+ if self.config("compile", False):
1123
+ # remove key prefix of compiled models
1124
+ state_dict = OrderedDict(
1125
+ (k[len("_orig_mod.") :] if k.startswith("_orig_mod.") else k, v) for k, v in state_dict.items()
1126
+ )
1127
+
1128
+ torch.save({"state_dict": state_dict, "config": self.parser.config, **kwargs}, ckpt)
1129
+
1130
+ save_time = time.time() - save_time
1131
+ logger.info(f"Saving checkpoint process: {ckpt}, {kwargs}, save_time {save_time:.2f}s")
1132
+
1133
+ return save_time
1134
+
1135
+ def checkpoint_load(self, ckpt: str, model: torch.nn.Module, **kwargs):
1136
+ # load checkpoint
1137
+ if not os.path.isfile(ckpt):
1138
+ logger.warning("Invalid checkpoint file: " + str(ckpt))
1139
+ return
1140
+ checkpoint = torch.load(ckpt, map_location="cpu")
1141
+
1142
+ model.load_state_dict(checkpoint["state_dict"], strict=True)
1143
+ epoch = checkpoint.get("epoch", 0)
1144
+ best_metric = checkpoint.get("best_metric", 0)
1145
+
1146
+ if self.config("continue", False):
1147
+ if "epoch" in checkpoint:
1148
+ self.parser["start_epoch"] = checkpoint["epoch"]
1149
+ if "best_metric" in checkpoint:
1150
+ self.parser["best_metric"] = checkpoint["best_metric"]
1151
+
1152
+ logger.info(
1153
+ f"=> loaded checkpoint {ckpt} (epoch {epoch}) "
1154
+ f"(best_metric {best_metric}) setting start_epoch {self.config('start_epoch')}"
1155
+ )
1156
+ self.parser["start_epoch"] = int(self.config("start_epoch")) + 1
1157
+ return
1158
+
1159
+ def schedule_validation_epochs(self, num_epochs, num_epochs_per_validation=None, fraction=0.16) -> list:
1160
+ """
1161
+ Schedule of epochs to validate (progressively more frequently)
1162
+ num_epochs - total number of epochs
1163
+ num_epochs_per_validation - if provided use a linear schedule with this step
1164
+ init_step
1165
+ """
1166
+
1167
+ if num_epochs_per_validation is None:
1168
+ x = (np.sin(np.linspace(0, np.pi / 2, max(10, int(fraction * num_epochs)))) * num_epochs).astype(int)
1169
+ x = np.cumsum(np.sort(np.diff(np.unique(x)))[::-1])
1170
+ x[-1] = num_epochs
1171
+ x = x.tolist()
1172
+ else:
1173
+ if num_epochs_per_validation >= num_epochs:
1174
+ x = [num_epochs_per_validation]
1175
+ else:
1176
+ x = list(range(num_epochs_per_validation, num_epochs, num_epochs_per_validation))
1177
+
1178
+ if len(x) == 0:
1179
+ x = [0]
1180
+
1181
+ return x
1182
+
1183
+
1184
+ def main(**kwargs) -> None:
1185
+ workflow = VistaCell(**kwargs)
1186
+ workflow.initialize()
1187
+ workflow.run()
1188
+ workflow.finalize()
1189
+
1190
+
1191
+ if __name__ == "__main__":
1192
+ # to be able to run directly as python scripts/workflow.py --config_file=...
1193
+ # for debugging and development
1194
+
1195
+ from pathlib import Path
1196
+
1197
+ sys.path.append(str(Path(__file__).parent.parent))
1198
+
1199
+ # from scripts import *
1200
+
1201
+ fire, fire_is_imported = optional_import("fire")
1202
+ if fire_is_imported:
1203
+ fire.Fire(main)
1204
+ else:
1205
+ print("Missing package: fire")