2ms commited on
Commit
03ae676
·
0 Parent(s):

init commit

Browse files
LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXAONEPath AI Model License Agreement 1.0 - NC
2
+
3
+ This License Agreement (“Agreement”) is entered into between you (“Licensee”) and LG Management Development Institute Co., Ltd. (“Licensor”), governing the use of the EXAONEPath AI Model (“Model”). By downloading, installing, copying, or using the Model, you agree to comply with and be bound by the terms of this Agreement. If you do not agree to all the terms, you must not download, install, copy, or use the Model. This Agreement constitutes a binding legal agreement between the Licensee and Licensor.
4
+
5
+ 1. Definitions
6
+ 1.1 Model: The artificial intelligence model provided by Licensor, which includes any software, algorithms, machine learning models, or related components supplied by Licensor. This definition extends to encompass all updates, enhancements, improvements, bug fixes, patches, or other modifications that may be provided by Licensor from time to time, whether automatically or manually implemented.
7
+ 1.2 Derivatives: Any modifications, alterations, enhancements, improvements, adaptations, or derivative works of the Model created by Licensee or any third party. This includes changes made to the Model's architecture, parameters, data processing methods, or any other aspect of the Model that results in a modification of its functionality or output.
8
+ 1.3 Output: Any data, results, content, predictions, analyses, insights, or other materials generated by the Model or Derivatives, regardless of whether they are in their original form or have been further processed or modified by the Licensee. This includes, but is not limited to, textual or numerical produced directly or indirectly through the use of the Model.
9
+ 1.4 Licensor: LG Management Development Institute Co., Ltd., the owner, developer, and provider of the EXAONEPath AI Model. The Licensor holds all rights, title, and interest in the Model and is responsible for granting licenses to use the Model under the terms specified in this Agreement.
10
+ 1.5 Licensee: The individual, organization, corporation, academic institution, government agency, or other entity using or intending to use the Model under the terms and conditions of this Agreement. The Licensee is responsible for ensuring compliance with the Agreement by all authorized users who access or utilize the Model on behalf of the Licensee.
11
+
12
+ 2. License Grant
13
+ 2.1 Grant of License: Subject to the terms and conditions outlined in this Agreement, the Licensor hereby grants the Licensee a limited, non-exclusive, non-transferable, worldwide, and revocable license to:
14
+ a. Access, download, install, and use the Model solely for research purposes. This includes evaluation, testing, academic research and experimentation.
15
+ b. Publicly disclose research results and findings derived from the use of the Model or Derivatives, including publishing papers or presentations.
16
+ c. Modify the Model and create Derivatives based on the Model, provided that such modifications and Derivatives are used exclusively for research purposes. The Licensee may conduct experiments, perform analyses, and apply custom modifications to the Model to explore its capabilities and performance under various scenarios. If the Model is modified, the modified Model must include "EXAONEPath" at the beginning of its name.
17
+ d. Distribute the Model and Derivatives in each case with a copy of this Agreement.
18
+ 2.2 Scope of License: The license granted herein does not authorize the Licensee to use the Model for any purpose not explicitly permitted under this Agreement. Any use beyond the scope of this license, including any commercial application or external distribution, is strictly prohibited unless explicitly agreed upon in writing by the Licensor.
19
+
20
+ 3. Restrictions
21
+ 3.1 Commercial Use: The Licensee is expressly prohibited from using the Model, Derivatives, or Output for any commercial purposes, including but not limited to, developing or deploying products, services, or applications that generate revenue, whether directly or indirectly. Any commercial exploitation of the Model or its derivatives requires a separate commercial license agreement with the Licensor. Furthermore, the Licensee shall not use the Model, Derivatives or Output to develop or improve other models, except for research purposes, which is explicitly permitted.
22
+ 3.2 Reverse Engineering: The Licensee shall not decompile, disassemble, reverse engineer, or attempt to derive the source code, underlying ideas, algorithms, or structure of the Model, except to the extent that such activities are expressly permitted by applicable law. Any attempt to bypass or circumvent technological protection measures applied to the Model is strictly prohibited.
23
+ 3.3 Unlawful Use: The Licensee shall not use the Model and Derivatives for any illegal, fraudulent, or unauthorized activities, nor for any purpose that violates applicable laws or regulations. This includes but is not limited to the creation, distribution, or dissemination of malicious, deceptive, or unlawful content.
24
+ 3.4 Ethical Use: The Licensee shall ensure that the Model or Derivatives is used in an ethical and responsible manner, adhering to the following guidelines:
25
+ a. The Model and Derivatives shall not be used to generate, propagate, or amplify false, misleading, or harmful information, including fake news, misinformation, or disinformation.
26
+ b. The Model and Derivatives shall not be employed to create, distribute, or promote content that is discriminatory, harassing, defamatory, abusive, or otherwise offensive to individuals or groups based on race, gender, sexual orientation, religion, nationality, or other protected characteristics.
27
+ c. The Model and Derivatives shall not infringe on the rights of others, including intellectual property rights, privacy rights, or any other rights recognized by law. The Licensee shall obtain all necessary permissions and consents before using the Model and Derivatives in a manner that may impact the rights of third parties.
28
+ d. The Model and Derivatives shall not be used in a way that causes harm, whether physical, mental, emotional, or financial, to individuals, organizations, or communities. The Licensee shall take all reasonable measures to prevent misuse or abuse of the Model and Derivatives that could result in harm or injury.
29
+
30
+ 4. Ownership
31
+ 4.1 Intellectual Property: All rights, title, and interest in and to the Model, including any modifications, Derivatives, and associated documentation, are and shall remain the exclusive property of the Licensor. The Licensee acknowledges that this Agreement does not transfer any ownership rights to the Licensee. All trademarks, service marks, and logos associated with the Model are the property of the Licensor.
32
+ 4.2 Output: All output generated by the Model from Licensee Data ("Output") shall be the sole property of the Licensee. Licensor hereby waives any claim of ownership or intellectual property rights to the Output. Licensee is solely responsible for the legality, accuracy, quality, integrity, and use of the Output.
33
+ 4.3 Attribution: In any publication or presentation of results obtained using the Model, the Licensee shall provide appropriate attribution to the Licensor, citing the Model's name and version, along with any relevant documentation or references specified by the Licensor.
34
+
35
+ 5. No Warranty
36
+ 5.1 “As-Is” Basis: The Model, Derivatives, and Output are provided on an “as-is” and “as-available” basis, without any warranties or representations of any kind, whether express, implied, or statutory. The Licensor disclaims all warranties, including but not limited to, implied warranties of merchantability, fitness for a particular purpose, accuracy, reliability, non-infringement, or any warranty arising from the course of dealing or usage of trade.
37
+ 5.2 Performance and Reliability: The Licensor does not warrant or guarantee that the Model, Derivatives or Output will meet the Licensee’s requirements, that the operation of the Model, Derivatives or Output will be uninterrupted or error-free, or that defects in the Model will be corrected. The Licensee acknowledges that the use of the Model, Derivatives or Output is at its own risk and that the Model, Derivatives or Output may contain bugs, errors, or other limitations.
38
+ 5.3 No Endorsement: The Licensor does not endorse, approve, or certify any results, conclusions, or recommendations derived from the use of the Model. The Licensee is solely responsible for evaluating the accuracy, reliability, and suitability of the Model for its intended purposes.
39
+
40
+ 6. Limitation of Liability
41
+ 6.1 No Liability for Damages: To the fullest extent permitted by applicable law, in no event shall the Licensor be liable for any special, incidental, indirect, consequential, exemplary, or punitive damages, including but not limited to, damages for loss of business profits, business interruption, loss of business information, loss of data, or any other pecuniary or non-pecuniary loss arising out of or in connection with the use or inability to use the Model, Derivatives or any Output, even if the Licensor has been advised of the possibility of such damages.
42
+ 6.2 Indemnification: The Licensee agrees to indemnify, defend, and hold harmless the Licensor, its affiliates, officers, directors, employees, and agents from and against any claims, liabilities, damages, losses, costs, or expenses (including reasonable attorneys' fees) arising out of or related to the Licensee's use of the Model, any Derivatives, or any Output, including any violation of this Agreement or applicable laws. This includes, but is not limited to, ensuring compliance with copyright laws, privacy regulations, defamation laws, and any other applicable legal or regulatory requirements.
43
+
44
+ 7. Termination
45
+ 7.1 Termination by Licensor: The Licensor reserves the right to terminate this Agreement and revoke the Licensee’s rights to use the Model at any time, with or without cause, and without prior notice if the Licensee breaches any of the terms or conditions of this Agreement. Termination shall be effective immediately upon notice.
46
+ 7.2 Effect of Termination: Upon termination of this Agreement, the Licensee must immediately cease all use of the Model, Derivatives, and Output and destroy all copies of the Model, Derivatives, and Output in its possession or control, including any backup or archival copies. The Licensee shall certify in writing to the Licensor that such destruction has been completed.
47
+ 7.3 Survival: The provisions of this Agreement that by their nature should survive termination, including but not limited to, Sections 4 (Ownership), 5 (No Warranty), 6 (Limitation of Liability), and this Section 7 (Termination), shall continue to apply after termination.
48
+
49
+ 8. Governing Law
50
+ 8.1 Governing Law: This Agreement shall be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflict of laws principles.
51
+ 8.2 Arbitration: Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English.
52
+
53
+ 9. Alterations
54
+ 9.1 Modifications: The Licensor reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on the Licensor’s website or through other means of communication. The Licensee is responsible for reviewing the Agreement periodically for changes. Continued use of the Model after any modifications have been made constitutes acceptance of the revised Agreement.
55
+ 9.2 Entire Agreement: This Agreement constitutes the entire agreement between the Licensee and Licensor concerning the subject matter hereof and supersedes all prior or contemporaneous oral or written agreements, representations, or understandings. Any terms or conditions of any purchase order or other document submitted by the Licensee in connection with the Model that are in addition to, different from, or inconsistent with the terms and conditions of this Agreement are not binding on the Licensor and are void.
56
+
57
+ By downloading, installing, or using the EXAONEPath AI Model, the Licensee acknowledges that it has read, understood, and agrees to be bound by the terms and conditions of this Agreement.
58
+
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: exaonepath
4
+ license_link: LICENSE
5
+ tags:
6
+ - lg-ai
7
+ - EXAONEPath-1.5
8
+ - pathology
9
+ ---
10
+ # EXAONE Path 1.5
11
+ ## Introduction
12
+ EXAONE Path 1.5 is a whole slide image level(WSI-level) classification framework designed for downstream tasks in pathology, such as cancer subtyping, molecular subtyping and mutation prediction. It builds upon our previous work, EXAONE Path v1.0, which focused on patch-wise feature extraction by dividing a WSI into patches and embedding each patch into a feature vector.
13
+ In EXAONE Path 1.5, we extend this pipeline to take an entire WSI as input. Each patch is first processed using the pretrained EXAONE Path 1.0 encoder to extract patch-level features. These features are then aggregated using a ViT-based (Vision Transformer) aggregator module to produce a slide-level representation.
14
+ This aggregated representation is subsequently passed through a linear classifier to perform downstream tasks such as molecular subtyping, tumor subtyping, and mutation prediction.
15
+ To effectively train the aggregator, we adopt a two-stage learning process:
16
+ Pretraining: We employ multimodal learning by aligning slide images with various mRNA gene expression profiles to learn semantically meaningful slide-level representations.
17
+ Fine-tuning: The pretrained model is then adapted to specific downstream classification tasks.
18
+ In this repository, we release the model trained for EGFR mutation prediction in lung adenocarcinoma (LUAD), enabling researchers to leverage our pipeline for similar molecular pathology applications.
19
+ ## Quickstart
20
+ ### 1. Hardware Requirements ###
21
+ - NVIDIA GPU is required
22
+ - Minimum 40GB GPU memory recommended
23
+ - Tested on Ubuntu 22.04 with NVIDIA driver version 550.144.03
24
+ Note: This implementation requires NVIDIA GPU and drivers. The provided environment setup specifically uses CUDA-enabled PyTorch, making NVIDIA GPU mandatory for running the model.
25
+ ### 2. Environment Setup
26
+ ```
27
+ pip install -r requirements.txt
28
+ ```
29
+ ### 3-a. Load the model & Inference
30
+ #### Load model with HuggingFace
31
+
32
+ ```python
33
+ from models.exaonepath import EXAONEPathV1p5Downstream
34
+ hf_token = "YOUR_HUGGING_FACE_ACCESS_TOKEN"
35
+ model = EXAONEPathV1p5Downstream.from_pretrained("LGAI-EXAONE/EXAONE-Path-1.5", use_auth_token=hf_token)
36
+ slide_path = './samples/wsis/1/1.svs'
37
+ probs = model(slide_path)
38
+ ```
39
+ #### Fast CLI Inference
40
+ Before running the command below, make sure you update your Hugging Face token.
41
+ Open `tokens.py` and replace the placeholder with your actual token:
42
+
43
+ ```python
44
+ HF_TOKEN = "YOUR_HUGGING_FACE_ACCESS_TOKEN"
45
+ ```
46
+
47
+ Then, run inference with:
48
+ ```bash
49
+ python inference.py --svs_path ./samples/wsis/1/1.svs
50
+ ```
51
+ ### 3-b. Fine-tuning with Pretrained Weights
52
+ We provide example scripts and files to help you fine-tune the model on your own dataset.
53
+ The provided script fine-tunes the model using pretrained weights stored in `./pretrained_weight.pth`.
54
+
55
+ #### Extract Features from WSI Images
56
+ To train the model using WSI images and their corresponding labels,
57
+ you must first extract patch-level features from each WSI using our provided feature extractor.
58
+
59
+ ```bash
60
+ python feature_extract.py --input_dir ./samples/wsis/ --output_dir ./samples/feats/
61
+ ```
62
+ This will generate .pt feature files in the output_dir.
63
+
64
+ #### Fine-tuning
65
+ ```bash
66
+ bash tuning_script.sh
67
+ ```
68
+ Inside tuning_script.sh, you can modify the following variables to match your dataset:
69
+ ```bash
70
+ FEAT_PATH=./samples/feats
71
+ LABEL_PATH=./samples/label/label.csv
72
+ LABEL_DICT="{'n':0, 'y':1}"
73
+ SPLIT_PATH=./samples/splits
74
+ ```
75
+ Change these paths to point to your own feature, label, and split files to start training.
76
+
77
+ ## Model Performance Comparison
78
+ | Metric: AUC | Titan(Conch v1.5+iBot, image+text) | PRISM (virchow+pe receiver, Image+text) | CHIEF (CTransPath + CLAM, Image+text, clam+wsi contrastive) | Prov-GigaPath (GigaPath+LongNet, Image-only, mask precision manner) | UNI2-h + CLAM (Image-only) | EXAONE Path 1.5(image+gene expression) |
79
+ |--------------------------|----------------------------------|-----------------------------------------|--------------------------------------------------------------|------------------------------------------------------------------------|-----------------------------|------------------|
80
+ | **TMB (cutoff 10)** | 0.74 | 0.73 | 0.70 | 0.69 | 0.71 | 0.71 |
81
+ | **LUAD-EGFR-mut** | 0.76 | 0.80 | 0.73 | 0.73 | 0.79 | 0.81 |
82
+ | **LUAD-KRAS-mut** | 0.61 | 0.65 | 0.61 | 0.66 | 0.60 | 0.63 |
83
+ | **LUAD-Gene-overexp[1]** | 0.75 | 0.68 | 0.71 | 0.71 | 0.74 | 0.72 |
84
+ | **CRC-MSS/MSI** | 0.89 | 0.88 | 0.86 | 0.90 | 0.90 | 0.89 |
85
+ | **BRCA-ER_PR_HER2** | 0.82 | 0.79 | 0.76 | 0.79 | 0.81 | 0.77 |
86
+ | **Pan-cancer-Gene-mut[2]** | 0.79 | 0.77 | 0.73 | 0.74 | 0.77 | 0.76 |
87
+ | **Avg. AUC** | 0.77 | 0.76 | 0.73 | 0.74 | 0.77 | 0.76 |
88
+
89
+ [1]: **lung-gene-overexp**: total 11 genes were evaluated: LAG3, CLDN6, CD274, EGFR, ERBB2, ERBB3, CD276, VTCN1, TACSTD2, FOLR1, MET.
90
+
91
+ [2]: **Pan-cancer-Gene-mut**: total 7 genes were evaluated: TP53, KRAS, ALK, PIK3CA, MET, EGFR, PTEN
92
+
93
+ ## License
94
+ The model is licensed under [EXAONEPath AI Model License Agreement 1.0 - NC](./LICENSE)
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "step_size": 256,
3
+ "patch_size": 256,
4
+ "num_sampled_patch": 16384,
5
+ "architecture": "EXAONEPathV1p5Downstream"
6
+ }
datasets/dataset_WSI.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from openslide import OpenSlide
4
+ from utils.preprocessor import MacenkoNormalizer, preprocessor
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class WSIPatchDataset(Dataset):
9
+ def __init__(
10
+ self,
11
+ coords,
12
+ wsi_path,
13
+ pretrained=False,
14
+ patch_size=256,
15
+ patch_level=0,
16
+ macenko=True,
17
+ return_coord=False,
18
+ ):
19
+ self.pretrained = pretrained
20
+ self.wsi = OpenSlide(wsi_path)
21
+ self.patch_size = patch_size
22
+ self.patch_level = patch_level
23
+ self.return_coord = return_coord
24
+
25
+ if macenko:
26
+ normalizer = MacenkoNormalizer(
27
+ target_path=os.path.join(
28
+ os.path.dirname(os.path.dirname(os.path.join(__file__))),
29
+ "macenko_target",
30
+ "macenko_param.pt",
31
+ )
32
+ )
33
+ else:
34
+ normalizer = None
35
+
36
+ self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer)
37
+ self.coords = coords
38
+ self.length = len(self.coords)
39
+
40
+ def __len__(self):
41
+ return self.length
42
+
43
+ def __getitem__(self, idx):
44
+ coord = self.coords[idx]
45
+ img = self.wsi.read_region(
46
+ coord, self.patch_level, (self.patch_size, self.patch_size)
47
+ ).convert("RGB")
48
+ img = self.roi_transforms(img)
49
+ if self.return_coord:
50
+ return img, torch.tensor(coord)
51
+ else:
52
+ return img
inference.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+
5
+ from models.exaonepath import EXAONEPathV1p5Downstream
6
+ from utils.constants import CLASS_NAMES
7
+ from tokens import HF_TOKEN
8
+
9
+
10
+ def infer(model, input_file):
11
+ print("Processing", input_file, "...")
12
+ probs = model(input_file)
13
+ result_str = "Result -- " + " / ".join(
14
+ [f"{name}: {probs[i].item():.4f}" for i, name in enumerate(CLASS_NAMES)]
15
+ )
16
+ print(result_str + "\n")
17
+
18
+
19
+ if __name__ == '__main__':
20
+ parser = argparse.ArgumentParser(description="Inference")
21
+ parser.add_argument('--svs_path', type=str, default='./samples/wsis/1/1.svs', help="Path to the .svs file")
22
+ parser.add_argument('--svs_dir', type=str, default='./samples_CRC', help="")
23
+
24
+ args = parser.parse_args()
25
+
26
+ hf_token = HF_TOKEN
27
+ # model = EXAONEPathV1p5Downstream.from_pretrained("LGAI-EXAONE/EXAONE-Path-1.5", use_auth_token=hf_token)
28
+ model = EXAONEPathV1p5Downstream(num_sampled_patch=16384)
29
+
30
+
31
+ # qwe = torch.load('./pytorch_model_ori.bin')
32
+ # aaa = model.load_state_dict(qwe, strict=False)
33
+ # hw_w = torch.load('/mnt/shared/shared_medical/shared/hi.choi/MOM/logs_eval_25/closebench2/ours/BIOMARKER_SMC_SMC/CRCSensor/CRCSensor_exaone_mom3_MOM_batch_8_lr0.00003_wd0.1_do0.1/_s100/s_0_checkpoint.pt', map_location='cpu')
34
+
35
+ # new_state_dict = {}
36
+ # for k, v in hw_w.items():
37
+ # if k.startswith("_orig_mod."):
38
+ # new_k = k.replace("_orig_mod.", "agg_model.", 1)
39
+ # else:
40
+ # new_k = k
41
+ # new_state_dict[new_k] = v
42
+
43
+ # load_result = model.load_state_dict(new_state_dict, strict=False)
44
+ model.load_state_dict(torch.load('./pytorch_model.bin'))
45
+ model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
46
+ model.eval()
47
+ model.feature_extractor = torch.compile(model.feature_extractor)
48
+ model.agg_model = torch.compile(model.agg_model)
49
+
50
+ for svs_name in os.listdir(args.svs_dir):
51
+ infer(model, os.path.join(args.svs_dir, svs_name))
52
+
models/aggregator.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+
7
+ from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer
8
+
9
+
10
+ # @dataclass
11
+ # class VisionCfg:
12
+ # layers: Union[Tuple[int, int, int, int], int] = 6
13
+ # width: int = 512
14
+ # head_width: int = 64
15
+ # mlp_ratio: float = 4.0
16
+
17
+ # ls_init_value: Optional[float] = None # layer scale initial value
18
+ # patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
19
+ # no_ln_pre: bool = False # disable pre transformer LayerNorm
20
+ # pool_type: str = 'none'
21
+ # final_ln_after_pool: bool = True # apply final LayerNorm after pooling
22
+ # output_tokens: bool = False
23
+ # act_kwargs: Optional[dict] = None
24
+ # norm_kwargs: Optional[dict] = None
25
+
26
+ @dataclass
27
+ class CLIPVisionCfg:
28
+ layers: Union[Tuple[int, int, int, int], int] = 6
29
+ width: int = 512
30
+ head_width: int = 64
31
+ mlp_ratio: float = 4.0
32
+ patch_size: int = 16
33
+ image_size: Union[Tuple[int, int], int] = 224
34
+
35
+ ls_init_value: Optional[float] = None # layer scale initial value
36
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
37
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
38
+ attn_pooler_queries: int = 256 # n_queries for attentional pooler
39
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
40
+ no_ln_pre: bool = False # disable pre transformer LayerNorm
41
+ pos_embed_type: str = 'none'
42
+ final_ln_after_pool: bool = True # apply final LayerNorm after pooling
43
+ pool_type: str = 'none'
44
+ output_tokens: bool = False
45
+ act_kwargs: Optional[dict] = None
46
+ norm_kwargs: Optional[dict] = None
47
+
48
+ timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
49
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
+ timm_proj_bias: bool = False # enable bias final projection
53
+ timm_drop: float = 0. # head dropout
54
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
55
+ img_embed: bool = False
56
+ cls_embed: bool = False
57
+ projection = False
58
+ use_flex = True
59
+
60
+
61
+ def get_cast_dtype(precision: str):
62
+ cast_dtype = None
63
+ if precision == 'bf16':
64
+ cast_dtype = torch.bfloat16
65
+ elif precision == 'fp16':
66
+ cast_dtype = torch.float16
67
+ return cast_dtype
68
+
69
+
70
+ def get_input_dtype(precision: str):
71
+ input_dtype = None
72
+ if precision in ('bf16', 'pure_bf16'):
73
+ input_dtype = torch.bfloat16
74
+ elif precision in ('fp16', 'pure_fp16'):
75
+ input_dtype = torch.float16
76
+ return input_dtype
77
+
78
+
79
+ def _build_vision_tower(
80
+ embed_dim: int,
81
+ vision_cfg: CLIPVisionCfg,
82
+ quick_gelu: bool = False,
83
+ cast_dtype: Optional[torch.dtype] = None,
84
+ dropout: float = 0.1,
85
+ num_registers: int = 0,
86
+ ):
87
+ if isinstance(vision_cfg, dict):
88
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
89
+
90
+ act_layer = QuickGELU if quick_gelu else nn.GELU
91
+
92
+ vision_heads = vision_cfg.width // vision_cfg.head_width
93
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
94
+ if vision_cfg.norm_kwargs:
95
+ norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
96
+ if vision_cfg.act_kwargs is not None:
97
+ act_layer = partial(act_layer, **vision_cfg.act_kwargs)
98
+
99
+ visual = VisionTransformer(
100
+ width=vision_cfg.width,
101
+ layers=vision_cfg.layers,
102
+ heads=vision_heads,
103
+ mlp_ratio=vision_cfg.mlp_ratio,
104
+ ls_init_value=vision_cfg.ls_init_value,
105
+ output_dim=embed_dim,
106
+ patch_dropout=vision_cfg.patch_dropout,
107
+ no_ln_pre=vision_cfg.no_ln_pre,
108
+ pool_type=vision_cfg.pool_type,
109
+ final_ln_after_pool=vision_cfg.final_ln_after_pool,
110
+ act_layer=act_layer,
111
+ norm_layer=norm_layer,
112
+ output_tokens=vision_cfg.output_tokens,
113
+ img_embed = vision_cfg.img_embed,
114
+ use_flex = True,
115
+ dropout = dropout,
116
+ num_registers = num_registers,
117
+ use_rel_bias =True,
118
+ )
119
+
120
+ return visual
121
+
122
+
123
+ class MixedOmicsModel(nn.Module):
124
+ def __init__(
125
+ self,
126
+ embed_dim: int,
127
+ vision_cfg: CLIPVisionCfg,
128
+ quick_gelu: bool = False,
129
+ cast_dtype: Optional[torch.dtype] = None,
130
+ drop_rate: float = 0.25,
131
+ num_registers: int = 0,
132
+ *args,
133
+ **kwargs,
134
+ ):
135
+ super().__init__()
136
+
137
+ self.drop_prob = drop_rate
138
+ self.num_registers = num_registers
139
+
140
+ vision_cfg.cls_embed = False
141
+
142
+ self.visual = _build_vision_tower(embed_dim,
143
+ vision_cfg,
144
+ quick_gelu,
145
+ cast_dtype,
146
+ dropout=drop_rate,
147
+ num_registers=0,
148
+ )
149
+
150
+ self.image_proj = nn.Linear(embed_dim, embed_dim)
151
+ self.image_proj.apply(self.init_weights)
152
+
153
+ self.ln_post = LayerNorm(embed_dim)
154
+
155
+
156
+
157
+ def init_weights(self, module):
158
+ if isinstance(module, (nn.Linear, nn.Embedding)):
159
+ module.weight.data.normal_(mean=0.0, std=0.02)
160
+ elif isinstance(module, nn.LayerNorm):
161
+ module.bias.data.zero_()
162
+ module.weight.data.fill_(1.0)
163
+
164
+ if isinstance(module, nn.Linear) and module.bias is not None:
165
+ module.bias.data.zero_()
166
+
167
+ def _check_tensor(self, tensor, name):
168
+ print(name, " : ", tensor.shape)
169
+ if torch.isnan(tensor).any():
170
+ print(tensor.shape)
171
+ print(f"Tensor {name} contains NaN values.")
172
+ if torch.isinf(tensor).any():
173
+ print(tensor.shape)
174
+ print(f"Tensor {name} contains Inf values.")
175
+
176
+ def forward(
177
+ self,
178
+ image,
179
+ coords=None,
180
+ im_mask=None,
181
+ *args,
182
+ **kwargs,
183
+ ):
184
+
185
+ ## image embedding
186
+ image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous())
187
+ image_embeds = self.ln_post(image_embeds)
188
+
189
+ if im_mask is not None:
190
+ mask = im_mask.unsqueeze(-1).contiguous()
191
+ masked_embeds = image_embeds * mask
192
+ sum_embeds = masked_embeds.sum(dim=1)
193
+ valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1]
194
+ mean_embeds = sum_embeds / valid_counts # [N, dim]
195
+
196
+ else:
197
+ mean_embeds = image_embeds.mean(-2)
198
+
199
+ image_embeds_final = self.image_proj(mean_embeds)
200
+
201
+ return image_embeds_final, image_embeds, mean_embeds
202
+
203
+
204
+
205
+ def make_model(
206
+ embed_dim=768,
207
+ droprate=0.1,
208
+ num_registers=0,
209
+ depth=4,
210
+ ):
211
+ vCfg = CLIPVisionCfg
212
+ vCfg.width = embed_dim
213
+ vCfg.layers = depth
214
+
215
+ model = MixedOmicsModel(
216
+ embed_dim=embed_dim,
217
+ vision_cfg=vCfg,
218
+ drop_rate=droprate,
219
+ num_registers=num_registers,
220
+ )
221
+
222
+ return model
models/cls_modules.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class GlobalClassificationHead(nn.Module):
7
+ def __init__(self, input_dim, num_classes=32, dropout_rate=0.1):
8
+ super().__init__()
9
+ self.norm = nn.LayerNorm(input_dim)
10
+ self.dropout = nn.Dropout(dropout_rate)
11
+ self.fc = nn.Linear(input_dim, num_classes)
12
+
13
+ def forward(self, x):
14
+ x = self.norm(x)
15
+ x = self.dropout(x)
16
+ logits = self.fc(x) # (batch_size, num_classes)
17
+ return logits
18
+
19
+
20
+ class CLSHead(nn.Module):
21
+ def __init__(
22
+ self,
23
+ embed_dim,
24
+ num_classes,
25
+ dropout=0.1,
26
+ use_norm=True,
27
+ hidden_dim=None,
28
+ activation="silu",
29
+ pooling_type="cls",
30
+ ):
31
+ super().__init__()
32
+
33
+ if hidden_dim is None:
34
+ hidden_dim = embed_dim
35
+
36
+ if activation == "gelu":
37
+ self.activation = nn.GELU()
38
+ elif activation == "relu":
39
+ self.activation = nn.ReLU(inplace=True)
40
+ elif activation == "silu":
41
+ self.activation = nn.SiLU(inplace=True)
42
+ else:
43
+ raise ValueError(f"Value Error: {activation}")
44
+
45
+ self.pooling_type = pooling_type
46
+
47
+ if pooling_type == "attention":
48
+ self.attention_pool = nn.Sequential(
49
+ nn.LayerNorm(embed_dim),
50
+ nn.Linear(embed_dim, 1),
51
+ # nn.Softmax(dim=1)
52
+ )
53
+
54
+ if use_norm:
55
+ self.norm = nn.LayerNorm(embed_dim)
56
+ else:
57
+ self.norm = nn.Identity()
58
+
59
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
60
+ self.dropout1 = nn.Dropout(dropout)
61
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
62
+
63
+ self._init_weights()
64
+
65
+ def _init_weights(self):
66
+ nn.init.trunc_normal_(self.fc1.weight, std=0.02)
67
+ nn.init.zeros_(self.fc1.bias)
68
+ nn.init.trunc_normal_(self.fc2.weight, std=0.02)
69
+ nn.init.zeros_(self.fc2.bias)
70
+
71
+ if self.pooling_type == "attention":
72
+ nn.init.trunc_normal_(self.attention_pool[0].weight, std=0.02)
73
+ nn.init.zeros_(self.attention_pool[0].bias)
74
+
75
+ def forward(self, x, x2=None, x3=None, attn_mask=None):
76
+ if self.pooling_type == "mlp":
77
+ pooled = x + x3
78
+ elif self.pooling_type == "attention":
79
+ x = torch.cat([x.unsqueeze(1), x2], dim=1)
80
+ weights = self.attention_pool(x) # [batch_size, num_tokens, 1]
81
+
82
+ if attn_mask is not None:
83
+ attn_mask = torch.cat([torch.zeros(attn_mask.shape[0], 1).to(attn_mask.dtype).to(attn_mask.device), attn_mask], dim=1)
84
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=weights.dtype)
85
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
86
+ attn_mask = new_attn_mask
87
+
88
+ if len(attn_mask.shape) ==2:
89
+ attn_mask = attn_mask[..., None] # [batch_size, num_tokens, 1]
90
+
91
+ weights = weights + attn_mask
92
+ weights = F.softmax(weights, dim=1)
93
+
94
+ pooled = torch.sum(x * weights, dim=1)
95
+ else:
96
+ raise ValueError(f"지원하지 않는 풀링 타입: {self.pooling_type}")
97
+
98
+ pooled = self.norm(pooled)
99
+
100
+ x = self.fc1(pooled)
101
+ x = self.activation(x)
102
+ x = self.dropout1(x)
103
+ x = self.fc2(x)
104
+
105
+ return x
106
+
107
+
108
+ class BaseClassifier(nn.Module):
109
+ def __init__(
110
+ self,
111
+ base_model,
112
+ num_classes,
113
+ cls_head_kwargs=None
114
+ ):
115
+ super().__init__()
116
+
117
+ self.backbone = base_model
118
+ embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768
119
+
120
+ cls_head_config = {
121
+ 'embed_dim': embed_dim,
122
+ 'num_classes': num_classes,
123
+ }
124
+
125
+ if cls_head_kwargs:
126
+ cls_head_config.update(cls_head_kwargs)
127
+
128
+ self.cls_head = CLSHead(**cls_head_config)
129
+
130
+ def forward(self, image, coords=None, im_mask=None):
131
+ feat_final, feats, m_feats = self.backbone(image=image, coords=coords, im_mask=im_mask)
132
+ logits = self.cls_head(feat_final, feats, m_feats, attn_mask=~im_mask.bool())
133
+
134
+ Y_hat = torch.topk(logits, 1, dim=1)[1]
135
+ Y_prob = F.softmax(logits, dim=-1)
136
+
137
+ return logits, Y_prob, Y_hat
138
+
139
+
140
+ class LinearClassifier(nn.Module):
141
+ def __init__(
142
+ self,
143
+ base_model,
144
+ num_classes=2,
145
+ pool="final"
146
+ ):
147
+ super().__init__()
148
+ self.backbone = base_model
149
+ self.pool = pool
150
+ print("Linear classifier pool type : ", self.pool)
151
+ embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768
152
+
153
+ self.cls_head = nn.Sequential(
154
+ nn.LayerNorm(embed_dim),
155
+ nn.Linear(embed_dim, num_classes)
156
+ )
157
+
158
+ def forward(self, image, coords=None, im_mask=None):
159
+ feat_final, _, feat_mean = self.backbone(image=image, coords=coords, im_mask=im_mask)
160
+ if self.pool == "mean":
161
+ logits = self.cls_head(feat_mean)
162
+ elif self.pool == "final":
163
+ logits = self.cls_head(feat_final)
164
+
165
+ Y_hat = torch.topk(logits, 1, dim=1)[1]
166
+ Y_prob = F.softmax(logits, dim=-1)
167
+
168
+ return logits, Y_prob, Y_hat
models/exaonepath.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from models.aggregator import make_model
5
+ from models.cls_modules import LinearClassifier
6
+ from datasets.dataset_WSI import WSIPatchDataset
7
+ from models.feature_extractor import vit_base
8
+ from utils.wsi_utils import extract_tissue_patch_coords
9
+ from torch.utils.data import DataLoader
10
+
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+
14
+ class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin):
15
+ def __init__(
16
+ self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True
17
+ ):
18
+ super(EXAONEPathV1p5Downstream, self).__init__()
19
+ self.step_size = step_size
20
+ self.patch_size = patch_size
21
+ self.macenko = macenko
22
+ self.num_sampled_patch = num_sampled_patch
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ self.config = {
26
+ "step_size": step_size,
27
+ "patch_size": patch_size,
28
+ "macenko": macenko,
29
+ "num_sampled_patch": num_sampled_patch,
30
+ }
31
+
32
+ self.feature_extractor = vit_base()
33
+ self.feature_extractor = self.feature_extractor
34
+ # self.feature_extractor = self.feature_extractor.to(self.device)
35
+ # self.feature_extractor.eval()
36
+
37
+ self.agg_model = make_model(
38
+ embed_dim=768,
39
+ droprate=0.0,
40
+ num_registers=0,
41
+ depth=4,
42
+ )
43
+
44
+ self.agg_model = LinearClassifier(self.agg_model, pool='mean')
45
+
46
+ # self.agg_model.to(self.device)
47
+ # self.agg_model.eval()
48
+
49
+ @torch.no_grad()
50
+ def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
51
+ # Extract patches
52
+ coords = extract_tissue_patch_coords(
53
+ svs_path, patch_size=self.patch_size, step_size=self.step_size
54
+ )
55
+
56
+ # Extract patch-level features
57
+ self.feature_extractor.eval()
58
+ patch_dataset = WSIPatchDataset(
59
+ coords=coords,
60
+ wsi_path=svs_path,
61
+ pretrained=True,
62
+ macenko=self.macenko,
63
+ patch_size=self.patch_size,
64
+ return_coord=True,
65
+ )
66
+ patch_loader = DataLoader(
67
+ dataset=patch_dataset,
68
+ batch_size=feature_extractor_batch_size,
69
+ num_workers=(
70
+ feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
71
+ ),
72
+ pin_memory=self.device.type == "cuda",
73
+
74
+ )
75
+ features_list = []
76
+ coords_list = []
77
+ for count, items in enumerate(patch_loader):
78
+ patches, coords = items
79
+ print(
80
+ f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
81
+ end="\r",
82
+ )
83
+ patches = patches.to(self.device, non_blocking=True)
84
+
85
+ feature = self.feature_extractor(patches) # [B, 1024]
86
+ feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
87
+ feature = feature.to("cpu", non_blocking=True)
88
+ features_list.append(feature)
89
+
90
+ coords = coords.to(self.device, non_blocking=True)
91
+ coords_list.append(coords)
92
+
93
+ print("")
94
+ print("Feature extraction finished")
95
+
96
+ features = torch.cat(features_list)
97
+ coords = torch.cat(coords_list)
98
+ total_samples = features.shape[0]
99
+
100
+ num_samples = min(self.num_sampled_patch, total_samples)
101
+ indices = torch.randperm(total_samples)[:num_samples]
102
+ sampled_features = features[indices]
103
+ sampled_coords = coords[indices]
104
+
105
+ # Aggregate features
106
+ self.agg_model.eval()
107
+
108
+ # sampled_features = torch.randn([8192, 768])
109
+ # sampled_coords = torch.randn([8192, 2])
110
+
111
+ logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device))
112
+ probs = Y_prob[0].cpu()
113
+
114
+ return probs
115
+
116
+ @torch.no_grad()
117
+ def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8):
118
+ # Extract patches
119
+ coords = extract_tissue_patch_coords(
120
+ svs_path, patch_size=self.patch_size, step_size=self.step_size
121
+ )
122
+
123
+ # Extract patch-level features
124
+ self.feature_extractor.eval()
125
+ patch_dataset = WSIPatchDataset(
126
+ coords=coords,
127
+ wsi_path=svs_path,
128
+ pretrained=True,
129
+ macenko=self.macenko,
130
+ patch_size=self.patch_size,
131
+ return_coord=False
132
+ )
133
+ patch_loader = DataLoader(
134
+ dataset=patch_dataset,
135
+ batch_size=feature_extractor_batch_size,
136
+ num_workers=(
137
+ feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
138
+ ),
139
+ pin_memory=self.device.type == "cuda",
140
+ )
141
+ features_list = []
142
+ for count, patches in enumerate(patch_loader):
143
+ print(
144
+ f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
145
+ end="\r",
146
+ )
147
+ patches = patches.to(self.device, non_blocking=True)
148
+
149
+ feature = self.feature_extractor(patches) # [B, 1024]
150
+ feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
151
+ feature = feature.to("cpu", non_blocking=True)
152
+ features_list.append(feature)
153
+ print("")
154
+ print("Feature extraction finished")
155
+
156
+ features = torch.cat(features_list)
157
+
158
+ return features
models/feature_extractor.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class DropPath(nn.Module):
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
11
+
12
+ def __init__(self, drop_prob=None):
13
+ super(DropPath, self).__init__()
14
+ self.drop_prob = drop_prob
15
+
16
+ def forward(self, x):
17
+ return _drop_path(x, self.drop_prob, self.training)
18
+
19
+
20
+ class Mlp(nn.Module):
21
+ def __init__(
22
+ self,
23
+ in_features,
24
+ hidden_features=None,
25
+ out_features=None,
26
+ act_layer=nn.GELU,
27
+ drop=0.0,
28
+ ):
29
+ super().__init__()
30
+ out_features = out_features or in_features
31
+ hidden_features = hidden_features or in_features
32
+ self.fc1 = nn.Linear(in_features, hidden_features)
33
+ self.act = act_layer()
34
+ self.fc2 = nn.Linear(hidden_features, out_features)
35
+ self.drop = nn.Dropout(drop)
36
+
37
+ def forward(self, x):
38
+ x = self.fc1(x)
39
+ x = self.act(x)
40
+ x = self.drop(x)
41
+ x = self.fc2(x)
42
+ x = self.drop(x)
43
+ return x
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ num_heads=8,
51
+ qkv_bias=False,
52
+ qk_scale=None,
53
+ attn_drop=0.0,
54
+ proj_drop=0.0,
55
+ ):
56
+ super().__init__()
57
+ self.num_heads = num_heads
58
+ head_dim = dim // num_heads
59
+ self.scale = qk_scale or head_dim**-0.5
60
+
61
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62
+ self.attn_drop = nn.Dropout(attn_drop)
63
+ self.proj = nn.Linear(dim, dim)
64
+ self.proj_drop = nn.Dropout(proj_drop)
65
+
66
+ def forward(self, x):
67
+ B, N, C = x.shape
68
+ qkv = (
69
+ self.qkv(x)
70
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
71
+ .permute(2, 0, 3, 1, 4)
72
+ )
73
+ q, k, v = qkv[0], qkv[1], qkv[2]
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
+ x = self.proj(x)
81
+ x = self.proj_drop(x)
82
+ return x, attn
83
+
84
+
85
+ class Block(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ num_heads,
90
+ mlp_ratio=4.0,
91
+ qkv_bias=False,
92
+ qk_scale=None,
93
+ drop=0.0,
94
+ attn_drop=0.0,
95
+ drop_path=0.0,
96
+ act_layer=nn.GELU,
97
+ norm_layer=nn.LayerNorm,
98
+ ):
99
+ super().__init__()
100
+ self.norm1 = norm_layer(dim)
101
+ self.attn = Attention(
102
+ dim,
103
+ num_heads=num_heads,
104
+ qkv_bias=qkv_bias,
105
+ qk_scale=qk_scale,
106
+ attn_drop=attn_drop,
107
+ proj_drop=drop,
108
+ )
109
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
110
+ self.norm2 = norm_layer(dim)
111
+ mlp_hidden_dim = int(dim * mlp_ratio)
112
+ self.mlp = Mlp(
113
+ in_features=dim,
114
+ hidden_features=mlp_hidden_dim,
115
+ act_layer=act_layer,
116
+ drop=drop,
117
+ )
118
+
119
+ def forward(self, x, return_attention=False):
120
+ y, attn = self.attn(self.norm1(x))
121
+ if return_attention:
122
+ return attn
123
+ x = x + self.drop_path(y)
124
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
125
+ return x
126
+
127
+
128
+ class PatchEmbed(nn.Module):
129
+ """Image to Patch Embedding"""
130
+
131
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
132
+ super().__init__()
133
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
134
+ self.img_size = img_size
135
+ self.patch_size = patch_size
136
+ self.num_patches = num_patches
137
+
138
+ self.proj = nn.Conv2d(
139
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
140
+ )
141
+
142
+ def forward(self, x):
143
+ B, C, H, W = x.shape
144
+ x = self.proj(x).flatten(2).transpose(1, 2)
145
+ return x
146
+
147
+
148
+ class VisionTransformer(nn.Module):
149
+ def __init__(
150
+ self,
151
+ img_size=[224],
152
+ patch_size=16,
153
+ in_chans=3,
154
+ num_classes=0,
155
+ embed_dim=768,
156
+ depth=12,
157
+ num_heads=12,
158
+ mlp_ratio=4.0,
159
+ qkv_bias=False,
160
+ qk_scale=None,
161
+ drop_rate=0.0,
162
+ attn_drop_rate=0.0,
163
+ drop_path_rate=0.0,
164
+ norm_layer=nn.LayerNorm,
165
+ **kwargs
166
+ ):
167
+ super().__init__()
168
+ self.num_features = self.embed_dim = embed_dim
169
+
170
+ self.patch_embed = PatchEmbed(
171
+ img_size=img_size[0],
172
+ patch_size=patch_size,
173
+ in_chans=in_chans,
174
+ embed_dim=embed_dim,
175
+ )
176
+ num_patches = self.patch_embed.num_patches
177
+
178
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
179
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
180
+ self.pos_drop = nn.Dropout(p=drop_rate)
181
+
182
+ dpr = [
183
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
184
+ ] # stochastic depth decay rule
185
+ self.blocks = nn.ModuleList(
186
+ [
187
+ Block(
188
+ dim=embed_dim,
189
+ num_heads=num_heads,
190
+ mlp_ratio=mlp_ratio,
191
+ qkv_bias=qkv_bias,
192
+ qk_scale=qk_scale,
193
+ drop=drop_rate,
194
+ attn_drop=attn_drop_rate,
195
+ drop_path=dpr[i],
196
+ norm_layer=norm_layer,
197
+ )
198
+ for i in range(depth)
199
+ ]
200
+ )
201
+ self.norm = norm_layer(embed_dim)
202
+
203
+ # Classifier head
204
+ self.head = (
205
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
206
+ )
207
+
208
+ _trunc_normal_(self.pos_embed, std=0.02)
209
+ _trunc_normal_(self.cls_token, std=0.02)
210
+ self.apply(self._init_weights)
211
+
212
+ def _init_weights(self, m):
213
+ if isinstance(m, nn.Linear):
214
+ _trunc_normal_(m.weight, std=0.02)
215
+ if isinstance(m, nn.Linear) and m.bias is not None:
216
+ nn.init.constant_(m.bias, 0)
217
+ elif isinstance(m, nn.LayerNorm):
218
+ nn.init.constant_(m.bias, 0)
219
+ nn.init.constant_(m.weight, 1.0)
220
+
221
+ def interpolate_pos_encoding(self, x, w, h):
222
+ npatch = x.shape[1] - 1
223
+ N = self.pos_embed.shape[1] - 1
224
+ if npatch == N and w == h:
225
+ return self.pos_embed
226
+ class_pos_embed = self.pos_embed[:, 0]
227
+ patch_pos_embed = self.pos_embed[:, 1:]
228
+ dim = x.shape[-1]
229
+ w0 = w // self.patch_embed.patch_size
230
+ h0 = h // self.patch_embed.patch_size
231
+ # we add a small number to avoid floating point error in the interpolation
232
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
233
+ w0, h0 = w0 + 0.1, h0 + 0.1
234
+ patch_pos_embed = nn.functional.interpolate(
235
+ patch_pos_embed.reshape(
236
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
237
+ ).permute(0, 3, 1, 2),
238
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
239
+ # size=(int(w0), int(h0)),
240
+ mode="bicubic",
241
+ )
242
+ assert (
243
+ int(w0) == patch_pos_embed.shape[-2]
244
+ and int(h0) == patch_pos_embed.shape[-1]
245
+ )
246
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
247
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
248
+
249
+ def prepare_tokens(self, x):
250
+ B, nc, w, h = x.shape
251
+ x = self.patch_embed(x) # patch linear embedding
252
+
253
+ # add the [CLS] token to the embed patch tokens
254
+ cls_tokens = self.cls_token.expand(B, -1, -1)
255
+ x = torch.cat((cls_tokens, x), dim=1)
256
+
257
+ # add positional encoding to each token
258
+ x = x + self.interpolate_pos_encoding(x, w, h)
259
+
260
+ return self.pos_drop(x)
261
+
262
+ def forward(self, x):
263
+ x = self.prepare_tokens(x)
264
+ for blk in self.blocks:
265
+ x = blk(x)
266
+ x = self.norm(x)
267
+ # print(x.type())
268
+ return x[:, 0]
269
+
270
+ def get_last_selfattention(self, x):
271
+ x = self.prepare_tokens(x)
272
+ for i, blk in enumerate(self.blocks):
273
+ if i < len(self.blocks) - 1:
274
+ x = blk(x)
275
+ else:
276
+ # return attention of the last block
277
+ return blk(x, return_attention=True)
278
+
279
+ def get_intermediate_layers(self, x, n=1):
280
+ x = self.prepare_tokens(x)
281
+ # we return the output tokens from the `n` last blocks
282
+ output = []
283
+ for i, blk in enumerate(self.blocks):
284
+ x = blk(x)
285
+ if len(self.blocks) - i <= n:
286
+ output.append(self.norm(x))
287
+ return output
288
+
289
+
290
+ def vit_base(patch_size=16, **kwargs):
291
+ model = VisionTransformer(
292
+ patch_size=patch_size,
293
+ embed_dim=768,
294
+ depth=12,
295
+ num_heads=12,
296
+ mlp_ratio=4,
297
+ qkv_bias=True,
298
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
299
+ **kwargs
300
+ )
301
+ return model
302
+
303
+
304
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
305
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
306
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
307
+ def norm_cdf(x):
308
+ # Computes standard normal cumulative distribution function
309
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
310
+
311
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
312
+ warnings.warn(
313
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
314
+ "The distribution of values may be incorrect.",
315
+ stacklevel=2,
316
+ )
317
+
318
+ with torch.no_grad():
319
+ # Values are generated by using a truncated uniform distribution and
320
+ # then using the inverse CDF for the normal distribution.
321
+ # Get upper and lower cdf values
322
+ l = norm_cdf((a - mean) / std)
323
+ u = norm_cdf((b - mean) / std)
324
+
325
+ # Uniformly fill tensor with values from [l, u], then translate to
326
+ # [2l-1, 2u-1].
327
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
328
+
329
+ # Use inverse cdf transform for normal distribution to get truncated
330
+ # standard normal
331
+ tensor.erfinv_()
332
+
333
+ # Transform to proper mean, std
334
+ tensor.mul_(std * math.sqrt(2.0))
335
+ tensor.add_(mean)
336
+
337
+ # Clamp to ensure it's in the proper range
338
+ tensor.clamp_(min=a, max=b)
339
+ return tensor
340
+
341
+
342
+ def _trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
343
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
344
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
345
+
346
+
347
+ def _drop_path(x, drop_prob: float = 0.0, training: bool = False):
348
+ if drop_prob == 0.0 or not training:
349
+ return x
350
+ keep_prob = 1 - drop_prob
351
+ shape = (x.shape[0],) + (1,) * (
352
+ x.ndim - 1
353
+ ) # work with diff dim tensors, not just 2D ConvNets
354
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
355
+ random_tensor.floor_() # binarize
356
+ output = x.div(keep_prob) * random_tensor
357
+ return output
models/flex_attn.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Sequence, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.attention.flex_attention import (
8
+ _DEFAULT_SPARSE_BLOCK_SIZE,
9
+ create_block_mask,
10
+ create_mask,
11
+ flex_attention,
12
+ )
13
+ from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv
14
+
15
+ try:
16
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
17
+ except ImportError:
18
+ from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
19
+
20
+ from torch._dynamo import disable
21
+
22
+
23
+ def generate_alibi_bias(H=12):
24
+ alibi_bias = []
25
+ for h in range(H):
26
+ alibi_bias.append(-((h + 1) / H))
27
+ alibi_bias = torch.tensor(alibi_bias)
28
+ alibi_bias = torch.exp2(alibi_bias)
29
+ return alibi_bias
30
+
31
+
32
+ def get_rel_bias_func(scale, coords=None, qk_scale=1.0):
33
+ def patch_coords_rel_bias(score, b, h, q_idx, kv_idx):
34
+ if coords is None:
35
+ return score
36
+ with torch.no_grad():
37
+ dx = coords[b, q_idx][0] - coords[b, kv_idx][0]
38
+ dy = coords[b, q_idx][1] - coords[b, kv_idx][1]
39
+ dist = torch.sqrt(dx * dx + dy * dy)
40
+ dist = dist.clamp(max=1000) # max distance
41
+ dist = torch.log1p(dist)
42
+ bias = dist * scale[h] * qk_scale
43
+ return score - bias # closer → larger score
44
+ return patch_coords_rel_bias
45
+
46
+
47
+ def key_padding_mask(mask):
48
+ def padding_mask(b, h, q_idx, kv_idx):
49
+ return ~mask[b, kv_idx]
50
+ return padding_mask
51
+
52
+
53
+ class FlexCore(nn.Module):
54
+ """
55
+ For using "forward hook"
56
+ """
57
+ def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False):
58
+ """
59
+ "return_lse=True" should be used with ATTN_MAP_VIS wrapper.
60
+ Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores).
61
+ """
62
+ return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse)
63
+
64
+
65
+ class Flex_Attention(nn.Module):
66
+ def __init__(
67
+ self,
68
+ dim: int,
69
+ num_heads: int = 12,
70
+ qkv_bias: bool = True,
71
+ proj_drop: float = 0.,
72
+ use_rel_bias: bool = True,
73
+ ):
74
+ super().__init__()
75
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim // num_heads
78
+
79
+ self.scale = self.head_dim ** -0.5
80
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
81
+ if qkv_bias:
82
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
83
+ else:
84
+ self.in_proj_bias = None
85
+
86
+ self.f_attn = FlexCore()
87
+
88
+ self.out_proj = nn.Linear(dim, dim)
89
+ self.out_drop = nn.Dropout(proj_drop)
90
+
91
+ self.max_distance=16
92
+
93
+ def build_rel_bias(self, coords):
94
+ return torch.log1p(torch.cdist(coords, coords, p=2))
95
+
96
+ def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False):
97
+ N, L, C = x.shape
98
+ # ensure contiguous projection before chunking
99
+ x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous()
100
+ q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)]
101
+
102
+ q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
103
+ k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
104
+ v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
105
+
106
+ if attn_mask is not None:
107
+ maks_func = create_block_mask(
108
+ key_padding_mask(attn_mask), N, self.num_heads, L, L
109
+ )
110
+
111
+ qk_scale = q.size(-1) ** -0.5
112
+ x = self.f_attn(
113
+ q, k, v,
114
+ score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None,
115
+ block_mask = maks_func if attn_mask is not None else None,
116
+ return_lse=return_attn_score,
117
+ )
118
+
119
+ x = x.permute(0, 2, 1, 3).contiguous()
120
+ x = x.reshape(N, L, C).contiguous()
121
+
122
+ x = self.out_proj(x)
123
+ x = self.out_drop(x)
124
+ return x
models/transformer.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import math
3
+ from typing import Callable, List, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from einops import pack, repeat
10
+
11
+ from .flex_attn import Flex_Attention
12
+
13
+
14
+
15
+ class LayerNormFp32(nn.LayerNorm):
16
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
17
+
18
+ def forward(self, x: torch.Tensor):
19
+ orig_type = x.dtype
20
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
21
+ return x.to(orig_type)
22
+
23
+
24
+ class LayerNorm(nn.LayerNorm):
25
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
26
+
27
+ def forward(self, x: torch.Tensor):
28
+ orig_type = x.dtype
29
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
30
+ return x.to(orig_type)
31
+
32
+
33
+ class QuickGELU(nn.Module):
34
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+
39
+ class LayerScale(nn.Module):
40
+ def __init__(self, dim, init_values=1e-5, inplace=False):
41
+ super().__init__()
42
+ self.inplace = inplace
43
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
44
+
45
+ def forward(self, x):
46
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
47
+
48
+
49
+ class PatchDropout(nn.Module):
50
+ """
51
+ https://arxiv.org/abs/2212.00794
52
+ """
53
+
54
+ def __init__(self, prob, exclude_first_token=True):
55
+ super().__init__()
56
+ assert 0 <= prob < 1.
57
+ self.prob = prob
58
+ self.exclude_first_token = exclude_first_token # exclude CLS token
59
+
60
+ def forward(self, x):
61
+ if not self.training or self.prob == 0.:
62
+ return x
63
+
64
+ if self.exclude_first_token:
65
+ cls_tokens, x = x[:, :1], x[:, 1:]
66
+ else:
67
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
68
+
69
+ batch = x.size()[0]
70
+ num_tokens = x.size()[1]
71
+
72
+ batch_indices = torch.arange(batch)
73
+ batch_indices = batch_indices[..., None]
74
+
75
+ keep_prob = 1 - self.prob
76
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
77
+
78
+ rand = torch.randn(batch, num_tokens)
79
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
80
+
81
+ x = x[batch_indices, patch_indices_keep]
82
+
83
+ if self.exclude_first_token:
84
+ x = torch.cat((cls_tokens, x), dim=1)
85
+
86
+ return x
87
+
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(
91
+ self,
92
+ dim: int,
93
+ num_heads: int = 8,
94
+ qkv_bias: bool = True,
95
+ scaled_cosine: bool = True,
96
+ scale_heads: bool = False,
97
+ logit_scale_max: float = math.log(1. / 0.01),
98
+ batch_first: bool = True,
99
+ attn_drop: float = 0.,
100
+ proj_drop: float = 0.
101
+ ):
102
+ super().__init__()
103
+ self.scaled_cosine = scaled_cosine
104
+ self.scale_heads = scale_heads
105
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
106
+ self.num_heads = num_heads
107
+ self.head_dim = dim // num_heads
108
+ self.scale = self.head_dim ** -0.5
109
+ self.logit_scale_max = logit_scale_max
110
+ self.batch_first = batch_first
111
+ self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
112
+
113
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
114
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
115
+ if qkv_bias:
116
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
117
+ else:
118
+ self.in_proj_bias = None
119
+
120
+ if self.scaled_cosine:
121
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
122
+ else:
123
+ self.logit_scale = None
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ if self.scale_heads:
126
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
127
+ else:
128
+ self.head_scale = None
129
+ self.out_proj = nn.Linear(dim, dim)
130
+ self.out_drop = nn.Dropout(proj_drop)
131
+
132
+
133
+ def forward(self, x, coords, attn_mask: Optional[torch.Tensor] = None):
134
+ if self.batch_first:
135
+ x = x.transpose(0, 1)
136
+
137
+ L, N, C = x.shape
138
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
139
+ q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)
140
+ k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)
141
+ v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)
142
+
143
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
144
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
145
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
146
+ attn_mask = new_attn_mask
147
+
148
+ # if self.logit_scale is not None:
149
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
150
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
151
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
152
+
153
+ if attn_mask is not None:
154
+ attn = attn + attn_mask[:, None, None, :]
155
+ attn = attn.view(-1, L, L)
156
+ attn = attn.softmax(dim=-1)
157
+ attn = self.attn_drop(attn)
158
+
159
+ x = torch.bmm(attn, v)
160
+
161
+ if self.head_scale is not None:
162
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
163
+ x = x.view(-1, L, C)
164
+
165
+ x = x.transpose(0, 1).reshape(L, N, C)
166
+
167
+ if self.batch_first:
168
+ x = x.transpose(0, 1)
169
+
170
+ x = self.out_proj(x)
171
+ x = self.out_drop(x)
172
+ return x
173
+
174
+
175
+ class AttentionalPooler(nn.Module):
176
+ def __init__(
177
+ self,
178
+ d_model: int,
179
+ context_dim: int,
180
+ n_head: int = 8,
181
+ n_queries: int = 256,
182
+ norm_layer: Callable = LayerNorm,
183
+ ):
184
+ super().__init__()
185
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
186
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
187
+ self.ln_q = norm_layer(d_model)
188
+ self.ln_k = norm_layer(context_dim)
189
+
190
+ def forward(self, x: torch.Tensor):
191
+ N = x.shape[0]
192
+ x = self.ln_k(x)
193
+ q = self.ln_q(self.query)
194
+ out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
195
+ return out
196
+
197
+
198
+ class ResidualAttentionBlock(nn.Module):
199
+ def __init__(
200
+ self,
201
+ d_model: int,
202
+ n_head: int,
203
+ mlp_ratio: float = 4.0,
204
+ ls_init_value: float = None,
205
+ act_layer: Callable = nn.GELU,
206
+ norm_layer: Callable = LayerNorm,
207
+ is_cross_attention: bool = False,
208
+ batch_first: bool = True,
209
+ use_flex:bool = False,
210
+ dropout:float = 0.2,
211
+ use_rel_bias:bool = True,
212
+ ):
213
+ super().__init__()
214
+
215
+ self.ln_1 = norm_layer(d_model)
216
+
217
+ if use_flex:
218
+ print("Flex_Attention!")
219
+ self.attn = Flex_Attention(dim = d_model, num_heads=n_head, proj_drop=dropout, use_rel_bias=use_rel_bias)
220
+ else:
221
+ self.attn = Attention(dim = d_model, num_heads=n_head, batch_first=batch_first, proj_drop=dropout, attn_drop=dropout)
222
+
223
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
224
+ if is_cross_attention:
225
+ self.ln_1_kv = norm_layer(d_model)
226
+
227
+ self.ln_2 = norm_layer(d_model)
228
+ mlp_width = int(d_model * mlp_ratio)
229
+
230
+ self.mlp = nn.Sequential(OrderedDict([
231
+ ("c_fc", nn.Linear(d_model, mlp_width)),
232
+ ("gelu", act_layer()),
233
+ ("c_proj", nn.Linear(mlp_width, d_model))
234
+ ]))
235
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
236
+
237
+ def attention(
238
+ self,
239
+ q_x: torch.Tensor,
240
+ k_x: Optional[torch.Tensor] = None,
241
+ v_x: Optional[torch.Tensor] = None,
242
+ coords = None,
243
+ attn_mask: Optional[torch.Tensor] = None,
244
+ key_padding_mask=None,
245
+ ):
246
+ k_x = k_x if k_x is not None else q_x
247
+ v_x = v_x if v_x is not None else q_x
248
+
249
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
250
+
251
+ return self.attn(
252
+ q_x, coords=coords, attn_mask=key_padding_mask
253
+ )
254
+
255
+ def forward(
256
+ self,
257
+ q_x: torch.Tensor,
258
+ k_x: Optional[torch.Tensor] = None,
259
+ v_x: Optional[torch.Tensor] = None,
260
+ coords = None,
261
+ attn_mask: Optional[torch.Tensor] = None,
262
+ key_padding_mask = None,
263
+ ):
264
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
265
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
266
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, coords=coords, attn_mask=attn_mask, key_padding_mask=key_padding_mask))
267
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
268
+ return x
269
+
270
+
271
+ def _expand_token(token, batch_size: int):
272
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
273
+
274
+
275
+ class Transformer(nn.Module):
276
+ def __init__(
277
+ self,
278
+ width: int,
279
+ layers: int,
280
+ heads: int,
281
+ mlp_ratio: float = 4.0,
282
+ ls_init_value: float = None,
283
+ act_layer: Callable = nn.GELU,
284
+ norm_layer: Callable = LayerNorm,
285
+ batch_first: bool = True,
286
+ use_flex: bool = False,
287
+ dropout: float = False,
288
+ use_rel_bias: bool = True,
289
+ ):
290
+ super().__init__()
291
+ self.width = width
292
+ self.layers = layers
293
+ self.batch_first = batch_first
294
+ self.grad_checkpointing = False
295
+
296
+ self.resblocks = nn.ModuleList([
297
+ ResidualAttentionBlock(
298
+ width,
299
+ heads,
300
+ mlp_ratio,
301
+ ls_init_value=ls_init_value,
302
+ act_layer=act_layer,
303
+ norm_layer=norm_layer,
304
+ batch_first=batch_first,
305
+ use_flex=use_flex,
306
+ dropout=dropout,
307
+ use_rel_bias=use_rel_bias
308
+ )
309
+ for _ in range(layers)
310
+ ])
311
+
312
+ def get_cast_dtype(self) -> torch.dtype:
313
+ if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
314
+ return self.resblocks[0].mlp.c_fc.int8_original_dtype
315
+ return self.resblocks[0].mlp.c_fc.weight.dtype
316
+
317
+ def forward(self, x: torch.Tensor, coords = None, attn_mask: Optional[torch.Tensor] = None, key_padding_mask=None):
318
+ if not self.batch_first:
319
+ x = x.transpose(0, 1).contiguous() # NLD -> LND
320
+ for r in self.resblocks:
321
+ x = r(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, coords=coords)
322
+ if not self.batch_first:
323
+ x = x.transpose(0, 1).contiguous() # LND -> NLD
324
+ return x
325
+
326
+
327
+
328
+ class VisionTransformer(nn.Module):
329
+ def __init__(
330
+ self,
331
+ width: int,
332
+ layers: int,
333
+ heads: int,
334
+ mlp_ratio: float,
335
+ ls_init_value: float = None,
336
+ output_dim: int = 512,
337
+ patch_dropout: float = 0.,
338
+ no_ln_pre: bool = False,
339
+ pool_type: str = 'tok',
340
+ final_ln_after_pool: bool = False,
341
+ act_layer: Callable = nn.GELU,
342
+ norm_layer: Callable = LayerNorm,
343
+ output_tokens: bool = False,
344
+ img_embed: bool = False,
345
+ use_flex:bool = False,
346
+ dropout:float = 0.1,
347
+ num_registers: int = 0,
348
+ use_rel_bias: bool = True,
349
+ ):
350
+ super().__init__()
351
+ assert pool_type in ('tok', 'avg', 'none')
352
+ self.output_tokens = output_tokens
353
+
354
+ self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
355
+ self.output_dim = output_dim
356
+ self.img_embed = img_embed
357
+ self.num_registers = num_registers
358
+ self.positional_embedding = None
359
+ self.pre_linear = nn.Linear(768, width)
360
+
361
+
362
+ if num_registers>0:
363
+ self.register_token = nn.Parameter(torch.empty(num_registers, width))
364
+ nn.init.normal_(self.register_token, std=0.02)
365
+
366
+
367
+ self.positional_embedding = None
368
+
369
+
370
+ self.positional_embedding = None
371
+
372
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
373
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
374
+
375
+ self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
376
+ self.transformer = Transformer(
377
+ width,
378
+ layers,
379
+ heads,
380
+ mlp_ratio,
381
+ ls_init_value=ls_init_value,
382
+ act_layer=act_layer,
383
+ norm_layer=norm_layer,
384
+ use_flex=use_flex,
385
+ dropout=dropout,
386
+ use_rel_bias=use_rel_bias,
387
+ )
388
+
389
+ pool_dim = width
390
+ self.pool_type = pool_type
391
+
392
+ self.ln_post = norm_layer(pool_dim)
393
+
394
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
395
+ if self.pool_type == 'avg':
396
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
397
+ elif self.pool_type == 'tok':
398
+ pooled, tokens = x[:, 0], x[:, 1:]
399
+ else:
400
+ pooled = tokens = x
401
+
402
+ return pooled, tokens
403
+
404
+ def forward(self, x: torch.Tensor, coords=None, mask=None, key_padding_mask=None):
405
+ x = self.pre_linear(x)
406
+
407
+ if self.num_registers > 0:
408
+ r = repeat(self.register_token, 'n d -> b n d', b=x.size(0))
409
+ x, ps = pack([x, r], 'b * d')
410
+
411
+ x = self.patch_dropout(x)
412
+ x = self.ln_pre(x)
413
+ x = self.transformer(x, coords, mask, key_padding_mask=key_padding_mask)
414
+
415
+ if self.final_ln_after_pool:
416
+ pooled, tokens = self._global_pool(x)
417
+ pooled = self.ln_post(pooled)
418
+ else:
419
+ x = self.ln_post(x)
420
+ pooled, tokens = self._global_pool(x)
421
+
422
+ return pooled
423
+
424
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+
3
+ torch==2.6.0+cu124
4
+ torchvision==0.21.0+cu124
5
+ numpy==2.1.2
6
+ opencv-python-headless==4.11.0.86
7
+ openslide-bin==4.0.0.6
8
+ openslide-python==1.4.1
9
+ pandas==2.2.3
10
+ torchstain==1.4.1
11
+ fire==0.7.0
12
+ einops==0.8.1
13
+ huggingface_hub==0.32.2
14
+ h5py==3.13.0
15
+ scikit-learn==1.6.1
16
+ tensorboard==2.19.0
utils/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ CLASS_NAMES = ["LABEL_LOW", "LABEL_HIGH"]
utils/file_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import h5py
3
+
4
+ def save_pkl(filename, save_object):
5
+ writer = open(filename,'wb')
6
+ pickle.dump(save_object, writer)
7
+ writer.close()
8
+
9
+ def load_pkl(filename):
10
+ loader = open(filename,'rb')
11
+ file = pickle.load(loader)
12
+ loader.close()
13
+ return file
14
+
15
+
16
+ def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'):
17
+ file = h5py.File(output_path, mode)
18
+ for key, val in asset_dict.items():
19
+ data_shape = val.shape
20
+ if key not in file:
21
+ data_type = val.dtype
22
+ chunk_shape = (1, ) + data_shape[1:]
23
+ maxshape = (None, ) + data_shape[1:]
24
+ dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type)
25
+ dset[:] = val
26
+ if attr_dict is not None:
27
+ if key in attr_dict.keys():
28
+ for attr_key, attr_val in attr_dict[key].items():
29
+ dset.attrs[attr_key] = attr_val
30
+ else:
31
+ dset = file[key]
32
+ dset.resize(len(dset) + data_shape[0], axis=0)
33
+ dset[-data_shape[0]:] = val
34
+ file.close()
35
+ return output_path
utils/preprocessor.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from torchstain.base.normalizers.he_normalizer import HENormalizer
7
+ from torchstain.torch.utils import cov, percentile
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import to_pil_image
10
+
11
+
12
+ def preprocessor(pretrained=False, normalizer=None):
13
+ if pretrained:
14
+ mean = (0.485, 0.456, 0.406)
15
+ std = (0.229, 0.224, 0.225)
16
+ else:
17
+ mean = (0.5, 0.5, 0.5)
18
+ std = (0.5, 0.5, 0.5)
19
+
20
+ preprocess = transforms.Compose(
21
+ [
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.Lambda(lambda x: x) if normalizer == None else normalizer,
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=mean, std=std),
27
+ ]
28
+ )
29
+
30
+ return preprocess
31
+
32
+
33
+ """
34
+ Source code ported from: https://github.com/schaugf/HEnorm_python
35
+ Original implementation: https://github.com/mitkovetta/staining-normalization
36
+ """
37
+
38
+
39
+ class TorchMacenkoNormalizer(HENormalizer):
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ self.HERef = torch.tensor(
44
+ [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]]
45
+ )
46
+ self.maxCRef = torch.tensor([1.9705, 1.0308])
47
+
48
+ # Avoid using deprecated torch.lstsq (since 1.9.0)
49
+ self.updated_lstsq = hasattr(torch.linalg, "lstsq")
50
+
51
+ def __convert_rgb2od(self, I, Io, beta):
52
+ I = I.permute(1, 2, 0)
53
+
54
+ # calculate optical density
55
+ OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)
56
+
57
+ # remove transparent pixels
58
+ ODhat = OD[~torch.any(OD < beta, dim=1)]
59
+
60
+ return OD, ODhat
61
+
62
+ def __find_HE(self, ODhat, eigvecs, alpha):
63
+ # project on the plane spanned by the eigenvectors corresponding to the two
64
+ # largest eigenvalues
65
+ That = torch.matmul(ODhat, eigvecs)
66
+ phi = torch.atan2(That[:, 1], That[:, 0])
67
+ # print(phi.size())
68
+
69
+ minPhi = percentile(phi, alpha)
70
+ maxPhi = percentile(phi, 100 - alpha)
71
+
72
+ vMin = torch.matmul(
73
+ eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
74
+ ).unsqueeze(1)
75
+ vMax = torch.matmul(
76
+ eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
77
+ ).unsqueeze(1)
78
+
79
+ # a heuristic to make the vector corresponding to hematoxylin first and the
80
+ # one corresponding to eosin second
81
+ HE = torch.where(
82
+ vMin[0] > vMax[0],
83
+ torch.cat((vMin, vMax), dim=1),
84
+ torch.cat((vMax, vMin), dim=1),
85
+ )
86
+
87
+ return HE
88
+
89
+ def __find_concentration(self, OD, HE):
90
+ # rows correspond to channels (RGB), columns to OD values
91
+ Y = OD.T
92
+
93
+ # determine concentrations of the individual stains
94
+ if not self.updated_lstsq:
95
+ return torch.lstsq(Y, HE)[0][:2]
96
+
97
+ return torch.linalg.lstsq(HE, Y)[0]
98
+
99
+ def __compute_matrices(self, I, Io, alpha, beta):
100
+ OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
101
+
102
+ # compute eigenvectors
103
+ _, eigvecs = torch.linalg.eigh(cov(ODhat.T))
104
+ eigvecs = eigvecs[:, [1, 2]]
105
+
106
+ HE = self.__find_HE(ODhat, eigvecs, alpha)
107
+
108
+ C = self.__find_concentration(OD, HE)
109
+ maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
110
+
111
+ return HE, C, maxC
112
+
113
+ def fit(self, I, Io=240, alpha=1, beta=0.15):
114
+ HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)
115
+
116
+ self.HERef = HE
117
+ self.maxCRef = maxC
118
+
119
+ def normalize(
120
+ self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int"
121
+ ):
122
+ """Normalize staining appearence of H&E stained images
123
+
124
+ Example use:
125
+ see test.py
126
+
127
+ Input:
128
+ I: RGB input image: tensor of shape [C, H, W] and type uint8
129
+ Io: (optional) transmitted light intensity
130
+ alpha: percentile
131
+ beta: transparency threshold
132
+ stains: if true, return also H & E components
133
+
134
+ Output:
135
+ Inorm: normalized image
136
+ H: hematoxylin image
137
+ E: eosin image
138
+
139
+ Reference:
140
+ A method for normalizing histology slides for quantitative analysis. M.
141
+ Macenko et al., ISBI 2009
142
+ """
143
+
144
+ c, h, w = I.shape
145
+
146
+ HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
147
+
148
+ # normalize stain concentrations
149
+ C *= (self.maxCRef / maxC).unsqueeze(-1)
150
+
151
+ # recreate the image using reference mixing matrix
152
+ Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
153
+ Inorm = torch.clip(Inorm, 0, 255)
154
+
155
+ Inorm = Inorm.reshape(c, h, w).float() / 255.0
156
+ Inorm = torch.clip(Inorm, 0.0, 1.0)
157
+
158
+ H, E = None, None
159
+
160
+ if stains:
161
+ H = torch.mul(
162
+ Io,
163
+ torch.exp(
164
+ torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
165
+ ),
166
+ )
167
+ H[H > 255] = 255
168
+ H = H.T.reshape(h, w, c).int()
169
+
170
+ E = torch.mul(
171
+ Io,
172
+ torch.exp(
173
+ torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
174
+ ),
175
+ )
176
+ E[E > 255] = 255
177
+ E = E.T.reshape(h, w, c).int()
178
+
179
+ return Inorm, H, E
180
+
181
+
182
+ class MacenkoNormalizer:
183
+ def __init__(self, target_path=None, prob=1):
184
+ self.transform_before_macenko = transforms.Compose(
185
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]
186
+ )
187
+ self.normalizer = TorchMacenkoNormalizer()
188
+
189
+ ext = os.path.splitext(target_path)[1].lower()
190
+ if ext in [".jpg", ".jpeg", ".png"]:
191
+ target = Image.open(target_path)
192
+ self.normalizer.fit(self.transform_before_macenko(target))
193
+ elif ext in [".pt"]:
194
+ target = torch.load(target_path)
195
+ self.normalizer.HERef = target["HERef"]
196
+ self.normalizer.maxCRef = target["maxCRef"]
197
+
198
+ else:
199
+ raise ValueError(f"Invalid extension: {ext}")
200
+ self.prob = prob
201
+
202
+ def __call__(self, image):
203
+ t_to_transform = self.transform_before_macenko(image)
204
+ try:
205
+ image_macenko, _, _ = self.normalizer.normalize(
206
+ I=t_to_transform, stains=False, form="chw", dtype="float"
207
+ )
208
+ if torch.any(torch.isnan(image_macenko)):
209
+ return image
210
+ else:
211
+ image_macenko = to_pil_image(image_macenko)
212
+ return image_macenko
213
+ except Exception as e:
214
+ if "kthvalue()" in str(e) or "linalg.eigh" in str(e):
215
+ pass
216
+ else:
217
+ print(str(e))
218
+ return image
utils/wsi_utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from openslide import OpenSlide
5
+
6
+
7
+ def extract_tissue_patch_coords(
8
+ wsi_path: str,
9
+ patch_size: int = 256,
10
+ step_size: int = 256,
11
+ downsample_threshold: float = 64,
12
+ threshold: int = 8,
13
+ max_val: int = 255,
14
+ median_kernel: int = 7,
15
+ close_size: int = 4,
16
+ min_effective_area_factor: float = 100, # multiplied by (ref_area)^2
17
+ ref_area: int = 512,
18
+ min_hole_area_factor: float = 16, # multiplied by (ref_area)^2
19
+ max_n_holes: int = 8,
20
+ )-> list:
21
+ """
22
+ Extract patches from the full-resolution image whose centers fall within tissue regions.
23
+
24
+ Process:
25
+ 1. Open the WSI.
26
+ 2. Select a segmentation level and compute a binary mask.
27
+ 3. Find contours and holes from the mask and filter them using effective area criteria.
28
+ 4. Scale the external contours and holes to full resolution.
29
+ 5. Slide a window over the full-resolution image and extract patches if the center is in tissue.
30
+
31
+ Returns:
32
+ A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches.
33
+ """
34
+ slide = OpenSlide(wsi_path)
35
+ full_width, full_height = slide.level_dimensions[0]
36
+
37
+ seg_level, scale = select_segmentation_level(slide, downsample_threshold)
38
+ binary_mask = compute_segmentation_mask(
39
+ slide, seg_level, threshold, max_val, median_kernel, close_size
40
+ )
41
+
42
+ # Compute thresholds for effective area and hole area
43
+ effective_area_thresh = min_effective_area_factor * (
44
+ ref_area**2 / (scale[0] * scale[1])
45
+ )
46
+ hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1]))
47
+
48
+ ext_contours, holes_list = filter_contours_and_holes(
49
+ binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes
50
+ )
51
+ if not ext_contours:
52
+ raise ValueError("No valid tissue contours found.")
53
+
54
+ tissue_contours = scale_contours(ext_contours, scale)
55
+ scaled_holes = [scale_contours(holes, scale) for holes in holes_list]
56
+
57
+ coords = []
58
+ for y in range(0, full_height - patch_size + 1, step_size):
59
+ for x in range(0, full_width - patch_size + 1, step_size):
60
+ center_x = x + patch_size // 2
61
+ center_y = y + patch_size // 2
62
+ if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes):
63
+ continue
64
+ coords.append((x, y))
65
+
66
+ if not coords:
67
+ raise ValueError("No available patches")
68
+ return coords
69
+
70
+
71
+ def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64):
72
+ """
73
+ Select a segmentation level whose downsample factor is at least the specified threshold.
74
+
75
+ Returns:
76
+ level (int): Chosen level index.
77
+ scale (tuple): Downsample factors (sx, sy) for that level.
78
+ """
79
+ level = slide.get_best_level_for_downsample(downsample_threshold)
80
+ ds = slide.level_downsamples[level]
81
+ if not isinstance(ds, (tuple, list)):
82
+ ds = (ds, ds)
83
+ return level, ds
84
+
85
+
86
+ def compute_segmentation_mask(
87
+ slide: OpenSlide,
88
+ level: int,
89
+ threshold: int = 20,
90
+ max_val: int = 255,
91
+ median_kernel: int = 7,
92
+ close_size: int = 4,
93
+ ):
94
+ """
95
+ Compute a binary mask for tissue segmentation at the specified level.
96
+
97
+ Process:
98
+ - Read the image at the given level and convert to RGB.
99
+ - Convert the image to HSV and extract the saturation channel.
100
+ - Apply median blur.
101
+ - Apply binary thresholding (either fixed or Otsu).
102
+ - Apply morphological closing.
103
+
104
+ Returns:
105
+ binary (ndarray): Binary mask image.
106
+ """
107
+ img = np.array(
108
+ slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB")
109
+ )
110
+ hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
111
+ sat = hsv[:, :, 1]
112
+ blurred = cv2.medianBlur(sat, median_kernel)
113
+ _, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY)
114
+ if close_size > 0:
115
+ kernel = np.ones((close_size, close_size), np.uint8)
116
+ binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
117
+ return binary
118
+
119
+
120
+ def filter_contours_and_holes(
121
+ binary_mask: np.ndarray,
122
+ min_effective_area: float,
123
+ min_hole_area: float,
124
+ max_n_holes: int,
125
+ ):
126
+ """
127
+ Find contours from the binary mask and filter them based on effective area.
128
+
129
+ For each external contour (one with no parent), identify child contours (holes),
130
+ sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area.
131
+ The effective area is computed as the area of the external contour minus the sum of areas
132
+ of the selected holes. Only contours with effective area above min_effective_area are retained.
133
+
134
+ Returns:
135
+ filtered_contours (list): List of external contours (numpy arrays).
136
+ holes_list (list): Corresponding list of lists of hole contours.
137
+ """
138
+ contours, hierarchy = cv2.findContours(
139
+ binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
140
+ )
141
+ if hierarchy is None:
142
+ return [], []
143
+ hierarchy = hierarchy[0] # shape: (N, 4)
144
+ filtered_contours = []
145
+ holes_list = []
146
+ for idx, h in enumerate(hierarchy):
147
+ if h[3] != -1:
148
+ continue # Only external contours
149
+ ext_cont = contours[idx]
150
+ ext_area = cv2.contourArea(ext_cont)
151
+ # Find child contours (holes)
152
+ hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx]
153
+ # Sort holes by area descending and keep up to max_n_holes
154
+ sorted_holes = sorted(
155
+ [contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True
156
+ )
157
+ selected_holes = [
158
+ hole
159
+ for hole in sorted_holes[:max_n_holes]
160
+ if cv2.contourArea(hole) > min_hole_area
161
+ ]
162
+ total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes)
163
+ effective_area = ext_area - total_hole_area
164
+ if effective_area > min_effective_area:
165
+ filtered_contours.append(ext_cont)
166
+ holes_list.append(selected_holes)
167
+ return filtered_contours, holes_list
168
+
169
+
170
+ def scale_contours(contours: list, scale: tuple) -> list:
171
+ """
172
+ Scale contour coordinates by the provided scale factors.
173
+
174
+ Args:
175
+ contours: List of contours (each a numpy array of points).
176
+ scale: Tuple (sx, sy) for scaling.
177
+
178
+ Returns:
179
+ List of scaled contours.
180
+ """
181
+ scaled = []
182
+ for cont in contours:
183
+ scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32))
184
+ return scaled
185
+
186
+
187
+ def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool:
188
+ """
189
+ Check if point (x, y) lies within any external contour and not inside its corresponding holes.
190
+
191
+ For each external contour in ext_contours (paired with holes_list),
192
+ if the point is inside the contour and not inside any of its holes, return True.
193
+ """
194
+ for cont, holes in zip(ext_contours, holes_list):
195
+ if cv2.pointPolygonTest(cont, (x, y), False) >= 0:
196
+ inside_hole = False
197
+ for hole in holes:
198
+ if cv2.pointPolygonTest(hole, (x, y), False) >= 0:
199
+ inside_hole = True
200
+ break
201
+ if not inside_hole:
202
+ return True
203
+ return False
204
+
205
+
206
+ def tile(x: torch.Tensor, size: int):
207
+ C, H, W = x.shape[-3:]
208
+
209
+ pad_h = (size - H % size) % size
210
+ pad_w = (size - W % size) % size
211
+ if pad_h > 0 or pad_w > 0:
212
+ x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
213
+
214
+ nh, nw = x.size(2) // size, x.size(3) // size
215
+ return (
216
+ x.view(-1, C, nh, size, nw, size)
217
+ .permute(0, 2, 4, 1, 3, 5)
218
+ .reshape(-1, C, size, size)
219
+ )