TemporalGradCam
Temporal interpretation of medical image data.
Dataset
The OASIS_2D
dataset contains brain X-ray images of 100 patients.
- 50 images with disease (
label 1
), 43 unique patients
- 50 images with healthy (
label 0
), 10 unique patients
- The image size is 256 x 256 colored.
- For experiement we split the dataset 80:20 for train and test using unique patients. So the same patient will not appear in both training and test. The split is stratified, so balanced amount of healthy and disease examples is present in train and test.
- No data augmentation is used at this point.
Figure: Length distribution when patients images are converted to a time series. Each patient can have multiple X-rays at different days. Most patients only have one image.
Model
Currently we have the following two models implemented
For training we freeze all layers except the output Linear layer
.
- Epochs: 25
- Learning rate: 1e-3
- Early stop: 5
- Experiment iteration: 5, the whole experiment is repeated 5 times using different random seed each time. The test results and best model checkpoints are saved.
Figure: Training vistory of one iteration from ResNet
Temporal Model
To create the temporal version of the OASIS model we,
- Extracted features from the images using the pretrained models (ResNet or ViT). The extracted feature dimension is equal to the dimension of layer just before the output layer.
- 512 for ResNet
- 768 for ViT
- For each sample
- find previous images of the same patient, max upto seq length
- create the time series example [seq_len, feature_dim].
- smaller sequences are padded at the beginning to the max sequence length. Larger sequences are truncated from the beginning (olders images are dropped).
- we currently use
seq_len 3
, around 70% examples fall within this range. Rest are padded.
- The model is batch first. Pytorch doesn’t easily support variable length time sequences.
- Currently we use a simple
DNN
model on the the temporal dataset.
- max epochs 100
- learning rate 1e-3
- dropout=0.1
- hidden_size=64
Results
Following shows the average test result across all five iterations.
Model |
Loss |
Accuracy |
F1-score |
AUC |
ResNet |
1.32 |
83.87 |
81.32 |
91.68 |
ResNet (Seq 3) |
1.45 |
79.87 |
81.24 |
82.24 |
ViT |
1.22 |
85.77 |
86.50 |
92.08 |
ViT (Seq 3) |
0.94 |
87.72 |
88.25 |
95.13 |
The temporal model (sequence length 3) with Vision Transformer is performing best so far.
Interpretation
Interpreting sample patient image for ResNet
Note that, this is not for the temporal model.
No |
Sample |
Gradient Shap |
GradCam |
Guided GradCam |
Guided Backprop |
1 |
|
|
|
|
|
2 |
|
|
|
|
|
3 |
|
|
|
|
|
Interpreting sample patient image for ViT
Note that, this is not for the temporal model.
No |
Sample |
Gradient Shap |
GradCam |
Guided GradCam |
Guided Backprop |
1 |
|
|
|
|
|
2 |
|
|
|
|
|
3 |
|
|
|
|
|
Files
The following files are available for now with pre-trained vision models for transfer learning on the medical dataset.
- oasis_resnet: Run the oasis dataset with ResNet model.
- oasis_ViT: Run the oasis dataset with Vision Transformer.
Literature