Commit
·
03ae676
0
Parent(s):
init commit
Browse files- LICENSE +58 -0
- README.md +94 -0
- config.json +6 -0
- datasets/dataset_WSI.py +52 -0
- inference.py +52 -0
- models/aggregator.py +222 -0
- models/cls_modules.py +168 -0
- models/exaonepath.py +158 -0
- models/feature_extractor.py +357 -0
- models/flex_attn.py +124 -0
- models/transformer.py +424 -0
- requirements.txt +16 -0
- utils/constants.py +1 -0
- utils/file_utils.py +35 -0
- utils/preprocessor.py +218 -0
- utils/wsi_utils.py +219 -0
LICENSE
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
EXAONEPath AI Model License Agreement 1.0 - NC
|
2 |
+
|
3 |
+
This License Agreement (“Agreement”) is entered into between you (“Licensee”) and LG Management Development Institute Co., Ltd. (“Licensor”), governing the use of the EXAONEPath AI Model (“Model”). By downloading, installing, copying, or using the Model, you agree to comply with and be bound by the terms of this Agreement. If you do not agree to all the terms, you must not download, install, copy, or use the Model. This Agreement constitutes a binding legal agreement between the Licensee and Licensor.
|
4 |
+
|
5 |
+
1. Definitions
|
6 |
+
1.1 Model: The artificial intelligence model provided by Licensor, which includes any software, algorithms, machine learning models, or related components supplied by Licensor. This definition extends to encompass all updates, enhancements, improvements, bug fixes, patches, or other modifications that may be provided by Licensor from time to time, whether automatically or manually implemented.
|
7 |
+
1.2 Derivatives: Any modifications, alterations, enhancements, improvements, adaptations, or derivative works of the Model created by Licensee or any third party. This includes changes made to the Model's architecture, parameters, data processing methods, or any other aspect of the Model that results in a modification of its functionality or output.
|
8 |
+
1.3 Output: Any data, results, content, predictions, analyses, insights, or other materials generated by the Model or Derivatives, regardless of whether they are in their original form or have been further processed or modified by the Licensee. This includes, but is not limited to, textual or numerical produced directly or indirectly through the use of the Model.
|
9 |
+
1.4 Licensor: LG Management Development Institute Co., Ltd., the owner, developer, and provider of the EXAONEPath AI Model. The Licensor holds all rights, title, and interest in the Model and is responsible for granting licenses to use the Model under the terms specified in this Agreement.
|
10 |
+
1.5 Licensee: The individual, organization, corporation, academic institution, government agency, or other entity using or intending to use the Model under the terms and conditions of this Agreement. The Licensee is responsible for ensuring compliance with the Agreement by all authorized users who access or utilize the Model on behalf of the Licensee.
|
11 |
+
|
12 |
+
2. License Grant
|
13 |
+
2.1 Grant of License: Subject to the terms and conditions outlined in this Agreement, the Licensor hereby grants the Licensee a limited, non-exclusive, non-transferable, worldwide, and revocable license to:
|
14 |
+
a. Access, download, install, and use the Model solely for research purposes. This includes evaluation, testing, academic research and experimentation.
|
15 |
+
b. Publicly disclose research results and findings derived from the use of the Model or Derivatives, including publishing papers or presentations.
|
16 |
+
c. Modify the Model and create Derivatives based on the Model, provided that such modifications and Derivatives are used exclusively for research purposes. The Licensee may conduct experiments, perform analyses, and apply custom modifications to the Model to explore its capabilities and performance under various scenarios. If the Model is modified, the modified Model must include "EXAONEPath" at the beginning of its name.
|
17 |
+
d. Distribute the Model and Derivatives in each case with a copy of this Agreement.
|
18 |
+
2.2 Scope of License: The license granted herein does not authorize the Licensee to use the Model for any purpose not explicitly permitted under this Agreement. Any use beyond the scope of this license, including any commercial application or external distribution, is strictly prohibited unless explicitly agreed upon in writing by the Licensor.
|
19 |
+
|
20 |
+
3. Restrictions
|
21 |
+
3.1 Commercial Use: The Licensee is expressly prohibited from using the Model, Derivatives, or Output for any commercial purposes, including but not limited to, developing or deploying products, services, or applications that generate revenue, whether directly or indirectly. Any commercial exploitation of the Model or its derivatives requires a separate commercial license agreement with the Licensor. Furthermore, the Licensee shall not use the Model, Derivatives or Output to develop or improve other models, except for research purposes, which is explicitly permitted.
|
22 |
+
3.2 Reverse Engineering: The Licensee shall not decompile, disassemble, reverse engineer, or attempt to derive the source code, underlying ideas, algorithms, or structure of the Model, except to the extent that such activities are expressly permitted by applicable law. Any attempt to bypass or circumvent technological protection measures applied to the Model is strictly prohibited.
|
23 |
+
3.3 Unlawful Use: The Licensee shall not use the Model and Derivatives for any illegal, fraudulent, or unauthorized activities, nor for any purpose that violates applicable laws or regulations. This includes but is not limited to the creation, distribution, or dissemination of malicious, deceptive, or unlawful content.
|
24 |
+
3.4 Ethical Use: The Licensee shall ensure that the Model or Derivatives is used in an ethical and responsible manner, adhering to the following guidelines:
|
25 |
+
a. The Model and Derivatives shall not be used to generate, propagate, or amplify false, misleading, or harmful information, including fake news, misinformation, or disinformation.
|
26 |
+
b. The Model and Derivatives shall not be employed to create, distribute, or promote content that is discriminatory, harassing, defamatory, abusive, or otherwise offensive to individuals or groups based on race, gender, sexual orientation, religion, nationality, or other protected characteristics.
|
27 |
+
c. The Model and Derivatives shall not infringe on the rights of others, including intellectual property rights, privacy rights, or any other rights recognized by law. The Licensee shall obtain all necessary permissions and consents before using the Model and Derivatives in a manner that may impact the rights of third parties.
|
28 |
+
d. The Model and Derivatives shall not be used in a way that causes harm, whether physical, mental, emotional, or financial, to individuals, organizations, or communities. The Licensee shall take all reasonable measures to prevent misuse or abuse of the Model and Derivatives that could result in harm or injury.
|
29 |
+
|
30 |
+
4. Ownership
|
31 |
+
4.1 Intellectual Property: All rights, title, and interest in and to the Model, including any modifications, Derivatives, and associated documentation, are and shall remain the exclusive property of the Licensor. The Licensee acknowledges that this Agreement does not transfer any ownership rights to the Licensee. All trademarks, service marks, and logos associated with the Model are the property of the Licensor.
|
32 |
+
4.2 Output: All output generated by the Model from Licensee Data ("Output") shall be the sole property of the Licensee. Licensor hereby waives any claim of ownership or intellectual property rights to the Output. Licensee is solely responsible for the legality, accuracy, quality, integrity, and use of the Output.
|
33 |
+
4.3 Attribution: In any publication or presentation of results obtained using the Model, the Licensee shall provide appropriate attribution to the Licensor, citing the Model's name and version, along with any relevant documentation or references specified by the Licensor.
|
34 |
+
|
35 |
+
5. No Warranty
|
36 |
+
5.1 “As-Is” Basis: The Model, Derivatives, and Output are provided on an “as-is” and “as-available” basis, without any warranties or representations of any kind, whether express, implied, or statutory. The Licensor disclaims all warranties, including but not limited to, implied warranties of merchantability, fitness for a particular purpose, accuracy, reliability, non-infringement, or any warranty arising from the course of dealing or usage of trade.
|
37 |
+
5.2 Performance and Reliability: The Licensor does not warrant or guarantee that the Model, Derivatives or Output will meet the Licensee’s requirements, that the operation of the Model, Derivatives or Output will be uninterrupted or error-free, or that defects in the Model will be corrected. The Licensee acknowledges that the use of the Model, Derivatives or Output is at its own risk and that the Model, Derivatives or Output may contain bugs, errors, or other limitations.
|
38 |
+
5.3 No Endorsement: The Licensor does not endorse, approve, or certify any results, conclusions, or recommendations derived from the use of the Model. The Licensee is solely responsible for evaluating the accuracy, reliability, and suitability of the Model for its intended purposes.
|
39 |
+
|
40 |
+
6. Limitation of Liability
|
41 |
+
6.1 No Liability for Damages: To the fullest extent permitted by applicable law, in no event shall the Licensor be liable for any special, incidental, indirect, consequential, exemplary, or punitive damages, including but not limited to, damages for loss of business profits, business interruption, loss of business information, loss of data, or any other pecuniary or non-pecuniary loss arising out of or in connection with the use or inability to use the Model, Derivatives or any Output, even if the Licensor has been advised of the possibility of such damages.
|
42 |
+
6.2 Indemnification: The Licensee agrees to indemnify, defend, and hold harmless the Licensor, its affiliates, officers, directors, employees, and agents from and against any claims, liabilities, damages, losses, costs, or expenses (including reasonable attorneys' fees) arising out of or related to the Licensee's use of the Model, any Derivatives, or any Output, including any violation of this Agreement or applicable laws. This includes, but is not limited to, ensuring compliance with copyright laws, privacy regulations, defamation laws, and any other applicable legal or regulatory requirements.
|
43 |
+
|
44 |
+
7. Termination
|
45 |
+
7.1 Termination by Licensor: The Licensor reserves the right to terminate this Agreement and revoke the Licensee’s rights to use the Model at any time, with or without cause, and without prior notice if the Licensee breaches any of the terms or conditions of this Agreement. Termination shall be effective immediately upon notice.
|
46 |
+
7.2 Effect of Termination: Upon termination of this Agreement, the Licensee must immediately cease all use of the Model, Derivatives, and Output and destroy all copies of the Model, Derivatives, and Output in its possession or control, including any backup or archival copies. The Licensee shall certify in writing to the Licensor that such destruction has been completed.
|
47 |
+
7.3 Survival: The provisions of this Agreement that by their nature should survive termination, including but not limited to, Sections 4 (Ownership), 5 (No Warranty), 6 (Limitation of Liability), and this Section 7 (Termination), shall continue to apply after termination.
|
48 |
+
|
49 |
+
8. Governing Law
|
50 |
+
8.1 Governing Law: This Agreement shall be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflict of laws principles.
|
51 |
+
8.2 Arbitration: Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English.
|
52 |
+
|
53 |
+
9. Alterations
|
54 |
+
9.1 Modifications: The Licensor reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on the Licensor’s website or through other means of communication. The Licensee is responsible for reviewing the Agreement periodically for changes. Continued use of the Model after any modifications have been made constitutes acceptance of the revised Agreement.
|
55 |
+
9.2 Entire Agreement: This Agreement constitutes the entire agreement between the Licensee and Licensor concerning the subject matter hereof and supersedes all prior or contemporaneous oral or written agreements, representations, or understandings. Any terms or conditions of any purchase order or other document submitted by the Licensee in connection with the Model that are in addition to, different from, or inconsistent with the terms and conditions of this Agreement are not binding on the Licensor and are void.
|
56 |
+
|
57 |
+
By downloading, installing, or using the EXAONEPath AI Model, the Licensee acknowledges that it has read, understood, and agrees to be bound by the terms and conditions of this Agreement.
|
58 |
+
|
README.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
license_name: exaonepath
|
4 |
+
license_link: LICENSE
|
5 |
+
tags:
|
6 |
+
- lg-ai
|
7 |
+
- EXAONEPath-1.5
|
8 |
+
- pathology
|
9 |
+
---
|
10 |
+
# EXAONE Path 1.5
|
11 |
+
## Introduction
|
12 |
+
EXAONE Path 1.5 is a whole slide image level(WSI-level) classification framework designed for downstream tasks in pathology, such as cancer subtyping, molecular subtyping and mutation prediction. It builds upon our previous work, EXAONE Path v1.0, which focused on patch-wise feature extraction by dividing a WSI into patches and embedding each patch into a feature vector.
|
13 |
+
In EXAONE Path 1.5, we extend this pipeline to take an entire WSI as input. Each patch is first processed using the pretrained EXAONE Path 1.0 encoder to extract patch-level features. These features are then aggregated using a ViT-based (Vision Transformer) aggregator module to produce a slide-level representation.
|
14 |
+
This aggregated representation is subsequently passed through a linear classifier to perform downstream tasks such as molecular subtyping, tumor subtyping, and mutation prediction.
|
15 |
+
To effectively train the aggregator, we adopt a two-stage learning process:
|
16 |
+
Pretraining: We employ multimodal learning by aligning slide images with various mRNA gene expression profiles to learn semantically meaningful slide-level representations.
|
17 |
+
Fine-tuning: The pretrained model is then adapted to specific downstream classification tasks.
|
18 |
+
In this repository, we release the model trained for EGFR mutation prediction in lung adenocarcinoma (LUAD), enabling researchers to leverage our pipeline for similar molecular pathology applications.
|
19 |
+
## Quickstart
|
20 |
+
### 1. Hardware Requirements ###
|
21 |
+
- NVIDIA GPU is required
|
22 |
+
- Minimum 40GB GPU memory recommended
|
23 |
+
- Tested on Ubuntu 22.04 with NVIDIA driver version 550.144.03
|
24 |
+
Note: This implementation requires NVIDIA GPU and drivers. The provided environment setup specifically uses CUDA-enabled PyTorch, making NVIDIA GPU mandatory for running the model.
|
25 |
+
### 2. Environment Setup
|
26 |
+
```
|
27 |
+
pip install -r requirements.txt
|
28 |
+
```
|
29 |
+
### 3-a. Load the model & Inference
|
30 |
+
#### Load model with HuggingFace
|
31 |
+
|
32 |
+
```python
|
33 |
+
from models.exaonepath import EXAONEPathV1p5Downstream
|
34 |
+
hf_token = "YOUR_HUGGING_FACE_ACCESS_TOKEN"
|
35 |
+
model = EXAONEPathV1p5Downstream.from_pretrained("LGAI-EXAONE/EXAONE-Path-1.5", use_auth_token=hf_token)
|
36 |
+
slide_path = './samples/wsis/1/1.svs'
|
37 |
+
probs = model(slide_path)
|
38 |
+
```
|
39 |
+
#### Fast CLI Inference
|
40 |
+
Before running the command below, make sure you update your Hugging Face token.
|
41 |
+
Open `tokens.py` and replace the placeholder with your actual token:
|
42 |
+
|
43 |
+
```python
|
44 |
+
HF_TOKEN = "YOUR_HUGGING_FACE_ACCESS_TOKEN"
|
45 |
+
```
|
46 |
+
|
47 |
+
Then, run inference with:
|
48 |
+
```bash
|
49 |
+
python inference.py --svs_path ./samples/wsis/1/1.svs
|
50 |
+
```
|
51 |
+
### 3-b. Fine-tuning with Pretrained Weights
|
52 |
+
We provide example scripts and files to help you fine-tune the model on your own dataset.
|
53 |
+
The provided script fine-tunes the model using pretrained weights stored in `./pretrained_weight.pth`.
|
54 |
+
|
55 |
+
#### Extract Features from WSI Images
|
56 |
+
To train the model using WSI images and their corresponding labels,
|
57 |
+
you must first extract patch-level features from each WSI using our provided feature extractor.
|
58 |
+
|
59 |
+
```bash
|
60 |
+
python feature_extract.py --input_dir ./samples/wsis/ --output_dir ./samples/feats/
|
61 |
+
```
|
62 |
+
This will generate .pt feature files in the output_dir.
|
63 |
+
|
64 |
+
#### Fine-tuning
|
65 |
+
```bash
|
66 |
+
bash tuning_script.sh
|
67 |
+
```
|
68 |
+
Inside tuning_script.sh, you can modify the following variables to match your dataset:
|
69 |
+
```bash
|
70 |
+
FEAT_PATH=./samples/feats
|
71 |
+
LABEL_PATH=./samples/label/label.csv
|
72 |
+
LABEL_DICT="{'n':0, 'y':1}"
|
73 |
+
SPLIT_PATH=./samples/splits
|
74 |
+
```
|
75 |
+
Change these paths to point to your own feature, label, and split files to start training.
|
76 |
+
|
77 |
+
## Model Performance Comparison
|
78 |
+
| Metric: AUC | Titan(Conch v1.5+iBot, image+text) | PRISM (virchow+pe receiver, Image+text) | CHIEF (CTransPath + CLAM, Image+text, clam+wsi contrastive) | Prov-GigaPath (GigaPath+LongNet, Image-only, mask precision manner) | UNI2-h + CLAM (Image-only) | EXAONE Path 1.5(image+gene expression) |
|
79 |
+
|--------------------------|----------------------------------|-----------------------------------------|--------------------------------------------------------------|------------------------------------------------------------------------|-----------------------------|------------------|
|
80 |
+
| **TMB (cutoff 10)** | 0.74 | 0.73 | 0.70 | 0.69 | 0.71 | 0.71 |
|
81 |
+
| **LUAD-EGFR-mut** | 0.76 | 0.80 | 0.73 | 0.73 | 0.79 | 0.81 |
|
82 |
+
| **LUAD-KRAS-mut** | 0.61 | 0.65 | 0.61 | 0.66 | 0.60 | 0.63 |
|
83 |
+
| **LUAD-Gene-overexp[1]** | 0.75 | 0.68 | 0.71 | 0.71 | 0.74 | 0.72 |
|
84 |
+
| **CRC-MSS/MSI** | 0.89 | 0.88 | 0.86 | 0.90 | 0.90 | 0.89 |
|
85 |
+
| **BRCA-ER_PR_HER2** | 0.82 | 0.79 | 0.76 | 0.79 | 0.81 | 0.77 |
|
86 |
+
| **Pan-cancer-Gene-mut[2]** | 0.79 | 0.77 | 0.73 | 0.74 | 0.77 | 0.76 |
|
87 |
+
| **Avg. AUC** | 0.77 | 0.76 | 0.73 | 0.74 | 0.77 | 0.76 |
|
88 |
+
|
89 |
+
[1]: **lung-gene-overexp**: total 11 genes were evaluated: LAG3, CLDN6, CD274, EGFR, ERBB2, ERBB3, CD276, VTCN1, TACSTD2, FOLR1, MET.
|
90 |
+
|
91 |
+
[2]: **Pan-cancer-Gene-mut**: total 7 genes were evaluated: TP53, KRAS, ALK, PIK3CA, MET, EGFR, PTEN
|
92 |
+
|
93 |
+
## License
|
94 |
+
The model is licensed under [EXAONEPath AI Model License Agreement 1.0 - NC](./LICENSE)
|
config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"step_size": 256,
|
3 |
+
"patch_size": 256,
|
4 |
+
"num_sampled_patch": 16384,
|
5 |
+
"architecture": "EXAONEPathV1p5Downstream"
|
6 |
+
}
|
datasets/dataset_WSI.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from openslide import OpenSlide
|
4 |
+
from utils.preprocessor import MacenkoNormalizer, preprocessor
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class WSIPatchDataset(Dataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
coords,
|
12 |
+
wsi_path,
|
13 |
+
pretrained=False,
|
14 |
+
patch_size=256,
|
15 |
+
patch_level=0,
|
16 |
+
macenko=True,
|
17 |
+
return_coord=False,
|
18 |
+
):
|
19 |
+
self.pretrained = pretrained
|
20 |
+
self.wsi = OpenSlide(wsi_path)
|
21 |
+
self.patch_size = patch_size
|
22 |
+
self.patch_level = patch_level
|
23 |
+
self.return_coord = return_coord
|
24 |
+
|
25 |
+
if macenko:
|
26 |
+
normalizer = MacenkoNormalizer(
|
27 |
+
target_path=os.path.join(
|
28 |
+
os.path.dirname(os.path.dirname(os.path.join(__file__))),
|
29 |
+
"macenko_target",
|
30 |
+
"macenko_param.pt",
|
31 |
+
)
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
normalizer = None
|
35 |
+
|
36 |
+
self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer)
|
37 |
+
self.coords = coords
|
38 |
+
self.length = len(self.coords)
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return self.length
|
42 |
+
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
coord = self.coords[idx]
|
45 |
+
img = self.wsi.read_region(
|
46 |
+
coord, self.patch_level, (self.patch_size, self.patch_size)
|
47 |
+
).convert("RGB")
|
48 |
+
img = self.roi_transforms(img)
|
49 |
+
if self.return_coord:
|
50 |
+
return img, torch.tensor(coord)
|
51 |
+
else:
|
52 |
+
return img
|
inference.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
from models.exaonepath import EXAONEPathV1p5Downstream
|
6 |
+
from utils.constants import CLASS_NAMES
|
7 |
+
from tokens import HF_TOKEN
|
8 |
+
|
9 |
+
|
10 |
+
def infer(model, input_file):
|
11 |
+
print("Processing", input_file, "...")
|
12 |
+
probs = model(input_file)
|
13 |
+
result_str = "Result -- " + " / ".join(
|
14 |
+
[f"{name}: {probs[i].item():.4f}" for i, name in enumerate(CLASS_NAMES)]
|
15 |
+
)
|
16 |
+
print(result_str + "\n")
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
parser = argparse.ArgumentParser(description="Inference")
|
21 |
+
parser.add_argument('--svs_path', type=str, default='./samples/wsis/1/1.svs', help="Path to the .svs file")
|
22 |
+
parser.add_argument('--svs_dir', type=str, default='./samples_CRC', help="")
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
hf_token = HF_TOKEN
|
27 |
+
# model = EXAONEPathV1p5Downstream.from_pretrained("LGAI-EXAONE/EXAONE-Path-1.5", use_auth_token=hf_token)
|
28 |
+
model = EXAONEPathV1p5Downstream(num_sampled_patch=16384)
|
29 |
+
|
30 |
+
|
31 |
+
# qwe = torch.load('./pytorch_model_ori.bin')
|
32 |
+
# aaa = model.load_state_dict(qwe, strict=False)
|
33 |
+
# hw_w = torch.load('/mnt/shared/shared_medical/shared/hi.choi/MOM/logs_eval_25/closebench2/ours/BIOMARKER_SMC_SMC/CRCSensor/CRCSensor_exaone_mom3_MOM_batch_8_lr0.00003_wd0.1_do0.1/_s100/s_0_checkpoint.pt', map_location='cpu')
|
34 |
+
|
35 |
+
# new_state_dict = {}
|
36 |
+
# for k, v in hw_w.items():
|
37 |
+
# if k.startswith("_orig_mod."):
|
38 |
+
# new_k = k.replace("_orig_mod.", "agg_model.", 1)
|
39 |
+
# else:
|
40 |
+
# new_k = k
|
41 |
+
# new_state_dict[new_k] = v
|
42 |
+
|
43 |
+
# load_result = model.load_state_dict(new_state_dict, strict=False)
|
44 |
+
model.load_state_dict(torch.load('./pytorch_model.bin'))
|
45 |
+
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
46 |
+
model.eval()
|
47 |
+
model.feature_extractor = torch.compile(model.feature_extractor)
|
48 |
+
model.agg_model = torch.compile(model.agg_model)
|
49 |
+
|
50 |
+
for svs_name in os.listdir(args.svs_dir):
|
51 |
+
infer(model, os.path.join(args.svs_dir, svs_name))
|
52 |
+
|
models/aggregator.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer
|
8 |
+
|
9 |
+
|
10 |
+
# @dataclass
|
11 |
+
# class VisionCfg:
|
12 |
+
# layers: Union[Tuple[int, int, int, int], int] = 6
|
13 |
+
# width: int = 512
|
14 |
+
# head_width: int = 64
|
15 |
+
# mlp_ratio: float = 4.0
|
16 |
+
|
17 |
+
# ls_init_value: Optional[float] = None # layer scale initial value
|
18 |
+
# patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
19 |
+
# no_ln_pre: bool = False # disable pre transformer LayerNorm
|
20 |
+
# pool_type: str = 'none'
|
21 |
+
# final_ln_after_pool: bool = True # apply final LayerNorm after pooling
|
22 |
+
# output_tokens: bool = False
|
23 |
+
# act_kwargs: Optional[dict] = None
|
24 |
+
# norm_kwargs: Optional[dict] = None
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class CLIPVisionCfg:
|
28 |
+
layers: Union[Tuple[int, int, int, int], int] = 6
|
29 |
+
width: int = 512
|
30 |
+
head_width: int = 64
|
31 |
+
mlp_ratio: float = 4.0
|
32 |
+
patch_size: int = 16
|
33 |
+
image_size: Union[Tuple[int, int], int] = 224
|
34 |
+
|
35 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
36 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
37 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
|
38 |
+
attn_pooler_queries: int = 256 # n_queries for attentional pooler
|
39 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
40 |
+
no_ln_pre: bool = False # disable pre transformer LayerNorm
|
41 |
+
pos_embed_type: str = 'none'
|
42 |
+
final_ln_after_pool: bool = True # apply final LayerNorm after pooling
|
43 |
+
pool_type: str = 'none'
|
44 |
+
output_tokens: bool = False
|
45 |
+
act_kwargs: Optional[dict] = None
|
46 |
+
norm_kwargs: Optional[dict] = None
|
47 |
+
|
48 |
+
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
|
49 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
50 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
51 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
52 |
+
timm_proj_bias: bool = False # enable bias final projection
|
53 |
+
timm_drop: float = 0. # head dropout
|
54 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
55 |
+
img_embed: bool = False
|
56 |
+
cls_embed: bool = False
|
57 |
+
projection = False
|
58 |
+
use_flex = True
|
59 |
+
|
60 |
+
|
61 |
+
def get_cast_dtype(precision: str):
|
62 |
+
cast_dtype = None
|
63 |
+
if precision == 'bf16':
|
64 |
+
cast_dtype = torch.bfloat16
|
65 |
+
elif precision == 'fp16':
|
66 |
+
cast_dtype = torch.float16
|
67 |
+
return cast_dtype
|
68 |
+
|
69 |
+
|
70 |
+
def get_input_dtype(precision: str):
|
71 |
+
input_dtype = None
|
72 |
+
if precision in ('bf16', 'pure_bf16'):
|
73 |
+
input_dtype = torch.bfloat16
|
74 |
+
elif precision in ('fp16', 'pure_fp16'):
|
75 |
+
input_dtype = torch.float16
|
76 |
+
return input_dtype
|
77 |
+
|
78 |
+
|
79 |
+
def _build_vision_tower(
|
80 |
+
embed_dim: int,
|
81 |
+
vision_cfg: CLIPVisionCfg,
|
82 |
+
quick_gelu: bool = False,
|
83 |
+
cast_dtype: Optional[torch.dtype] = None,
|
84 |
+
dropout: float = 0.1,
|
85 |
+
num_registers: int = 0,
|
86 |
+
):
|
87 |
+
if isinstance(vision_cfg, dict):
|
88 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
89 |
+
|
90 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
91 |
+
|
92 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
93 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
94 |
+
if vision_cfg.norm_kwargs:
|
95 |
+
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
|
96 |
+
if vision_cfg.act_kwargs is not None:
|
97 |
+
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
|
98 |
+
|
99 |
+
visual = VisionTransformer(
|
100 |
+
width=vision_cfg.width,
|
101 |
+
layers=vision_cfg.layers,
|
102 |
+
heads=vision_heads,
|
103 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
104 |
+
ls_init_value=vision_cfg.ls_init_value,
|
105 |
+
output_dim=embed_dim,
|
106 |
+
patch_dropout=vision_cfg.patch_dropout,
|
107 |
+
no_ln_pre=vision_cfg.no_ln_pre,
|
108 |
+
pool_type=vision_cfg.pool_type,
|
109 |
+
final_ln_after_pool=vision_cfg.final_ln_after_pool,
|
110 |
+
act_layer=act_layer,
|
111 |
+
norm_layer=norm_layer,
|
112 |
+
output_tokens=vision_cfg.output_tokens,
|
113 |
+
img_embed = vision_cfg.img_embed,
|
114 |
+
use_flex = True,
|
115 |
+
dropout = dropout,
|
116 |
+
num_registers = num_registers,
|
117 |
+
use_rel_bias =True,
|
118 |
+
)
|
119 |
+
|
120 |
+
return visual
|
121 |
+
|
122 |
+
|
123 |
+
class MixedOmicsModel(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
embed_dim: int,
|
127 |
+
vision_cfg: CLIPVisionCfg,
|
128 |
+
quick_gelu: bool = False,
|
129 |
+
cast_dtype: Optional[torch.dtype] = None,
|
130 |
+
drop_rate: float = 0.25,
|
131 |
+
num_registers: int = 0,
|
132 |
+
*args,
|
133 |
+
**kwargs,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.drop_prob = drop_rate
|
138 |
+
self.num_registers = num_registers
|
139 |
+
|
140 |
+
vision_cfg.cls_embed = False
|
141 |
+
|
142 |
+
self.visual = _build_vision_tower(embed_dim,
|
143 |
+
vision_cfg,
|
144 |
+
quick_gelu,
|
145 |
+
cast_dtype,
|
146 |
+
dropout=drop_rate,
|
147 |
+
num_registers=0,
|
148 |
+
)
|
149 |
+
|
150 |
+
self.image_proj = nn.Linear(embed_dim, embed_dim)
|
151 |
+
self.image_proj.apply(self.init_weights)
|
152 |
+
|
153 |
+
self.ln_post = LayerNorm(embed_dim)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def init_weights(self, module):
|
158 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
159 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
160 |
+
elif isinstance(module, nn.LayerNorm):
|
161 |
+
module.bias.data.zero_()
|
162 |
+
module.weight.data.fill_(1.0)
|
163 |
+
|
164 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
165 |
+
module.bias.data.zero_()
|
166 |
+
|
167 |
+
def _check_tensor(self, tensor, name):
|
168 |
+
print(name, " : ", tensor.shape)
|
169 |
+
if torch.isnan(tensor).any():
|
170 |
+
print(tensor.shape)
|
171 |
+
print(f"Tensor {name} contains NaN values.")
|
172 |
+
if torch.isinf(tensor).any():
|
173 |
+
print(tensor.shape)
|
174 |
+
print(f"Tensor {name} contains Inf values.")
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self,
|
178 |
+
image,
|
179 |
+
coords=None,
|
180 |
+
im_mask=None,
|
181 |
+
*args,
|
182 |
+
**kwargs,
|
183 |
+
):
|
184 |
+
|
185 |
+
## image embedding
|
186 |
+
image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous())
|
187 |
+
image_embeds = self.ln_post(image_embeds)
|
188 |
+
|
189 |
+
if im_mask is not None:
|
190 |
+
mask = im_mask.unsqueeze(-1).contiguous()
|
191 |
+
masked_embeds = image_embeds * mask
|
192 |
+
sum_embeds = masked_embeds.sum(dim=1)
|
193 |
+
valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1]
|
194 |
+
mean_embeds = sum_embeds / valid_counts # [N, dim]
|
195 |
+
|
196 |
+
else:
|
197 |
+
mean_embeds = image_embeds.mean(-2)
|
198 |
+
|
199 |
+
image_embeds_final = self.image_proj(mean_embeds)
|
200 |
+
|
201 |
+
return image_embeds_final, image_embeds, mean_embeds
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
def make_model(
|
206 |
+
embed_dim=768,
|
207 |
+
droprate=0.1,
|
208 |
+
num_registers=0,
|
209 |
+
depth=4,
|
210 |
+
):
|
211 |
+
vCfg = CLIPVisionCfg
|
212 |
+
vCfg.width = embed_dim
|
213 |
+
vCfg.layers = depth
|
214 |
+
|
215 |
+
model = MixedOmicsModel(
|
216 |
+
embed_dim=embed_dim,
|
217 |
+
vision_cfg=vCfg,
|
218 |
+
drop_rate=droprate,
|
219 |
+
num_registers=num_registers,
|
220 |
+
)
|
221 |
+
|
222 |
+
return model
|
models/cls_modules.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GlobalClassificationHead(nn.Module):
|
7 |
+
def __init__(self, input_dim, num_classes=32, dropout_rate=0.1):
|
8 |
+
super().__init__()
|
9 |
+
self.norm = nn.LayerNorm(input_dim)
|
10 |
+
self.dropout = nn.Dropout(dropout_rate)
|
11 |
+
self.fc = nn.Linear(input_dim, num_classes)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
x = self.norm(x)
|
15 |
+
x = self.dropout(x)
|
16 |
+
logits = self.fc(x) # (batch_size, num_classes)
|
17 |
+
return logits
|
18 |
+
|
19 |
+
|
20 |
+
class CLSHead(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
embed_dim,
|
24 |
+
num_classes,
|
25 |
+
dropout=0.1,
|
26 |
+
use_norm=True,
|
27 |
+
hidden_dim=None,
|
28 |
+
activation="silu",
|
29 |
+
pooling_type="cls",
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
if hidden_dim is None:
|
34 |
+
hidden_dim = embed_dim
|
35 |
+
|
36 |
+
if activation == "gelu":
|
37 |
+
self.activation = nn.GELU()
|
38 |
+
elif activation == "relu":
|
39 |
+
self.activation = nn.ReLU(inplace=True)
|
40 |
+
elif activation == "silu":
|
41 |
+
self.activation = nn.SiLU(inplace=True)
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Value Error: {activation}")
|
44 |
+
|
45 |
+
self.pooling_type = pooling_type
|
46 |
+
|
47 |
+
if pooling_type == "attention":
|
48 |
+
self.attention_pool = nn.Sequential(
|
49 |
+
nn.LayerNorm(embed_dim),
|
50 |
+
nn.Linear(embed_dim, 1),
|
51 |
+
# nn.Softmax(dim=1)
|
52 |
+
)
|
53 |
+
|
54 |
+
if use_norm:
|
55 |
+
self.norm = nn.LayerNorm(embed_dim)
|
56 |
+
else:
|
57 |
+
self.norm = nn.Identity()
|
58 |
+
|
59 |
+
self.fc1 = nn.Linear(embed_dim, hidden_dim)
|
60 |
+
self.dropout1 = nn.Dropout(dropout)
|
61 |
+
self.fc2 = nn.Linear(hidden_dim, num_classes)
|
62 |
+
|
63 |
+
self._init_weights()
|
64 |
+
|
65 |
+
def _init_weights(self):
|
66 |
+
nn.init.trunc_normal_(self.fc1.weight, std=0.02)
|
67 |
+
nn.init.zeros_(self.fc1.bias)
|
68 |
+
nn.init.trunc_normal_(self.fc2.weight, std=0.02)
|
69 |
+
nn.init.zeros_(self.fc2.bias)
|
70 |
+
|
71 |
+
if self.pooling_type == "attention":
|
72 |
+
nn.init.trunc_normal_(self.attention_pool[0].weight, std=0.02)
|
73 |
+
nn.init.zeros_(self.attention_pool[0].bias)
|
74 |
+
|
75 |
+
def forward(self, x, x2=None, x3=None, attn_mask=None):
|
76 |
+
if self.pooling_type == "mlp":
|
77 |
+
pooled = x + x3
|
78 |
+
elif self.pooling_type == "attention":
|
79 |
+
x = torch.cat([x.unsqueeze(1), x2], dim=1)
|
80 |
+
weights = self.attention_pool(x) # [batch_size, num_tokens, 1]
|
81 |
+
|
82 |
+
if attn_mask is not None:
|
83 |
+
attn_mask = torch.cat([torch.zeros(attn_mask.shape[0], 1).to(attn_mask.dtype).to(attn_mask.device), attn_mask], dim=1)
|
84 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=weights.dtype)
|
85 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
86 |
+
attn_mask = new_attn_mask
|
87 |
+
|
88 |
+
if len(attn_mask.shape) ==2:
|
89 |
+
attn_mask = attn_mask[..., None] # [batch_size, num_tokens, 1]
|
90 |
+
|
91 |
+
weights = weights + attn_mask
|
92 |
+
weights = F.softmax(weights, dim=1)
|
93 |
+
|
94 |
+
pooled = torch.sum(x * weights, dim=1)
|
95 |
+
else:
|
96 |
+
raise ValueError(f"지원하지 않는 풀링 타입: {self.pooling_type}")
|
97 |
+
|
98 |
+
pooled = self.norm(pooled)
|
99 |
+
|
100 |
+
x = self.fc1(pooled)
|
101 |
+
x = self.activation(x)
|
102 |
+
x = self.dropout1(x)
|
103 |
+
x = self.fc2(x)
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class BaseClassifier(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
base_model,
|
112 |
+
num_classes,
|
113 |
+
cls_head_kwargs=None
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
self.backbone = base_model
|
118 |
+
embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768
|
119 |
+
|
120 |
+
cls_head_config = {
|
121 |
+
'embed_dim': embed_dim,
|
122 |
+
'num_classes': num_classes,
|
123 |
+
}
|
124 |
+
|
125 |
+
if cls_head_kwargs:
|
126 |
+
cls_head_config.update(cls_head_kwargs)
|
127 |
+
|
128 |
+
self.cls_head = CLSHead(**cls_head_config)
|
129 |
+
|
130 |
+
def forward(self, image, coords=None, im_mask=None):
|
131 |
+
feat_final, feats, m_feats = self.backbone(image=image, coords=coords, im_mask=im_mask)
|
132 |
+
logits = self.cls_head(feat_final, feats, m_feats, attn_mask=~im_mask.bool())
|
133 |
+
|
134 |
+
Y_hat = torch.topk(logits, 1, dim=1)[1]
|
135 |
+
Y_prob = F.softmax(logits, dim=-1)
|
136 |
+
|
137 |
+
return logits, Y_prob, Y_hat
|
138 |
+
|
139 |
+
|
140 |
+
class LinearClassifier(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
base_model,
|
144 |
+
num_classes=2,
|
145 |
+
pool="final"
|
146 |
+
):
|
147 |
+
super().__init__()
|
148 |
+
self.backbone = base_model
|
149 |
+
self.pool = pool
|
150 |
+
print("Linear classifier pool type : ", self.pool)
|
151 |
+
embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768
|
152 |
+
|
153 |
+
self.cls_head = nn.Sequential(
|
154 |
+
nn.LayerNorm(embed_dim),
|
155 |
+
nn.Linear(embed_dim, num_classes)
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, image, coords=None, im_mask=None):
|
159 |
+
feat_final, _, feat_mean = self.backbone(image=image, coords=coords, im_mask=im_mask)
|
160 |
+
if self.pool == "mean":
|
161 |
+
logits = self.cls_head(feat_mean)
|
162 |
+
elif self.pool == "final":
|
163 |
+
logits = self.cls_head(feat_final)
|
164 |
+
|
165 |
+
Y_hat = torch.topk(logits, 1, dim=1)[1]
|
166 |
+
Y_prob = F.softmax(logits, dim=-1)
|
167 |
+
|
168 |
+
return logits, Y_prob, Y_hat
|
models/exaonepath.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from models.aggregator import make_model
|
5 |
+
from models.cls_modules import LinearClassifier
|
6 |
+
from datasets.dataset_WSI import WSIPatchDataset
|
7 |
+
from models.feature_extractor import vit_base
|
8 |
+
from utils.wsi_utils import extract_tissue_patch_coords
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
12 |
+
|
13 |
+
|
14 |
+
class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin):
|
15 |
+
def __init__(
|
16 |
+
self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True
|
17 |
+
):
|
18 |
+
super(EXAONEPathV1p5Downstream, self).__init__()
|
19 |
+
self.step_size = step_size
|
20 |
+
self.patch_size = patch_size
|
21 |
+
self.macenko = macenko
|
22 |
+
self.num_sampled_patch = num_sampled_patch
|
23 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
|
25 |
+
self.config = {
|
26 |
+
"step_size": step_size,
|
27 |
+
"patch_size": patch_size,
|
28 |
+
"macenko": macenko,
|
29 |
+
"num_sampled_patch": num_sampled_patch,
|
30 |
+
}
|
31 |
+
|
32 |
+
self.feature_extractor = vit_base()
|
33 |
+
self.feature_extractor = self.feature_extractor
|
34 |
+
# self.feature_extractor = self.feature_extractor.to(self.device)
|
35 |
+
# self.feature_extractor.eval()
|
36 |
+
|
37 |
+
self.agg_model = make_model(
|
38 |
+
embed_dim=768,
|
39 |
+
droprate=0.0,
|
40 |
+
num_registers=0,
|
41 |
+
depth=4,
|
42 |
+
)
|
43 |
+
|
44 |
+
self.agg_model = LinearClassifier(self.agg_model, pool='mean')
|
45 |
+
|
46 |
+
# self.agg_model.to(self.device)
|
47 |
+
# self.agg_model.eval()
|
48 |
+
|
49 |
+
@torch.no_grad()
|
50 |
+
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
|
51 |
+
# Extract patches
|
52 |
+
coords = extract_tissue_patch_coords(
|
53 |
+
svs_path, patch_size=self.patch_size, step_size=self.step_size
|
54 |
+
)
|
55 |
+
|
56 |
+
# Extract patch-level features
|
57 |
+
self.feature_extractor.eval()
|
58 |
+
patch_dataset = WSIPatchDataset(
|
59 |
+
coords=coords,
|
60 |
+
wsi_path=svs_path,
|
61 |
+
pretrained=True,
|
62 |
+
macenko=self.macenko,
|
63 |
+
patch_size=self.patch_size,
|
64 |
+
return_coord=True,
|
65 |
+
)
|
66 |
+
patch_loader = DataLoader(
|
67 |
+
dataset=patch_dataset,
|
68 |
+
batch_size=feature_extractor_batch_size,
|
69 |
+
num_workers=(
|
70 |
+
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
|
71 |
+
),
|
72 |
+
pin_memory=self.device.type == "cuda",
|
73 |
+
|
74 |
+
)
|
75 |
+
features_list = []
|
76 |
+
coords_list = []
|
77 |
+
for count, items in enumerate(patch_loader):
|
78 |
+
patches, coords = items
|
79 |
+
print(
|
80 |
+
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
|
81 |
+
end="\r",
|
82 |
+
)
|
83 |
+
patches = patches.to(self.device, non_blocking=True)
|
84 |
+
|
85 |
+
feature = self.feature_extractor(patches) # [B, 1024]
|
86 |
+
feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
|
87 |
+
feature = feature.to("cpu", non_blocking=True)
|
88 |
+
features_list.append(feature)
|
89 |
+
|
90 |
+
coords = coords.to(self.device, non_blocking=True)
|
91 |
+
coords_list.append(coords)
|
92 |
+
|
93 |
+
print("")
|
94 |
+
print("Feature extraction finished")
|
95 |
+
|
96 |
+
features = torch.cat(features_list)
|
97 |
+
coords = torch.cat(coords_list)
|
98 |
+
total_samples = features.shape[0]
|
99 |
+
|
100 |
+
num_samples = min(self.num_sampled_patch, total_samples)
|
101 |
+
indices = torch.randperm(total_samples)[:num_samples]
|
102 |
+
sampled_features = features[indices]
|
103 |
+
sampled_coords = coords[indices]
|
104 |
+
|
105 |
+
# Aggregate features
|
106 |
+
self.agg_model.eval()
|
107 |
+
|
108 |
+
# sampled_features = torch.randn([8192, 768])
|
109 |
+
# sampled_coords = torch.randn([8192, 2])
|
110 |
+
|
111 |
+
logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device))
|
112 |
+
probs = Y_prob[0].cpu()
|
113 |
+
|
114 |
+
return probs
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8):
|
118 |
+
# Extract patches
|
119 |
+
coords = extract_tissue_patch_coords(
|
120 |
+
svs_path, patch_size=self.patch_size, step_size=self.step_size
|
121 |
+
)
|
122 |
+
|
123 |
+
# Extract patch-level features
|
124 |
+
self.feature_extractor.eval()
|
125 |
+
patch_dataset = WSIPatchDataset(
|
126 |
+
coords=coords,
|
127 |
+
wsi_path=svs_path,
|
128 |
+
pretrained=True,
|
129 |
+
macenko=self.macenko,
|
130 |
+
patch_size=self.patch_size,
|
131 |
+
return_coord=False
|
132 |
+
)
|
133 |
+
patch_loader = DataLoader(
|
134 |
+
dataset=patch_dataset,
|
135 |
+
batch_size=feature_extractor_batch_size,
|
136 |
+
num_workers=(
|
137 |
+
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
|
138 |
+
),
|
139 |
+
pin_memory=self.device.type == "cuda",
|
140 |
+
)
|
141 |
+
features_list = []
|
142 |
+
for count, patches in enumerate(patch_loader):
|
143 |
+
print(
|
144 |
+
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
|
145 |
+
end="\r",
|
146 |
+
)
|
147 |
+
patches = patches.to(self.device, non_blocking=True)
|
148 |
+
|
149 |
+
feature = self.feature_extractor(patches) # [B, 1024]
|
150 |
+
feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
|
151 |
+
feature = feature.to("cpu", non_blocking=True)
|
152 |
+
features_list.append(feature)
|
153 |
+
print("")
|
154 |
+
print("Feature extraction finished")
|
155 |
+
|
156 |
+
features = torch.cat(features_list)
|
157 |
+
|
158 |
+
return features
|
models/feature_extractor.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class DropPath(nn.Module):
|
10 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
11 |
+
|
12 |
+
def __init__(self, drop_prob=None):
|
13 |
+
super(DropPath, self).__init__()
|
14 |
+
self.drop_prob = drop_prob
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return _drop_path(x, self.drop_prob, self.training)
|
18 |
+
|
19 |
+
|
20 |
+
class Mlp(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
in_features,
|
24 |
+
hidden_features=None,
|
25 |
+
out_features=None,
|
26 |
+
act_layer=nn.GELU,
|
27 |
+
drop=0.0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
out_features = out_features or in_features
|
31 |
+
hidden_features = hidden_features or in_features
|
32 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
33 |
+
self.act = act_layer()
|
34 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
35 |
+
self.drop = nn.Dropout(drop)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.fc1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
x = self.fc2(x)
|
42 |
+
x = self.drop(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class Attention(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
dim,
|
50 |
+
num_heads=8,
|
51 |
+
qkv_bias=False,
|
52 |
+
qk_scale=None,
|
53 |
+
attn_drop=0.0,
|
54 |
+
proj_drop=0.0,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.num_heads = num_heads
|
58 |
+
head_dim = dim // num_heads
|
59 |
+
self.scale = qk_scale or head_dim**-0.5
|
60 |
+
|
61 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
62 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
63 |
+
self.proj = nn.Linear(dim, dim)
|
64 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
B, N, C = x.shape
|
68 |
+
qkv = (
|
69 |
+
self.qkv(x)
|
70 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
71 |
+
.permute(2, 0, 3, 1, 4)
|
72 |
+
)
|
73 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
74 |
+
|
75 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
76 |
+
attn = attn.softmax(dim=-1)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
80 |
+
x = self.proj(x)
|
81 |
+
x = self.proj_drop(x)
|
82 |
+
return x, attn
|
83 |
+
|
84 |
+
|
85 |
+
class Block(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
dim,
|
89 |
+
num_heads,
|
90 |
+
mlp_ratio=4.0,
|
91 |
+
qkv_bias=False,
|
92 |
+
qk_scale=None,
|
93 |
+
drop=0.0,
|
94 |
+
attn_drop=0.0,
|
95 |
+
drop_path=0.0,
|
96 |
+
act_layer=nn.GELU,
|
97 |
+
norm_layer=nn.LayerNorm,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.norm1 = norm_layer(dim)
|
101 |
+
self.attn = Attention(
|
102 |
+
dim,
|
103 |
+
num_heads=num_heads,
|
104 |
+
qkv_bias=qkv_bias,
|
105 |
+
qk_scale=qk_scale,
|
106 |
+
attn_drop=attn_drop,
|
107 |
+
proj_drop=drop,
|
108 |
+
)
|
109 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
110 |
+
self.norm2 = norm_layer(dim)
|
111 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
112 |
+
self.mlp = Mlp(
|
113 |
+
in_features=dim,
|
114 |
+
hidden_features=mlp_hidden_dim,
|
115 |
+
act_layer=act_layer,
|
116 |
+
drop=drop,
|
117 |
+
)
|
118 |
+
|
119 |
+
def forward(self, x, return_attention=False):
|
120 |
+
y, attn = self.attn(self.norm1(x))
|
121 |
+
if return_attention:
|
122 |
+
return attn
|
123 |
+
x = x + self.drop_path(y)
|
124 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class PatchEmbed(nn.Module):
|
129 |
+
"""Image to Patch Embedding"""
|
130 |
+
|
131 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
132 |
+
super().__init__()
|
133 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
134 |
+
self.img_size = img_size
|
135 |
+
self.patch_size = patch_size
|
136 |
+
self.num_patches = num_patches
|
137 |
+
|
138 |
+
self.proj = nn.Conv2d(
|
139 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
140 |
+
)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
B, C, H, W = x.shape
|
144 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class VisionTransformer(nn.Module):
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
img_size=[224],
|
152 |
+
patch_size=16,
|
153 |
+
in_chans=3,
|
154 |
+
num_classes=0,
|
155 |
+
embed_dim=768,
|
156 |
+
depth=12,
|
157 |
+
num_heads=12,
|
158 |
+
mlp_ratio=4.0,
|
159 |
+
qkv_bias=False,
|
160 |
+
qk_scale=None,
|
161 |
+
drop_rate=0.0,
|
162 |
+
attn_drop_rate=0.0,
|
163 |
+
drop_path_rate=0.0,
|
164 |
+
norm_layer=nn.LayerNorm,
|
165 |
+
**kwargs
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_features = self.embed_dim = embed_dim
|
169 |
+
|
170 |
+
self.patch_embed = PatchEmbed(
|
171 |
+
img_size=img_size[0],
|
172 |
+
patch_size=patch_size,
|
173 |
+
in_chans=in_chans,
|
174 |
+
embed_dim=embed_dim,
|
175 |
+
)
|
176 |
+
num_patches = self.patch_embed.num_patches
|
177 |
+
|
178 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
179 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
180 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
181 |
+
|
182 |
+
dpr = [
|
183 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
184 |
+
] # stochastic depth decay rule
|
185 |
+
self.blocks = nn.ModuleList(
|
186 |
+
[
|
187 |
+
Block(
|
188 |
+
dim=embed_dim,
|
189 |
+
num_heads=num_heads,
|
190 |
+
mlp_ratio=mlp_ratio,
|
191 |
+
qkv_bias=qkv_bias,
|
192 |
+
qk_scale=qk_scale,
|
193 |
+
drop=drop_rate,
|
194 |
+
attn_drop=attn_drop_rate,
|
195 |
+
drop_path=dpr[i],
|
196 |
+
norm_layer=norm_layer,
|
197 |
+
)
|
198 |
+
for i in range(depth)
|
199 |
+
]
|
200 |
+
)
|
201 |
+
self.norm = norm_layer(embed_dim)
|
202 |
+
|
203 |
+
# Classifier head
|
204 |
+
self.head = (
|
205 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
206 |
+
)
|
207 |
+
|
208 |
+
_trunc_normal_(self.pos_embed, std=0.02)
|
209 |
+
_trunc_normal_(self.cls_token, std=0.02)
|
210 |
+
self.apply(self._init_weights)
|
211 |
+
|
212 |
+
def _init_weights(self, m):
|
213 |
+
if isinstance(m, nn.Linear):
|
214 |
+
_trunc_normal_(m.weight, std=0.02)
|
215 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
216 |
+
nn.init.constant_(m.bias, 0)
|
217 |
+
elif isinstance(m, nn.LayerNorm):
|
218 |
+
nn.init.constant_(m.bias, 0)
|
219 |
+
nn.init.constant_(m.weight, 1.0)
|
220 |
+
|
221 |
+
def interpolate_pos_encoding(self, x, w, h):
|
222 |
+
npatch = x.shape[1] - 1
|
223 |
+
N = self.pos_embed.shape[1] - 1
|
224 |
+
if npatch == N and w == h:
|
225 |
+
return self.pos_embed
|
226 |
+
class_pos_embed = self.pos_embed[:, 0]
|
227 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
228 |
+
dim = x.shape[-1]
|
229 |
+
w0 = w // self.patch_embed.patch_size
|
230 |
+
h0 = h // self.patch_embed.patch_size
|
231 |
+
# we add a small number to avoid floating point error in the interpolation
|
232 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
233 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
234 |
+
patch_pos_embed = nn.functional.interpolate(
|
235 |
+
patch_pos_embed.reshape(
|
236 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
237 |
+
).permute(0, 3, 1, 2),
|
238 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
239 |
+
# size=(int(w0), int(h0)),
|
240 |
+
mode="bicubic",
|
241 |
+
)
|
242 |
+
assert (
|
243 |
+
int(w0) == patch_pos_embed.shape[-2]
|
244 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
245 |
+
)
|
246 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
247 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
248 |
+
|
249 |
+
def prepare_tokens(self, x):
|
250 |
+
B, nc, w, h = x.shape
|
251 |
+
x = self.patch_embed(x) # patch linear embedding
|
252 |
+
|
253 |
+
# add the [CLS] token to the embed patch tokens
|
254 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
255 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
256 |
+
|
257 |
+
# add positional encoding to each token
|
258 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
259 |
+
|
260 |
+
return self.pos_drop(x)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
x = self.prepare_tokens(x)
|
264 |
+
for blk in self.blocks:
|
265 |
+
x = blk(x)
|
266 |
+
x = self.norm(x)
|
267 |
+
# print(x.type())
|
268 |
+
return x[:, 0]
|
269 |
+
|
270 |
+
def get_last_selfattention(self, x):
|
271 |
+
x = self.prepare_tokens(x)
|
272 |
+
for i, blk in enumerate(self.blocks):
|
273 |
+
if i < len(self.blocks) - 1:
|
274 |
+
x = blk(x)
|
275 |
+
else:
|
276 |
+
# return attention of the last block
|
277 |
+
return blk(x, return_attention=True)
|
278 |
+
|
279 |
+
def get_intermediate_layers(self, x, n=1):
|
280 |
+
x = self.prepare_tokens(x)
|
281 |
+
# we return the output tokens from the `n` last blocks
|
282 |
+
output = []
|
283 |
+
for i, blk in enumerate(self.blocks):
|
284 |
+
x = blk(x)
|
285 |
+
if len(self.blocks) - i <= n:
|
286 |
+
output.append(self.norm(x))
|
287 |
+
return output
|
288 |
+
|
289 |
+
|
290 |
+
def vit_base(patch_size=16, **kwargs):
|
291 |
+
model = VisionTransformer(
|
292 |
+
patch_size=patch_size,
|
293 |
+
embed_dim=768,
|
294 |
+
depth=12,
|
295 |
+
num_heads=12,
|
296 |
+
mlp_ratio=4,
|
297 |
+
qkv_bias=True,
|
298 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
299 |
+
**kwargs
|
300 |
+
)
|
301 |
+
return model
|
302 |
+
|
303 |
+
|
304 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
305 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
306 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
307 |
+
def norm_cdf(x):
|
308 |
+
# Computes standard normal cumulative distribution function
|
309 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
310 |
+
|
311 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
312 |
+
warnings.warn(
|
313 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
314 |
+
"The distribution of values may be incorrect.",
|
315 |
+
stacklevel=2,
|
316 |
+
)
|
317 |
+
|
318 |
+
with torch.no_grad():
|
319 |
+
# Values are generated by using a truncated uniform distribution and
|
320 |
+
# then using the inverse CDF for the normal distribution.
|
321 |
+
# Get upper and lower cdf values
|
322 |
+
l = norm_cdf((a - mean) / std)
|
323 |
+
u = norm_cdf((b - mean) / std)
|
324 |
+
|
325 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
326 |
+
# [2l-1, 2u-1].
|
327 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
328 |
+
|
329 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
330 |
+
# standard normal
|
331 |
+
tensor.erfinv_()
|
332 |
+
|
333 |
+
# Transform to proper mean, std
|
334 |
+
tensor.mul_(std * math.sqrt(2.0))
|
335 |
+
tensor.add_(mean)
|
336 |
+
|
337 |
+
# Clamp to ensure it's in the proper range
|
338 |
+
tensor.clamp_(min=a, max=b)
|
339 |
+
return tensor
|
340 |
+
|
341 |
+
|
342 |
+
def _trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
343 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
344 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
345 |
+
|
346 |
+
|
347 |
+
def _drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
348 |
+
if drop_prob == 0.0 or not training:
|
349 |
+
return x
|
350 |
+
keep_prob = 1 - drop_prob
|
351 |
+
shape = (x.shape[0],) + (1,) * (
|
352 |
+
x.ndim - 1
|
353 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
354 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
355 |
+
random_tensor.floor_() # binarize
|
356 |
+
output = x.div(keep_prob) * random_tensor
|
357 |
+
return output
|
models/flex_attn.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn.attention.flex_attention import (
|
8 |
+
_DEFAULT_SPARSE_BLOCK_SIZE,
|
9 |
+
create_block_mask,
|
10 |
+
create_mask,
|
11 |
+
flex_attention,
|
12 |
+
)
|
13 |
+
from torch.nn.attention.flex_attention import flex_attention, _vmap_for_bhqkv
|
14 |
+
|
15 |
+
try:
|
16 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
17 |
+
except ImportError:
|
18 |
+
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
|
19 |
+
|
20 |
+
from torch._dynamo import disable
|
21 |
+
|
22 |
+
|
23 |
+
def generate_alibi_bias(H=12):
|
24 |
+
alibi_bias = []
|
25 |
+
for h in range(H):
|
26 |
+
alibi_bias.append(-((h + 1) / H))
|
27 |
+
alibi_bias = torch.tensor(alibi_bias)
|
28 |
+
alibi_bias = torch.exp2(alibi_bias)
|
29 |
+
return alibi_bias
|
30 |
+
|
31 |
+
|
32 |
+
def get_rel_bias_func(scale, coords=None, qk_scale=1.0):
|
33 |
+
def patch_coords_rel_bias(score, b, h, q_idx, kv_idx):
|
34 |
+
if coords is None:
|
35 |
+
return score
|
36 |
+
with torch.no_grad():
|
37 |
+
dx = coords[b, q_idx][0] - coords[b, kv_idx][0]
|
38 |
+
dy = coords[b, q_idx][1] - coords[b, kv_idx][1]
|
39 |
+
dist = torch.sqrt(dx * dx + dy * dy)
|
40 |
+
dist = dist.clamp(max=1000) # max distance
|
41 |
+
dist = torch.log1p(dist)
|
42 |
+
bias = dist * scale[h] * qk_scale
|
43 |
+
return score - bias # closer → larger score
|
44 |
+
return patch_coords_rel_bias
|
45 |
+
|
46 |
+
|
47 |
+
def key_padding_mask(mask):
|
48 |
+
def padding_mask(b, h, q_idx, kv_idx):
|
49 |
+
return ~mask[b, kv_idx]
|
50 |
+
return padding_mask
|
51 |
+
|
52 |
+
|
53 |
+
class FlexCore(nn.Module):
|
54 |
+
"""
|
55 |
+
For using "forward hook"
|
56 |
+
"""
|
57 |
+
def forward(self, q, k, v, score_mod=None, block_mask=None, return_lse=False):
|
58 |
+
"""
|
59 |
+
"return_lse=True" should be used with ATTN_MAP_VIS wrapper.
|
60 |
+
Though return_lse is "True", _flex_attention(...) only have an output (attention output, not attention scores).
|
61 |
+
"""
|
62 |
+
return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask, return_lse=return_lse)
|
63 |
+
|
64 |
+
|
65 |
+
class Flex_Attention(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
dim: int,
|
69 |
+
num_heads: int = 12,
|
70 |
+
qkv_bias: bool = True,
|
71 |
+
proj_drop: float = 0.,
|
72 |
+
use_rel_bias: bool = True,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
76 |
+
self.num_heads = num_heads
|
77 |
+
self.head_dim = dim // num_heads
|
78 |
+
|
79 |
+
self.scale = self.head_dim ** -0.5
|
80 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
81 |
+
if qkv_bias:
|
82 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
83 |
+
else:
|
84 |
+
self.in_proj_bias = None
|
85 |
+
|
86 |
+
self.f_attn = FlexCore()
|
87 |
+
|
88 |
+
self.out_proj = nn.Linear(dim, dim)
|
89 |
+
self.out_drop = nn.Dropout(proj_drop)
|
90 |
+
|
91 |
+
self.max_distance=16
|
92 |
+
|
93 |
+
def build_rel_bias(self, coords):
|
94 |
+
return torch.log1p(torch.cdist(coords, coords, p=2))
|
95 |
+
|
96 |
+
def forward(self, x, coords=None, attn_mask: Optional[torch.Tensor] = None, return_attn_score=False):
|
97 |
+
N, L, C = x.shape
|
98 |
+
# ensure contiguous projection before chunking
|
99 |
+
x_proj = F.linear(x, self.in_proj_weight, self.in_proj_bias).contiguous()
|
100 |
+
q, k, v = [t.contiguous() for t in x_proj.chunk(3, -1)]
|
101 |
+
|
102 |
+
q = q.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
|
103 |
+
k = k.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
|
104 |
+
v = v.reshape(N, L, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
|
105 |
+
|
106 |
+
if attn_mask is not None:
|
107 |
+
maks_func = create_block_mask(
|
108 |
+
key_padding_mask(attn_mask), N, self.num_heads, L, L
|
109 |
+
)
|
110 |
+
|
111 |
+
qk_scale = q.size(-1) ** -0.5
|
112 |
+
x = self.f_attn(
|
113 |
+
q, k, v,
|
114 |
+
score_mod = get_rel_bias_func(generate_alibi_bias(self.num_heads).to(coords.device), coords, qk_scale) if coords is not None else None,
|
115 |
+
block_mask = maks_func if attn_mask is not None else None,
|
116 |
+
return_lse=return_attn_score,
|
117 |
+
)
|
118 |
+
|
119 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
120 |
+
x = x.reshape(N, L, C).contiguous()
|
121 |
+
|
122 |
+
x = self.out_proj(x)
|
123 |
+
x = self.out_drop(x)
|
124 |
+
return x
|
models/transformer.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import math
|
3 |
+
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from einops import pack, repeat
|
10 |
+
|
11 |
+
from .flex_attn import Flex_Attention
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class LayerNormFp32(nn.LayerNorm):
|
16 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
17 |
+
|
18 |
+
def forward(self, x: torch.Tensor):
|
19 |
+
orig_type = x.dtype
|
20 |
+
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
21 |
+
return x.to(orig_type)
|
22 |
+
|
23 |
+
|
24 |
+
class LayerNorm(nn.LayerNorm):
|
25 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor):
|
28 |
+
orig_type = x.dtype
|
29 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
30 |
+
return x.to(orig_type)
|
31 |
+
|
32 |
+
|
33 |
+
class QuickGELU(nn.Module):
|
34 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
35 |
+
def forward(self, x: torch.Tensor):
|
36 |
+
return x * torch.sigmoid(1.702 * x)
|
37 |
+
|
38 |
+
|
39 |
+
class LayerScale(nn.Module):
|
40 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
41 |
+
super().__init__()
|
42 |
+
self.inplace = inplace
|
43 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
47 |
+
|
48 |
+
|
49 |
+
class PatchDropout(nn.Module):
|
50 |
+
"""
|
51 |
+
https://arxiv.org/abs/2212.00794
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, prob, exclude_first_token=True):
|
55 |
+
super().__init__()
|
56 |
+
assert 0 <= prob < 1.
|
57 |
+
self.prob = prob
|
58 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
if not self.training or self.prob == 0.:
|
62 |
+
return x
|
63 |
+
|
64 |
+
if self.exclude_first_token:
|
65 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
66 |
+
else:
|
67 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
68 |
+
|
69 |
+
batch = x.size()[0]
|
70 |
+
num_tokens = x.size()[1]
|
71 |
+
|
72 |
+
batch_indices = torch.arange(batch)
|
73 |
+
batch_indices = batch_indices[..., None]
|
74 |
+
|
75 |
+
keep_prob = 1 - self.prob
|
76 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
77 |
+
|
78 |
+
rand = torch.randn(batch, num_tokens)
|
79 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
80 |
+
|
81 |
+
x = x[batch_indices, patch_indices_keep]
|
82 |
+
|
83 |
+
if self.exclude_first_token:
|
84 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
85 |
+
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Attention(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
dim: int,
|
93 |
+
num_heads: int = 8,
|
94 |
+
qkv_bias: bool = True,
|
95 |
+
scaled_cosine: bool = True,
|
96 |
+
scale_heads: bool = False,
|
97 |
+
logit_scale_max: float = math.log(1. / 0.01),
|
98 |
+
batch_first: bool = True,
|
99 |
+
attn_drop: float = 0.,
|
100 |
+
proj_drop: float = 0.
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
self.scaled_cosine = scaled_cosine
|
104 |
+
self.scale_heads = scale_heads
|
105 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
106 |
+
self.num_heads = num_heads
|
107 |
+
self.head_dim = dim // num_heads
|
108 |
+
self.scale = self.head_dim ** -0.5
|
109 |
+
self.logit_scale_max = logit_scale_max
|
110 |
+
self.batch_first = batch_first
|
111 |
+
self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
|
112 |
+
|
113 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
114 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
115 |
+
if qkv_bias:
|
116 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
117 |
+
else:
|
118 |
+
self.in_proj_bias = None
|
119 |
+
|
120 |
+
if self.scaled_cosine:
|
121 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
122 |
+
else:
|
123 |
+
self.logit_scale = None
|
124 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
125 |
+
if self.scale_heads:
|
126 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
127 |
+
else:
|
128 |
+
self.head_scale = None
|
129 |
+
self.out_proj = nn.Linear(dim, dim)
|
130 |
+
self.out_drop = nn.Dropout(proj_drop)
|
131 |
+
|
132 |
+
|
133 |
+
def forward(self, x, coords, attn_mask: Optional[torch.Tensor] = None):
|
134 |
+
if self.batch_first:
|
135 |
+
x = x.transpose(0, 1)
|
136 |
+
|
137 |
+
L, N, C = x.shape
|
138 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
139 |
+
q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)
|
140 |
+
k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)
|
141 |
+
v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)
|
142 |
+
|
143 |
+
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
144 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
145 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
146 |
+
attn_mask = new_attn_mask
|
147 |
+
|
148 |
+
# if self.logit_scale is not None:
|
149 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
150 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
151 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
152 |
+
|
153 |
+
if attn_mask is not None:
|
154 |
+
attn = attn + attn_mask[:, None, None, :]
|
155 |
+
attn = attn.view(-1, L, L)
|
156 |
+
attn = attn.softmax(dim=-1)
|
157 |
+
attn = self.attn_drop(attn)
|
158 |
+
|
159 |
+
x = torch.bmm(attn, v)
|
160 |
+
|
161 |
+
if self.head_scale is not None:
|
162 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
163 |
+
x = x.view(-1, L, C)
|
164 |
+
|
165 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
166 |
+
|
167 |
+
if self.batch_first:
|
168 |
+
x = x.transpose(0, 1)
|
169 |
+
|
170 |
+
x = self.out_proj(x)
|
171 |
+
x = self.out_drop(x)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class AttentionalPooler(nn.Module):
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
d_model: int,
|
179 |
+
context_dim: int,
|
180 |
+
n_head: int = 8,
|
181 |
+
n_queries: int = 256,
|
182 |
+
norm_layer: Callable = LayerNorm,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
186 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
|
187 |
+
self.ln_q = norm_layer(d_model)
|
188 |
+
self.ln_k = norm_layer(context_dim)
|
189 |
+
|
190 |
+
def forward(self, x: torch.Tensor):
|
191 |
+
N = x.shape[0]
|
192 |
+
x = self.ln_k(x)
|
193 |
+
q = self.ln_q(self.query)
|
194 |
+
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
|
195 |
+
return out
|
196 |
+
|
197 |
+
|
198 |
+
class ResidualAttentionBlock(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
d_model: int,
|
202 |
+
n_head: int,
|
203 |
+
mlp_ratio: float = 4.0,
|
204 |
+
ls_init_value: float = None,
|
205 |
+
act_layer: Callable = nn.GELU,
|
206 |
+
norm_layer: Callable = LayerNorm,
|
207 |
+
is_cross_attention: bool = False,
|
208 |
+
batch_first: bool = True,
|
209 |
+
use_flex:bool = False,
|
210 |
+
dropout:float = 0.2,
|
211 |
+
use_rel_bias:bool = True,
|
212 |
+
):
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
self.ln_1 = norm_layer(d_model)
|
216 |
+
|
217 |
+
if use_flex:
|
218 |
+
print("Flex_Attention!")
|
219 |
+
self.attn = Flex_Attention(dim = d_model, num_heads=n_head, proj_drop=dropout, use_rel_bias=use_rel_bias)
|
220 |
+
else:
|
221 |
+
self.attn = Attention(dim = d_model, num_heads=n_head, batch_first=batch_first, proj_drop=dropout, attn_drop=dropout)
|
222 |
+
|
223 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
224 |
+
if is_cross_attention:
|
225 |
+
self.ln_1_kv = norm_layer(d_model)
|
226 |
+
|
227 |
+
self.ln_2 = norm_layer(d_model)
|
228 |
+
mlp_width = int(d_model * mlp_ratio)
|
229 |
+
|
230 |
+
self.mlp = nn.Sequential(OrderedDict([
|
231 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
232 |
+
("gelu", act_layer()),
|
233 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
234 |
+
]))
|
235 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
236 |
+
|
237 |
+
def attention(
|
238 |
+
self,
|
239 |
+
q_x: torch.Tensor,
|
240 |
+
k_x: Optional[torch.Tensor] = None,
|
241 |
+
v_x: Optional[torch.Tensor] = None,
|
242 |
+
coords = None,
|
243 |
+
attn_mask: Optional[torch.Tensor] = None,
|
244 |
+
key_padding_mask=None,
|
245 |
+
):
|
246 |
+
k_x = k_x if k_x is not None else q_x
|
247 |
+
v_x = v_x if v_x is not None else q_x
|
248 |
+
|
249 |
+
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
250 |
+
|
251 |
+
return self.attn(
|
252 |
+
q_x, coords=coords, attn_mask=key_padding_mask
|
253 |
+
)
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
q_x: torch.Tensor,
|
258 |
+
k_x: Optional[torch.Tensor] = None,
|
259 |
+
v_x: Optional[torch.Tensor] = None,
|
260 |
+
coords = None,
|
261 |
+
attn_mask: Optional[torch.Tensor] = None,
|
262 |
+
key_padding_mask = None,
|
263 |
+
):
|
264 |
+
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
265 |
+
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
266 |
+
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, coords=coords, attn_mask=attn_mask, key_padding_mask=key_padding_mask))
|
267 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
268 |
+
return x
|
269 |
+
|
270 |
+
|
271 |
+
def _expand_token(token, batch_size: int):
|
272 |
+
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
273 |
+
|
274 |
+
|
275 |
+
class Transformer(nn.Module):
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
width: int,
|
279 |
+
layers: int,
|
280 |
+
heads: int,
|
281 |
+
mlp_ratio: float = 4.0,
|
282 |
+
ls_init_value: float = None,
|
283 |
+
act_layer: Callable = nn.GELU,
|
284 |
+
norm_layer: Callable = LayerNorm,
|
285 |
+
batch_first: bool = True,
|
286 |
+
use_flex: bool = False,
|
287 |
+
dropout: float = False,
|
288 |
+
use_rel_bias: bool = True,
|
289 |
+
):
|
290 |
+
super().__init__()
|
291 |
+
self.width = width
|
292 |
+
self.layers = layers
|
293 |
+
self.batch_first = batch_first
|
294 |
+
self.grad_checkpointing = False
|
295 |
+
|
296 |
+
self.resblocks = nn.ModuleList([
|
297 |
+
ResidualAttentionBlock(
|
298 |
+
width,
|
299 |
+
heads,
|
300 |
+
mlp_ratio,
|
301 |
+
ls_init_value=ls_init_value,
|
302 |
+
act_layer=act_layer,
|
303 |
+
norm_layer=norm_layer,
|
304 |
+
batch_first=batch_first,
|
305 |
+
use_flex=use_flex,
|
306 |
+
dropout=dropout,
|
307 |
+
use_rel_bias=use_rel_bias
|
308 |
+
)
|
309 |
+
for _ in range(layers)
|
310 |
+
])
|
311 |
+
|
312 |
+
def get_cast_dtype(self) -> torch.dtype:
|
313 |
+
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
|
314 |
+
return self.resblocks[0].mlp.c_fc.int8_original_dtype
|
315 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
316 |
+
|
317 |
+
def forward(self, x: torch.Tensor, coords = None, attn_mask: Optional[torch.Tensor] = None, key_padding_mask=None):
|
318 |
+
if not self.batch_first:
|
319 |
+
x = x.transpose(0, 1).contiguous() # NLD -> LND
|
320 |
+
for r in self.resblocks:
|
321 |
+
x = r(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, coords=coords)
|
322 |
+
if not self.batch_first:
|
323 |
+
x = x.transpose(0, 1).contiguous() # LND -> NLD
|
324 |
+
return x
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
class VisionTransformer(nn.Module):
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
width: int,
|
332 |
+
layers: int,
|
333 |
+
heads: int,
|
334 |
+
mlp_ratio: float,
|
335 |
+
ls_init_value: float = None,
|
336 |
+
output_dim: int = 512,
|
337 |
+
patch_dropout: float = 0.,
|
338 |
+
no_ln_pre: bool = False,
|
339 |
+
pool_type: str = 'tok',
|
340 |
+
final_ln_after_pool: bool = False,
|
341 |
+
act_layer: Callable = nn.GELU,
|
342 |
+
norm_layer: Callable = LayerNorm,
|
343 |
+
output_tokens: bool = False,
|
344 |
+
img_embed: bool = False,
|
345 |
+
use_flex:bool = False,
|
346 |
+
dropout:float = 0.1,
|
347 |
+
num_registers: int = 0,
|
348 |
+
use_rel_bias: bool = True,
|
349 |
+
):
|
350 |
+
super().__init__()
|
351 |
+
assert pool_type in ('tok', 'avg', 'none')
|
352 |
+
self.output_tokens = output_tokens
|
353 |
+
|
354 |
+
self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
|
355 |
+
self.output_dim = output_dim
|
356 |
+
self.img_embed = img_embed
|
357 |
+
self.num_registers = num_registers
|
358 |
+
self.positional_embedding = None
|
359 |
+
self.pre_linear = nn.Linear(768, width)
|
360 |
+
|
361 |
+
|
362 |
+
if num_registers>0:
|
363 |
+
self.register_token = nn.Parameter(torch.empty(num_registers, width))
|
364 |
+
nn.init.normal_(self.register_token, std=0.02)
|
365 |
+
|
366 |
+
|
367 |
+
self.positional_embedding = None
|
368 |
+
|
369 |
+
|
370 |
+
self.positional_embedding = None
|
371 |
+
|
372 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
373 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
374 |
+
|
375 |
+
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
|
376 |
+
self.transformer = Transformer(
|
377 |
+
width,
|
378 |
+
layers,
|
379 |
+
heads,
|
380 |
+
mlp_ratio,
|
381 |
+
ls_init_value=ls_init_value,
|
382 |
+
act_layer=act_layer,
|
383 |
+
norm_layer=norm_layer,
|
384 |
+
use_flex=use_flex,
|
385 |
+
dropout=dropout,
|
386 |
+
use_rel_bias=use_rel_bias,
|
387 |
+
)
|
388 |
+
|
389 |
+
pool_dim = width
|
390 |
+
self.pool_type = pool_type
|
391 |
+
|
392 |
+
self.ln_post = norm_layer(pool_dim)
|
393 |
+
|
394 |
+
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
395 |
+
if self.pool_type == 'avg':
|
396 |
+
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
|
397 |
+
elif self.pool_type == 'tok':
|
398 |
+
pooled, tokens = x[:, 0], x[:, 1:]
|
399 |
+
else:
|
400 |
+
pooled = tokens = x
|
401 |
+
|
402 |
+
return pooled, tokens
|
403 |
+
|
404 |
+
def forward(self, x: torch.Tensor, coords=None, mask=None, key_padding_mask=None):
|
405 |
+
x = self.pre_linear(x)
|
406 |
+
|
407 |
+
if self.num_registers > 0:
|
408 |
+
r = repeat(self.register_token, 'n d -> b n d', b=x.size(0))
|
409 |
+
x, ps = pack([x, r], 'b * d')
|
410 |
+
|
411 |
+
x = self.patch_dropout(x)
|
412 |
+
x = self.ln_pre(x)
|
413 |
+
x = self.transformer(x, coords, mask, key_padding_mask=key_padding_mask)
|
414 |
+
|
415 |
+
if self.final_ln_after_pool:
|
416 |
+
pooled, tokens = self._global_pool(x)
|
417 |
+
pooled = self.ln_post(pooled)
|
418 |
+
else:
|
419 |
+
x = self.ln_post(x)
|
420 |
+
pooled, tokens = self._global_pool(x)
|
421 |
+
|
422 |
+
return pooled
|
423 |
+
|
424 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
|
3 |
+
torch==2.6.0+cu124
|
4 |
+
torchvision==0.21.0+cu124
|
5 |
+
numpy==2.1.2
|
6 |
+
opencv-python-headless==4.11.0.86
|
7 |
+
openslide-bin==4.0.0.6
|
8 |
+
openslide-python==1.4.1
|
9 |
+
pandas==2.2.3
|
10 |
+
torchstain==1.4.1
|
11 |
+
fire==0.7.0
|
12 |
+
einops==0.8.1
|
13 |
+
huggingface_hub==0.32.2
|
14 |
+
h5py==3.13.0
|
15 |
+
scikit-learn==1.6.1
|
16 |
+
tensorboard==2.19.0
|
utils/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CLASS_NAMES = ["LABEL_LOW", "LABEL_HIGH"]
|
utils/file_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import h5py
|
3 |
+
|
4 |
+
def save_pkl(filename, save_object):
|
5 |
+
writer = open(filename,'wb')
|
6 |
+
pickle.dump(save_object, writer)
|
7 |
+
writer.close()
|
8 |
+
|
9 |
+
def load_pkl(filename):
|
10 |
+
loader = open(filename,'rb')
|
11 |
+
file = pickle.load(loader)
|
12 |
+
loader.close()
|
13 |
+
return file
|
14 |
+
|
15 |
+
|
16 |
+
def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'):
|
17 |
+
file = h5py.File(output_path, mode)
|
18 |
+
for key, val in asset_dict.items():
|
19 |
+
data_shape = val.shape
|
20 |
+
if key not in file:
|
21 |
+
data_type = val.dtype
|
22 |
+
chunk_shape = (1, ) + data_shape[1:]
|
23 |
+
maxshape = (None, ) + data_shape[1:]
|
24 |
+
dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type)
|
25 |
+
dset[:] = val
|
26 |
+
if attr_dict is not None:
|
27 |
+
if key in attr_dict.keys():
|
28 |
+
for attr_key, attr_val in attr_dict[key].items():
|
29 |
+
dset.attrs[attr_key] = attr_val
|
30 |
+
else:
|
31 |
+
dset = file[key]
|
32 |
+
dset.resize(len(dset) + data_shape[0], axis=0)
|
33 |
+
dset[-data_shape[0]:] = val
|
34 |
+
file.close()
|
35 |
+
return output_path
|
utils/preprocessor.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchstain.base.normalizers.he_normalizer import HENormalizer
|
7 |
+
from torchstain.torch.utils import cov, percentile
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.transforms.functional import to_pil_image
|
10 |
+
|
11 |
+
|
12 |
+
def preprocessor(pretrained=False, normalizer=None):
|
13 |
+
if pretrained:
|
14 |
+
mean = (0.485, 0.456, 0.406)
|
15 |
+
std = (0.229, 0.224, 0.225)
|
16 |
+
else:
|
17 |
+
mean = (0.5, 0.5, 0.5)
|
18 |
+
std = (0.5, 0.5, 0.5)
|
19 |
+
|
20 |
+
preprocess = transforms.Compose(
|
21 |
+
[
|
22 |
+
transforms.Resize(256),
|
23 |
+
transforms.CenterCrop(224),
|
24 |
+
transforms.Lambda(lambda x: x) if normalizer == None else normalizer,
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(mean=mean, std=std),
|
27 |
+
]
|
28 |
+
)
|
29 |
+
|
30 |
+
return preprocess
|
31 |
+
|
32 |
+
|
33 |
+
"""
|
34 |
+
Source code ported from: https://github.com/schaugf/HEnorm_python
|
35 |
+
Original implementation: https://github.com/mitkovetta/staining-normalization
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
class TorchMacenkoNormalizer(HENormalizer):
|
40 |
+
def __init__(self):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.HERef = torch.tensor(
|
44 |
+
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]]
|
45 |
+
)
|
46 |
+
self.maxCRef = torch.tensor([1.9705, 1.0308])
|
47 |
+
|
48 |
+
# Avoid using deprecated torch.lstsq (since 1.9.0)
|
49 |
+
self.updated_lstsq = hasattr(torch.linalg, "lstsq")
|
50 |
+
|
51 |
+
def __convert_rgb2od(self, I, Io, beta):
|
52 |
+
I = I.permute(1, 2, 0)
|
53 |
+
|
54 |
+
# calculate optical density
|
55 |
+
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)
|
56 |
+
|
57 |
+
# remove transparent pixels
|
58 |
+
ODhat = OD[~torch.any(OD < beta, dim=1)]
|
59 |
+
|
60 |
+
return OD, ODhat
|
61 |
+
|
62 |
+
def __find_HE(self, ODhat, eigvecs, alpha):
|
63 |
+
# project on the plane spanned by the eigenvectors corresponding to the two
|
64 |
+
# largest eigenvalues
|
65 |
+
That = torch.matmul(ODhat, eigvecs)
|
66 |
+
phi = torch.atan2(That[:, 1], That[:, 0])
|
67 |
+
# print(phi.size())
|
68 |
+
|
69 |
+
minPhi = percentile(phi, alpha)
|
70 |
+
maxPhi = percentile(phi, 100 - alpha)
|
71 |
+
|
72 |
+
vMin = torch.matmul(
|
73 |
+
eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))
|
74 |
+
).unsqueeze(1)
|
75 |
+
vMax = torch.matmul(
|
76 |
+
eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))
|
77 |
+
).unsqueeze(1)
|
78 |
+
|
79 |
+
# a heuristic to make the vector corresponding to hematoxylin first and the
|
80 |
+
# one corresponding to eosin second
|
81 |
+
HE = torch.where(
|
82 |
+
vMin[0] > vMax[0],
|
83 |
+
torch.cat((vMin, vMax), dim=1),
|
84 |
+
torch.cat((vMax, vMin), dim=1),
|
85 |
+
)
|
86 |
+
|
87 |
+
return HE
|
88 |
+
|
89 |
+
def __find_concentration(self, OD, HE):
|
90 |
+
# rows correspond to channels (RGB), columns to OD values
|
91 |
+
Y = OD.T
|
92 |
+
|
93 |
+
# determine concentrations of the individual stains
|
94 |
+
if not self.updated_lstsq:
|
95 |
+
return torch.lstsq(Y, HE)[0][:2]
|
96 |
+
|
97 |
+
return torch.linalg.lstsq(HE, Y)[0]
|
98 |
+
|
99 |
+
def __compute_matrices(self, I, Io, alpha, beta):
|
100 |
+
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
|
101 |
+
|
102 |
+
# compute eigenvectors
|
103 |
+
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
|
104 |
+
eigvecs = eigvecs[:, [1, 2]]
|
105 |
+
|
106 |
+
HE = self.__find_HE(ODhat, eigvecs, alpha)
|
107 |
+
|
108 |
+
C = self.__find_concentration(OD, HE)
|
109 |
+
maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
|
110 |
+
|
111 |
+
return HE, C, maxC
|
112 |
+
|
113 |
+
def fit(self, I, Io=240, alpha=1, beta=0.15):
|
114 |
+
HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)
|
115 |
+
|
116 |
+
self.HERef = HE
|
117 |
+
self.maxCRef = maxC
|
118 |
+
|
119 |
+
def normalize(
|
120 |
+
self, I, Io=240, alpha=1, beta=0.15, stains=True, form="chw", dtype="int"
|
121 |
+
):
|
122 |
+
"""Normalize staining appearence of H&E stained images
|
123 |
+
|
124 |
+
Example use:
|
125 |
+
see test.py
|
126 |
+
|
127 |
+
Input:
|
128 |
+
I: RGB input image: tensor of shape [C, H, W] and type uint8
|
129 |
+
Io: (optional) transmitted light intensity
|
130 |
+
alpha: percentile
|
131 |
+
beta: transparency threshold
|
132 |
+
stains: if true, return also H & E components
|
133 |
+
|
134 |
+
Output:
|
135 |
+
Inorm: normalized image
|
136 |
+
H: hematoxylin image
|
137 |
+
E: eosin image
|
138 |
+
|
139 |
+
Reference:
|
140 |
+
A method for normalizing histology slides for quantitative analysis. M.
|
141 |
+
Macenko et al., ISBI 2009
|
142 |
+
"""
|
143 |
+
|
144 |
+
c, h, w = I.shape
|
145 |
+
|
146 |
+
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
|
147 |
+
|
148 |
+
# normalize stain concentrations
|
149 |
+
C *= (self.maxCRef / maxC).unsqueeze(-1)
|
150 |
+
|
151 |
+
# recreate the image using reference mixing matrix
|
152 |
+
Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
|
153 |
+
Inorm = torch.clip(Inorm, 0, 255)
|
154 |
+
|
155 |
+
Inorm = Inorm.reshape(c, h, w).float() / 255.0
|
156 |
+
Inorm = torch.clip(Inorm, 0.0, 1.0)
|
157 |
+
|
158 |
+
H, E = None, None
|
159 |
+
|
160 |
+
if stains:
|
161 |
+
H = torch.mul(
|
162 |
+
Io,
|
163 |
+
torch.exp(
|
164 |
+
torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))
|
165 |
+
),
|
166 |
+
)
|
167 |
+
H[H > 255] = 255
|
168 |
+
H = H.T.reshape(h, w, c).int()
|
169 |
+
|
170 |
+
E = torch.mul(
|
171 |
+
Io,
|
172 |
+
torch.exp(
|
173 |
+
torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))
|
174 |
+
),
|
175 |
+
)
|
176 |
+
E[E > 255] = 255
|
177 |
+
E = E.T.reshape(h, w, c).int()
|
178 |
+
|
179 |
+
return Inorm, H, E
|
180 |
+
|
181 |
+
|
182 |
+
class MacenkoNormalizer:
|
183 |
+
def __init__(self, target_path=None, prob=1):
|
184 |
+
self.transform_before_macenko = transforms.Compose(
|
185 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 255)]
|
186 |
+
)
|
187 |
+
self.normalizer = TorchMacenkoNormalizer()
|
188 |
+
|
189 |
+
ext = os.path.splitext(target_path)[1].lower()
|
190 |
+
if ext in [".jpg", ".jpeg", ".png"]:
|
191 |
+
target = Image.open(target_path)
|
192 |
+
self.normalizer.fit(self.transform_before_macenko(target))
|
193 |
+
elif ext in [".pt"]:
|
194 |
+
target = torch.load(target_path)
|
195 |
+
self.normalizer.HERef = target["HERef"]
|
196 |
+
self.normalizer.maxCRef = target["maxCRef"]
|
197 |
+
|
198 |
+
else:
|
199 |
+
raise ValueError(f"Invalid extension: {ext}")
|
200 |
+
self.prob = prob
|
201 |
+
|
202 |
+
def __call__(self, image):
|
203 |
+
t_to_transform = self.transform_before_macenko(image)
|
204 |
+
try:
|
205 |
+
image_macenko, _, _ = self.normalizer.normalize(
|
206 |
+
I=t_to_transform, stains=False, form="chw", dtype="float"
|
207 |
+
)
|
208 |
+
if torch.any(torch.isnan(image_macenko)):
|
209 |
+
return image
|
210 |
+
else:
|
211 |
+
image_macenko = to_pil_image(image_macenko)
|
212 |
+
return image_macenko
|
213 |
+
except Exception as e:
|
214 |
+
if "kthvalue()" in str(e) or "linalg.eigh" in str(e):
|
215 |
+
pass
|
216 |
+
else:
|
217 |
+
print(str(e))
|
218 |
+
return image
|
utils/wsi_utils.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from openslide import OpenSlide
|
5 |
+
|
6 |
+
|
7 |
+
def extract_tissue_patch_coords(
|
8 |
+
wsi_path: str,
|
9 |
+
patch_size: int = 256,
|
10 |
+
step_size: int = 256,
|
11 |
+
downsample_threshold: float = 64,
|
12 |
+
threshold: int = 8,
|
13 |
+
max_val: int = 255,
|
14 |
+
median_kernel: int = 7,
|
15 |
+
close_size: int = 4,
|
16 |
+
min_effective_area_factor: float = 100, # multiplied by (ref_area)^2
|
17 |
+
ref_area: int = 512,
|
18 |
+
min_hole_area_factor: float = 16, # multiplied by (ref_area)^2
|
19 |
+
max_n_holes: int = 8,
|
20 |
+
)-> list:
|
21 |
+
"""
|
22 |
+
Extract patches from the full-resolution image whose centers fall within tissue regions.
|
23 |
+
|
24 |
+
Process:
|
25 |
+
1. Open the WSI.
|
26 |
+
2. Select a segmentation level and compute a binary mask.
|
27 |
+
3. Find contours and holes from the mask and filter them using effective area criteria.
|
28 |
+
4. Scale the external contours and holes to full resolution.
|
29 |
+
5. Slide a window over the full-resolution image and extract patches if the center is in tissue.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
A torch tensor of shape (N, 3, patch_size, patch_size) containing the patches.
|
33 |
+
"""
|
34 |
+
slide = OpenSlide(wsi_path)
|
35 |
+
full_width, full_height = slide.level_dimensions[0]
|
36 |
+
|
37 |
+
seg_level, scale = select_segmentation_level(slide, downsample_threshold)
|
38 |
+
binary_mask = compute_segmentation_mask(
|
39 |
+
slide, seg_level, threshold, max_val, median_kernel, close_size
|
40 |
+
)
|
41 |
+
|
42 |
+
# Compute thresholds for effective area and hole area
|
43 |
+
effective_area_thresh = min_effective_area_factor * (
|
44 |
+
ref_area**2 / (scale[0] * scale[1])
|
45 |
+
)
|
46 |
+
hole_area_thresh = min_hole_area_factor * (ref_area**2 / (scale[0] * scale[1]))
|
47 |
+
|
48 |
+
ext_contours, holes_list = filter_contours_and_holes(
|
49 |
+
binary_mask, effective_area_thresh, hole_area_thresh, max_n_holes
|
50 |
+
)
|
51 |
+
if not ext_contours:
|
52 |
+
raise ValueError("No valid tissue contours found.")
|
53 |
+
|
54 |
+
tissue_contours = scale_contours(ext_contours, scale)
|
55 |
+
scaled_holes = [scale_contours(holes, scale) for holes in holes_list]
|
56 |
+
|
57 |
+
coords = []
|
58 |
+
for y in range(0, full_height - patch_size + 1, step_size):
|
59 |
+
for x in range(0, full_width - patch_size + 1, step_size):
|
60 |
+
center_x = x + patch_size // 2
|
61 |
+
center_y = y + patch_size // 2
|
62 |
+
if not point_in_tissue(center_x, center_y, tissue_contours, scaled_holes):
|
63 |
+
continue
|
64 |
+
coords.append((x, y))
|
65 |
+
|
66 |
+
if not coords:
|
67 |
+
raise ValueError("No available patches")
|
68 |
+
return coords
|
69 |
+
|
70 |
+
|
71 |
+
def select_segmentation_level(slide: OpenSlide, downsample_threshold: float = 64):
|
72 |
+
"""
|
73 |
+
Select a segmentation level whose downsample factor is at least the specified threshold.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
level (int): Chosen level index.
|
77 |
+
scale (tuple): Downsample factors (sx, sy) for that level.
|
78 |
+
"""
|
79 |
+
level = slide.get_best_level_for_downsample(downsample_threshold)
|
80 |
+
ds = slide.level_downsamples[level]
|
81 |
+
if not isinstance(ds, (tuple, list)):
|
82 |
+
ds = (ds, ds)
|
83 |
+
return level, ds
|
84 |
+
|
85 |
+
|
86 |
+
def compute_segmentation_mask(
|
87 |
+
slide: OpenSlide,
|
88 |
+
level: int,
|
89 |
+
threshold: int = 20,
|
90 |
+
max_val: int = 255,
|
91 |
+
median_kernel: int = 7,
|
92 |
+
close_size: int = 4,
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Compute a binary mask for tissue segmentation at the specified level.
|
96 |
+
|
97 |
+
Process:
|
98 |
+
- Read the image at the given level and convert to RGB.
|
99 |
+
- Convert the image to HSV and extract the saturation channel.
|
100 |
+
- Apply median blur.
|
101 |
+
- Apply binary thresholding (either fixed or Otsu).
|
102 |
+
- Apply morphological closing.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
binary (ndarray): Binary mask image.
|
106 |
+
"""
|
107 |
+
img = np.array(
|
108 |
+
slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB")
|
109 |
+
)
|
110 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
111 |
+
sat = hsv[:, :, 1]
|
112 |
+
blurred = cv2.medianBlur(sat, median_kernel)
|
113 |
+
_, binary = cv2.threshold(blurred, threshold, max_val, cv2.THRESH_BINARY)
|
114 |
+
if close_size > 0:
|
115 |
+
kernel = np.ones((close_size, close_size), np.uint8)
|
116 |
+
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
117 |
+
return binary
|
118 |
+
|
119 |
+
|
120 |
+
def filter_contours_and_holes(
|
121 |
+
binary_mask: np.ndarray,
|
122 |
+
min_effective_area: float,
|
123 |
+
min_hole_area: float,
|
124 |
+
max_n_holes: int,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Find contours from the binary mask and filter them based on effective area.
|
128 |
+
|
129 |
+
For each external contour (one with no parent), identify child contours (holes),
|
130 |
+
sort them by area (largest first), and keep up to max_n_holes that exceed min_hole_area.
|
131 |
+
The effective area is computed as the area of the external contour minus the sum of areas
|
132 |
+
of the selected holes. Only contours with effective area above min_effective_area are retained.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
filtered_contours (list): List of external contours (numpy arrays).
|
136 |
+
holes_list (list): Corresponding list of lists of hole contours.
|
137 |
+
"""
|
138 |
+
contours, hierarchy = cv2.findContours(
|
139 |
+
binary_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
|
140 |
+
)
|
141 |
+
if hierarchy is None:
|
142 |
+
return [], []
|
143 |
+
hierarchy = hierarchy[0] # shape: (N, 4)
|
144 |
+
filtered_contours = []
|
145 |
+
holes_list = []
|
146 |
+
for idx, h in enumerate(hierarchy):
|
147 |
+
if h[3] != -1:
|
148 |
+
continue # Only external contours
|
149 |
+
ext_cont = contours[idx]
|
150 |
+
ext_area = cv2.contourArea(ext_cont)
|
151 |
+
# Find child contours (holes)
|
152 |
+
hole_idxs = [i for i, hr in enumerate(hierarchy) if hr[3] == idx]
|
153 |
+
# Sort holes by area descending and keep up to max_n_holes
|
154 |
+
sorted_holes = sorted(
|
155 |
+
[contours[i] for i in hole_idxs], key=cv2.contourArea, reverse=True
|
156 |
+
)
|
157 |
+
selected_holes = [
|
158 |
+
hole
|
159 |
+
for hole in sorted_holes[:max_n_holes]
|
160 |
+
if cv2.contourArea(hole) > min_hole_area
|
161 |
+
]
|
162 |
+
total_hole_area = sum(cv2.contourArea(hole) for hole in selected_holes)
|
163 |
+
effective_area = ext_area - total_hole_area
|
164 |
+
if effective_area > min_effective_area:
|
165 |
+
filtered_contours.append(ext_cont)
|
166 |
+
holes_list.append(selected_holes)
|
167 |
+
return filtered_contours, holes_list
|
168 |
+
|
169 |
+
|
170 |
+
def scale_contours(contours: list, scale: tuple) -> list:
|
171 |
+
"""
|
172 |
+
Scale contour coordinates by the provided scale factors.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
contours: List of contours (each a numpy array of points).
|
176 |
+
scale: Tuple (sx, sy) for scaling.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
List of scaled contours.
|
180 |
+
"""
|
181 |
+
scaled = []
|
182 |
+
for cont in contours:
|
183 |
+
scaled.append((cont * np.array(scale, dtype=np.float32)).astype(np.int32))
|
184 |
+
return scaled
|
185 |
+
|
186 |
+
|
187 |
+
def point_in_tissue(x: int, y: int, ext_contours: list, holes_list: list) -> bool:
|
188 |
+
"""
|
189 |
+
Check if point (x, y) lies within any external contour and not inside its corresponding holes.
|
190 |
+
|
191 |
+
For each external contour in ext_contours (paired with holes_list),
|
192 |
+
if the point is inside the contour and not inside any of its holes, return True.
|
193 |
+
"""
|
194 |
+
for cont, holes in zip(ext_contours, holes_list):
|
195 |
+
if cv2.pointPolygonTest(cont, (x, y), False) >= 0:
|
196 |
+
inside_hole = False
|
197 |
+
for hole in holes:
|
198 |
+
if cv2.pointPolygonTest(hole, (x, y), False) >= 0:
|
199 |
+
inside_hole = True
|
200 |
+
break
|
201 |
+
if not inside_hole:
|
202 |
+
return True
|
203 |
+
return False
|
204 |
+
|
205 |
+
|
206 |
+
def tile(x: torch.Tensor, size: int):
|
207 |
+
C, H, W = x.shape[-3:]
|
208 |
+
|
209 |
+
pad_h = (size - H % size) % size
|
210 |
+
pad_w = (size - W % size) % size
|
211 |
+
if pad_h > 0 or pad_w > 0:
|
212 |
+
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
|
213 |
+
|
214 |
+
nh, nw = x.size(2) // size, x.size(3) // size
|
215 |
+
return (
|
216 |
+
x.view(-1, C, nh, size, nw, size)
|
217 |
+
.permute(0, 2, 4, 1, 3, 5)
|
218 |
+
.reshape(-1, C, size, size)
|
219 |
+
)
|