Foundational models represent a significant advancement in machine learning, characterized by their large-scale training on extensive datasets comprising unlabeled or semi-labeled data. These models exhibit remarkable versatility, serving as a basis for a wide range of downstream tasks while demonstrating strong zero-shot prediction capabilities. The emergence of self-supervised learning (SSL) techniques, such as SimCLR, DINO, and DINOv2, has catalyzed the development of foundational models across various domains. In this article, we specifically examine the application and impact of foundational models within the field of digital pathology.
A brief glance at DINO:
The DINO (Distillation with No Labels) framework introduces a novel approach to self-supervised learning by leveraging two distinct random transformations of an input image. From a single input image, a set V of multiple views is generated, which includes two global views and several local views. All views are processed by the student network, while only the global views are passed through the teacher network.
Unlike traditional knowledge distillation methods, where the teacher network is typically more complex with a greater number of parameters, DINO employs identical architectures for both the student and teacher networks. However, the parameters of the two networks differ. The teacher network's output is normalized through centering, which involves computing a mean across the batch. The training objective is defined by minimizing the cross-entropy loss between the outputs of the student and teacher networks. Crucially, only the parameters of the student network are updated during backpropagation, with the gradients being stopped at the teacher network to avoid updates.
The parameters of the teacher network are iteratively updated using an exponential moving average (EMA) of the student network's parameters. This parameter-sharing strategy ensures that the teacher network evolves in alignment with the student network while maintaining stability and robustness during training. This design enables DINO to achieve effective self-supervised learning without reliance on labeled data.
# gs, gt: student and teacher networks
# C: center (K)
# tps, tpt: student and teacher temperatures
# l, m: network and center momentum rates
gt.params = gs.params
for x in loader: # load a minibatch x with n samples
x1, x2 = augment(x), augment(x) # random views
s1, s2 = gs(x1), gs(x2) # student output n-by-K
t1, t2 = gt(x1), gt(x2) # teacher output n-by-K
loss = H(t1, s2)/2 + H(t2, s1)/2
loss.backward() # back-propagate
# student, teacher and center updates
update(gs) # SGD
gt.params = l*gt.params + (1-l)*gs.params
C = m*C + (1-m)*cat([t1, t2]).mean(dim=0)
def H(t, s):
t = t.detach() # stop gradient
s = softmax(s / tps, dim=1)
t = softmax((t - C) / tpt, dim=1) # center + sharpen
return - (t * log(s)).sum(dim=1).mean()
LongViT by Microsoft
Gigapixel whole slide images (WSIs) present unique challenges due to their immense size, necessitating substantial memory and computational resources. To address these limitations, WSIs are typically divided into smaller, manageable patches. These patches are then individually processed by a model, and their outputs are subsequently aggregated using a predefined aggregation function. To facilitate an efficient and comprehensive end-to-end analysis pipeline, researchers have introduced LongViT, a vision transformer inspired by the architecture of LongNet, a model capable of handling sequences containing up to a billion tokens.
LongViT is particularly well-suited for processing WSIs as it employs dilated attention mechanisms to efficiently handle extremely long sequences. The workflow involves extracting all patches from a WSI and applying a linear projection to obtain their corresponding patch embeddings. These embeddings are further enriched by adding learnable one-dimensional positional embeddings to encode spatial information. This enriched representation is then fed into LongNet, enabling the generation of feature representations that capture the global context of the entire WSI. By incorporating such context-aware embeddings, LongViT achieves a holistic understanding of the WSI, addressing the challenges posed by its gigapixel scale.
With 10K slides from TCGA, they randomly take 100 crops from each WSI of varying widths and heights (1,024 to 1,536) and extract 32×32 patches on which it is pretrained under DINO framework
Campanella et al
In this study, the authors compare the pre-training of vision transformer models using the Masked Autoencoder (MAE) and DINO frameworks. As DINO has been previously discussed, the focus here shifts to the MAE framework:
While conventional masked autoencoders are typically symmetric in design, this work introduces an asymmetric architecture, which encompasses the following stages:
Masking:
- The image is divided into regular, non-overlapping patches. A small subset of these patches is randomly sampled, while the remaining patches are masked (i.e., removed).
MAE encoder:
- The encoder is a standard Vision Transformer (ViT) that processes the sampled patches and generates their embeddings.
MAE decoder:
- The previously masked tokens are reintroduced and combined with the embeddings. Positional embeddings are then added to these tokens before being passed through the decoder, which consists of another series of transformer blocks.
Reconstructing target:
- The decoder's final layer employs a linear projection to reconstruct the pixel values of the original image. The reconstruction is optimized using a simple mean squared error (MSE) loss function.
HIPT (Hierarchical Image Pyramid Transformer)
Taking inspiration from hierarchical representations in natural language processing, where embeddings can be aggregated at the character, word, sentence and paragraph level to form document representations, they aggregate visual tokens at the cell (16×16), patch (256×256) and region (4096×4096) to form slide representations. In choosing these image sizes, the input sequence length of tokens is always M = 256 in the forward passes (at cell and patch level aggregation) and usually M < 256 for slide level aggregation (as 4096×256 = 1048576), one of the other reason which they give is that 16×16 at 20x is ~ 8um² encoding visual concepts focusing a single cell. A patch of 256×256 is fed to the ViT_256-16 which unrolls the image as a sequence of non overlapping 16×16 tokens which are passed through a linear embedding layer after adding positional encodings to produce embeddings to which a learnable [CLS] token is added to feed it to the Transformer block. The output CLS tokens from it is fed to ViT_4K-256 directly as representative of [256×256] region thus covering the [4096×4096] region. Then the output CLS token from it is fed to VIT_WSI_4k. (Due to potential tissue segmentation irregularities in patching at [4096×4096], they ignored positional embeddings at this stage, (what does this mean?))
To use HIPT_4K directly on images:
import torch
from einops import rearrange, repeat
from HIPT_4K.hipt_model_utils import get_vit256, get_vit4k
class HIPT_4K(torch.nn.Module):
"""
HIPT Model (ViT_4K-256) for encoding non-square images (with [256 x 256] patch tokens), with
[256 x 256] patch tokens encoded via ViT_256-16 using [16 x 16] patch tokens.
"""
def __init__(self,
model256_path: str = 'path/to/Checkpoints/vit256_small_dino.pth',
model4k_path: str = 'path/to/Checkpoints/vit4k_xs_dino.pth',
device256=torch.device('cuda:0'),
device4k=torch.device('cuda:1')):
super().__init__()
self.model256 = get_vit256(pretrained_weights=model256_path).to(device256)
self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
self.device256 = device256
self.device4k = device4k
self.patch_filter_params = patch_filter_params
def forward(self, x):
"""
Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT_4K.
1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K (e.g. - 256 x 256).
2. x then gets unfolded into a "batch" of [256 x 256] images.
3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch.
4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".)
5. This feature grid is then used as the input to ViT_4K-256, outputting [CLS]_4K.
Args:
- x (torch.Tensor): [1 x C x W' x H'] image tensor.
Return:
- features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
"""
batch_256, w_256, h_256 = self.prepare_img_tensor(x) # 1. [1 x 3 x W x H].
batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) # 2. [1 x 3 x w_256 x h_256 x 256 x 256]
batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h') # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256)
features_cls256 = []
for mini_bs in range(0, batch_256.shape[0], 256): # 3. B may be too large for ViT_256. We further take minibatches of 256.
minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True)
features_cls256.append(self.model256(minibatch_256).detach().cpu()) # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches.
features_cls256 = torch.vstack(features_cls256) # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token.
features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0,1).transpose(0,2).unsqueeze(dim=0)
features_cls256 = features_cls256.to(self.device4k, non_blocking=True) # 4. [1 x 384 x w_256 x h_256]
features_cls4k = self.model4k.forward(features_cls256) # 5. [1 x 192], where 192 == dim of ViT_4K [ClS] token.
return features_cls4k
To pretrain ViT_256-16 and ViT_4K-256, DINO framework was used.
CLIP: The modern Big Bang
CLIP (Contrastive Language–Image Pretraining) is a multi-modal machine learning model developed by OpenAI. It is designed to understand both images and text by learning a shared embedding space, enabling it to perform tasks such as zero-shot classification, image captioning, and more, without being explicitly trained on those tasks.
CLIP achieves this by using a contrastive learning objective to align images and their corresponding text descriptions. It learns representations where:
Similar images and texts (aligned pairs) are close to each other in the embedding space.
Unrelated images and texts (non-aligned pairs) are far apart in the embedding space.
Pseudocode for implementing CLIP
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
Read more about the learnable temperature parameter “t”
There’s no non-linear projection while projecting them on a join multimodal space. (Read more at Bach-man et al and Chen et al)
Out of N² , we want to push the similar N embeddings closer in the joint multimodal space and the rest N²-1 away. Therefore the loss_i (ensures that the embedding of each image is closest to the embedding of its corresponding text) and loss_t (ensures that each embedding of each text is closest to the embedding of its corresponding image)
PLIP (Pathology language-image pretraining)
Used CLIP framework on pathology dataset by curating a dataset from tweets (and other sites like LAION) having pathological images, thus curating OpenPath and PathLAION which they released as public datasets along with weights of an image encoder and a text encoder trained under CLIP to have a joint multimodal space. Then these encoder can be used for various purposes as it will have a better image representation like classification, image to image retrieval, text to image retrieval etc.
MI-Zero
Uses CLIP frame to get a joint multimodal space, then uses image encoder to get image embeddings from a WSI (whole slide image)’ s patches and getting cosine similarity scores between these image embeddings and prompt embeddings with fixed classes. (For example, a histopathological image of <classname> ie (a histopathological image of carcinoma of the breast, ductal pattern)) and then using a permutation invariant operator h(.) like mean or topK max pooling for each class under set-based representation to get the slide level prediction.
$$\text{Image embeddings: } \{ u_i \}_{i=1,\ldots,N}\\ \text{where } N \text{ represents the number of patches from tissue in the WSI.}$$
$$\text{Prompt embeddings: } \{ w_m \}_{m=1,\ldots,C} \\ \text{where } C \text{ is the number of classes.}$$
$$s_i = \mathbf{u}_i^T [\mathbf{w}_1, \mathbf{w}_2, \ldots, \mathbf{w}_C], \quad s_i \in \mathbb{R}^C \\ \text{where } s_i \text{ is the cosine similarity between the } i\text{-th img embed and text embed.}$$
$$h_{\text{mean}}(\mathbf{S}) = \frac{1}{N} \sum_{i=1}^{N} s_i$$
$$h_{\text{topK}}(\mathbf{S}) = \frac{1}{K} \left[ \sum_{i=1}^{K} \tilde{s}i^1, \sum{i=1}^{K} \tilde{s}i^2, \ldots, \sum{i=1}^{K} \tilde{s}_i^C \right]^T \\ \text{where } C \text{ is the number of classes, ranging from 1 to } C.$$
For Graph based representation (as tissue has meaning encoded in a spatial way which we were loosing before) , take the spatial position of each patch and build a directed KNN graph G, connecting each patch (node) to its spatial neighbors, where the value at the node i is its score s_i.
$$G = \{ \mathcal{M}, \mathcal{E} \}$$
Then they spatially smooth the score values (average) by replacing s_i with h_mean(S_neighbors), where S_neighbors
$$\mathcal{S}_{\text{neighbors}} = \{ s_j : j \in \{ i \} \cup \mathcal{N}(i) \}$$
and
$$\mathcal{N}(i) = \{ j : (i, j) \in \mathcal{E} \}$$
for each node i in the graph and then applying a permutation invariant operator h(.) to get the slide level prediction.
CONCH (CONtrastive learning from Captions for Histopathology)
Now you must be wondering why no one is using a decoder to describe the image or using text to generate an image. With CONCH, they introduce a decoder to describe the image using CoCA (Contrastive Captioner) framework by google.
The CoCa architecture has an image encoder, a text encoder and a multimodal text decoder. As in vanilla ViT, the image is broken into non-overlapping tokens to which absolute positional embeddings are added and fed to the the image encoder (ViT). On top of the image encoder there are two attentional pooler (lets denote them as f_contrast and f_caption) which are responsible for computing a fixed number (N) of image tokens from the last layer representation of the ViT backbone using multihead attention and n learned queries. For contrastive learning, attentional pooler f_contrast uses a single query (N= 1) to compute a single image token to capture global representation of the image. The other attentional pooler f_caption uses (N = 256) to generate a set of 256 image tokens to capture local and fine-grained details of the image.
The text encoder utilized an embedding table to map discrete word tokens into continuous embeddings, complemented by a set of learned absolute positional embeddings. Furthermore, a learned <CLS>
token was appended to each tokenized caption. This token had access to the full context through transformer attention, enabling it to capture a global representation of the caption.
The multimodal decoder incorporated a cross-attention layer following each multi-headed self-attention layer to integrate information from the image tokens. Additionally, it featured a final language modeling head, which predicted the distribution of the next token across the supported vocabulary.
Dual-Encoder Contrastive learning: They use the contrastive loss (as introduced in CLIP) between the image and text encoders.
$$\mathcal{L}{\text{Con}} = -\frac{1}{N} \left( \sum{i=1}^{N} \log \frac{\exp(x_i^\top y_i / \sigma)}{\sum_{j=1}^{N} \exp(x_i^\top y_j / \sigma)} + \sum_{i=1}^{N} \log \frac{\exp(y_i^\top x_i / \sigma)}{\sum_{j=1}^{N} \exp(y_i^\top x_j / \sigma)} \right)$$
Encoder-Decoder Captioning: While the dual-encoder approach encodes the text as a whole(so that related embeddings are closer), the generative approach aims for detailed granularity and requires the model to predict the exact tokenized texts of y autogressively. The text decoder learns to maximize the conditional likelihood of the paired text y under the forward autoregressive factorization. The encoder-decoder is trained with teacher-forcing. (Read more about it at J williams et al).
Fore some reason this latex is breaking, so:
$$\mathcal{L}{\text{Cap}} = - \sum{t=1}^{T} \log P_{\theta} \big( y_t \mid y_{
The overall CoCa objective then becomes:
$$\mathcal{L}{\text{CoCa}} = \lambda{\text{Con}} \cdot \mathcal{L}{\text{Con}} + \lambda{\text{Cap}} \cdot \mathcal{L}_{\text{Cap}}$$
where lambda_con and lambda_cap are loss weighting hyper-parameters.
# image, text.ids, text.labels, text.mask: paired {image, text} data
# con_query: 1 query token for contrastive embedding
# cap_query: N query tokens for captioning embedding
# cls_token_id: a special cls_token_id in vocabulary
def attentional_pooling(features, query):
out = multihead_attention(features, query)
return layer_norm(out)
img_feature = vit_encoder(image) # [batch, seq_len, dim]
con_feature = attentional_pooling(img_feature, con_query) # [batch, 1, dim]
cap_feature = attentional_pooling(img_feature, cap_query) # [batch, N, dim]
ids = concat(text.ids, cls_token_id)
mask = concat(text.mask, zeros_like(cls_token_id)) # unpad cls_token_id
txt_embs = embedding_lookup(ids)
unimodal_out = lm_transformers(txt_embs, mask, cross_attn=None)
multimodal_out = lm_transformers(
unimodal_out[:, :-1, :], mask, cross_attn=cap_feature)
cls_token_feature = layer_norm(unimodal_out)[:, -1:, :] # [batch, 1, dim]
con_loss = contrastive_loss(con_feature, cls_token_feature)
cap_loss = softmax_cross_entropy_loss(
multimodal_out, labels=text.labels, mask=text.mask)
# vit_encoder: vision transformer based encoder
# lm_transformer: language-model transformers
For training of CONCH, they take pre-trained image and text encoders (trained using contrastive loss) and using the COCA framework they just fine-tune using the caption loss.
UNI
UNI is ViT-Large pretrained on Mass-100K (100 million images from 100K H&E WSI (Whole Slide Images) (~77 TB) across 20 major tissue types) using DINOv2 framework.
A brief overview of DINOv2 (self-DIstillation with NO labels):
They curate LVD-142M as follows:
Self-deduplication: To deduplicate the uncurated dataset (1.3B images), use a MLP to generate embeddings and then retrieve k=64 nearest neighbors of each image using cosine similarity. Considering only neighbors with a similarity of greater than 0.6, they extract the connected component and replace this component with one just representative from it resulting to 1.1 B images.
Relative deduplication: To reduce redundancy and properly evaluate the performance of these methodology (features) they also discarded images which were too similar to images present in the train or test split present in the evaluation datasets (using the same approach but with a stricter threshold of 0.45) leading to 744 Million images.
Retrieval: For retrieval they used two methods (sample based and cluster based) depending upon the size of dataset. For dataset larger than 1M images, sample based technique was used. In this K of the nearest images from the uncurated datasets are choosen thus in a way multiplying the size of the dataset by K. For smaller datasets, cluster based technique was used in which the uncurated datasets is divided into 100,000 separate clusters using K-means. Then 10,000 images are picked from each cluster associated with more than 3 query images. (need better understanding in the cluster based technique like 10,000 × 100,000 is 1B?)
DINOv2: Improvements over DINO
iBOT (Patch level objective)
Given an input image x we extract two view u and v and then patches are extracted from each view. As in DINO, we have a set of multiple views consisting of global and local and student sees all the views where as teacher sees only global views, in DINOv2 the patches fed to student nework are masked at few positions and they from both global and local view whereas the teacher sees all the patches but only from global view. The Cross-view token loss L[cls] is the normal cross entropy loss between the <cls> output tokens of teacher and student networks and another loss L[mim] (Masked Image modelling loss) between the masked and unmasked output between the student and teacher neworks.
Sinkhorn-Knopp centering:
- Ruan et al recommend to replace the teacher softmax-centering step of DINO and iBOT by Sinkhorn-Knopp (SK) batch normalization of SwAV (Caron et al). For the student, they applied the softmax normalization.
KoLeo Regularizer:
It encourages a uniform span of the features within a batch. If you take a batch of lets say 4 images and their corresponding masks tokens, the masks tokens might not be well distributed enough in the batch, to overcome this KoLeo Regularizer was used.
(image taken from AI Bites’s Youtube video, subscribe that channel a lot of good stuff)
Uses Flash attention and few more tweaks
H-optimus-0
Pretrained ViT using DINOv2 using Registers (Registers (as in Vision Transformers need Registers) by Bioptimus
HIBOU
Pretrained ViT using DINOv2 by Hist.ai
Virchow
Pretrained ViT using DINOv2 by Paige.ai
RudolfV
Pretrained ViT using DINOv2 by Aignostics
Next in series:
LUNIT
TITAN etc
PS:
Models are not in any chronological order.
All the images are from their respective papers.
Please point out the mistakes/help in rectifying them. Thanks!