Note: This notebook demonstrates how to use DINO callback with a single GPU.

First import fastai for training and other helpers, you can choose not to use wandb by setting WANDB=False.

from fastai.vision.all import *
torch.backends.cudnn.benchmark = True
WANDB = False
if WANDB:
    try:
        from fastai.callback.wandb import WandbCallback
        import wandb
    except:
        raise ImportError("Please run '!pip install wandb' on another cell to install wandb")

Then import self_supervised augmentations module for creating augmentations pipeline, layers module for creating encoder and model, and finally simclr for self-supervised training.

from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.models.vision_transformer import *
from self_supervised.vision.dino import *
from self_supervised.vision.swav import get_swav_aug_pipelines

In this notebook we will take a look at ImageWang benchmark, how to train a self-supervised model using MoCo algorithm and then how to use this pretrained model for finetuning on the given downstream task.

Pretraining

def get_dls(size, bs, workers=None, n_subset=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    
    if n_subset is None: files = get_image_files(source)
    else:              files = np.random.choice(get_image_files(source), n_subset)
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
    
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

ImageWang has several benchmarks for different image sizes, in this tutorial we will go for size=224 and also demonstrate how effectively you can utilize GPU memory.

Define batch size, resize resolution before batching and size for random cropping during self-supervised training. It's always good to use a batch size as high as it can fit the GPU memory.

bs, resize, size = 64, 256, 224

Let's create a ViT DINO model.

deits16 = deit_small(patch_size=16, drop_path_rate=0.1)
deits16 = MultiCropWrapper(deits16)
dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
student_model = nn.Sequential(deits16,dino_head)

deits16 = deit_small(patch_size=16)
deits16 = MultiCropWrapper(deits16)
dino_head = DINOHead(deits16.encoder.embed_dim, 2**16, norm_last_layer=True)
teacher_model = nn.Sequential(deits16,dino_head)

dino_model = DINOModel(student_model, teacher_model)
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3063: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed "
if WANDB:
    xtra_config = {"Arch":"deits16", "Resize":resize, "Size":size, "Algorithm":"DINO"}
    wandb.init(project="self-supervised-imagewang", config=xtra_config);
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: keremturgutlu (use `wandb login --relogin` to force relogin)
Tracking run with wandb version 0.10.30
Syncing run gentle-dust-67 to Weights & Biases (Documentation).
Project page: https://wandb.ai/keremturgutlu/self-supervised-imagewang
Run page: https://wandb.ai/keremturgutlu/self-supervised-imagewang/runs/3p86oq38
Run data is saved locally in /home/code-base/extra_space/self_supervised/examples/vision/wandb/run-20210517_143759-3p86oq38

Initialize the Dataloaders using the function above.

dls = get_dls(resize, bs, n_subset=None)

Next step is perhaps the most critical step for achieving good results on a custom problem - data augmentation. For this, we will use utility function from self_supervised.vision.simclr.get_simclr_aug_pipelines but you can also use your own list of Pipeline augmentations. self_supervised.vision.simclr.get_moco_aug_pipelinesshould be enough for most of the cases since under the hood it uses self_supervised.augmentations.get_multi_aug_pipelines and self_supervised.augmentations.get_batch_augs. You can do shift+tab and see all the arguments that can be passed to get_simclr_aug_pipelines. You can simply pass anything that you could pass to get_batch_augs including custom xtra_tfms.

get_simclr_aug_pipelines excepts size for random resized cropping of the 2 views of a given image and the rest of the arguments are coming from get_batch_augs()

aug_pipelines = get_dino_aug_pipelines(rotate=True, 
                                       rotate_deg=10, 
                                       jitter=True, 
                                       bw=True, 
                                       blur=True,
                                       blur_s=(4, 16))

# aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
#                                        crop_sizes=[224,96], 
#                                        min_scales=[0.25,0.2],
#                                        max_scales=[1.0,0.35],
#                                        rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False) 
 

Here, we will feed the augmentation pipelines and leave temperature parameter as default.

class SaveModelCallback(TrackerCallback):
    "A `TrackerCallback` that saves the model's best during training and loads it at the end."
    _only_train_loop,order = True,TrackerCallback.order+1
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, at_end=False,
                 with_opt=False, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        assert not (every_epoch and at_end), "every_epoch and at_end cannot both be set to True"
        # keep track of file path for loggers
        self.last_saved_path = None
        store_attr('fname,every_epoch,at_end,with_opt')

    def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)

    def after_epoch(self):
        "Compare the value monitored to its best score and save if best."
        if self.every_epoch:
            if (self.epoch%self.every_epoch) == 0: self._save(f'{self.fname}_{self.epoch}')
        else: #every improvement
            super().after_epoch()
            if self.new_best:
                print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
                self._save(f'{self.fname}')

    def after_fit(self, **kwargs):
        "Load the best model."
        if self.at_end: self._save(f'{self.fname}')
        elif not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)
dino_cb = DINO(aug_pipelines=aug_pipelines,
               tpt_start=0.04,
               tpt_end=0.04,
               tpt_warmup_pct=0., 
               freeze_last_layer=1)
grad_clip_cb = GradientClip(max_norm=3., norm_type=2.)
save_cb = SaveModelCallback(every_epoch=20, with_opt=True, fname='dino_pretraining')
nan_cb = TerminateOnNaNCallback()

cbs=[dino_cb, grad_clip_cb, save_cb, nan_cb]
if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]
learn = Learner(dls, dino_model, opt_func=Adam, cbs=cbs)
learn.to_fp16();

Before starting training let's check whether our augmentations makes sense or not. Since this step consumes GPU memory, once you are done with inspection, restart the notebook and skip this step. We can see that 2 views of the same image side by side and indeed augmentations look pretty good. Now, it's time restart the notebook and skip this step.

b = dls.one_batch()
learn._split(b)
learn('before_fit')
learn('before_batch')
learn.dino.show(n=5);
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3063: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed "

Use mixed precision with to_fp16() for more GPU memory, larger batch size and faster training . We could also use gradient checkpointing wrapper models from self_supervised.layers to save even more memory, e.g. CheckpointSequential().

 

Learning good representations via contrastive learning usually takes a lot of epochs. So here number epochs are set to 100. This might change depending on your data distribution and dataset size.

max_lr = 2.5e-4
lr_sched = combine_scheds([0.1,0.9], [SchedLin(0.,max_lr), SchedCos(max_lr,1e-6)])
wd_sched = SchedCos(0.04,0.4)
param_scheduler = ParamScheduler({"lr":lr_sched, "wd":wd_sched})
learn.fit(200, cbs=[param_scheduler])
64.00% [128/200 9:24:05<5:17:18]
epoch train_loss valid_loss time
0 10.905972 10.884517 04:21
1 10.799708 10.744269 04:21
2 10.395583 10.386750 04:20
3 10.667659 10.673615 04:22
4 10.524393 10.505392 04:22
5 10.400627 10.372112 04:22
6 10.272190 10.226895 04:22
7 10.106242 10.131449 04:22
8 10.017877 9.977060 04:22
9 9.887113 9.915744 04:22
10 9.754071 9.774323 04:22
11 9.626770 9.621745 04:22
12 9.477569 9.425710 04:23
13 9.329666 9.296606 04:22
14 9.108964 9.161177 04:22
15 8.964575 9.002557 04:22
16 8.746999 8.871298 04:22
17 8.604566 8.608711 04:22
18 8.419492 8.454516 04:22
19 8.219825 8.190134 04:22
20 8.027300 8.059779 04:22
21 7.872963 7.858914 04:23
22 7.633177 7.692121 04:23
23 7.490386 7.514220 04:22
24 7.311434 7.348587 04:22
25 7.119255 7.091773 04:23
26 6.891378 7.027006 04:22
27 6.740506 6.842145 04:22
28 6.599480 6.673983 04:23
29 6.480856 6.577636 04:22
30 6.362226 6.396636 04:22
31 6.246682 6.244115 04:22
32 6.153521 6.171088 04:22
33 6.001417 6.037159 04:22
34 5.908531 5.947373 04:23
35 5.838302 5.888688 04:23
36 5.691142 5.746446 04:23
37 5.644607 5.668302 04:23
38 5.561379 5.583524 04:23
39 5.449399 5.511849 04:22
40 5.390159 5.463367 04:23
41 5.349072 5.400405 04:23
42 5.243585 5.337997 04:23
43 5.195243 5.233816 04:22
44 5.124351 5.185687 04:22
45 5.050556 5.126068 04:23
46 4.939539 5.039504 04:22
47 4.934534 5.003039 04:23
48 4.895814 4.986242 04:26
49 4.853435 4.898729 04:24
50 4.783492 4.947905 04:22
51 4.785825 4.843892 04:32
52 4.714669 4.878543 05:26
53 4.633375 4.771772 04:37
54 4.587667 4.739517 04:23
55 4.549111 4.758195 04:23
56 4.500559 4.674447 04:23
57 4.512155 4.652016 04:23
58 4.451552 4.579087 04:23
59 4.468405 4.592949 04:23
60 4.353909 4.508565 04:23
61 4.333078 4.453562 04:23
62 4.354620 4.467622 04:24
63 4.326721 4.483803 04:23
64 4.324405 4.431428 04:24
65 4.205381 4.447000 04:24
66 4.211459 4.421143 04:23
67 4.171113 4.318798 04:23
68 4.133680 4.359928 04:23
69 4.150734 4.344061 04:23
70 4.078065 4.354052 04:23
71 4.115148 4.271878 04:22
72 4.083210 4.302464 04:22
73 4.008277 4.303656 04:23
74 4.002521 4.210809 04:24
75 3.986869 4.249226 04:24
76 3.970350 4.143545 04:24
77 3.967119 4.237706 04:23
78 3.868739 4.180150 04:24
79 3.905615 4.213471 04:24
80 3.822973 4.097795 04:24
81 3.847088 4.128041 04:23
82 3.841082 4.113137 04:24
83 3.833511 4.166142 04:24
84 3.805576 4.123006 04:23
85 3.770877 4.086368 04:24
86 3.708668 4.051410 04:23
87 3.738222 4.024220 04:24
88 3.692777 3.988358 04:23
89 3.697853 4.012632 04:23
90 3.649306 4.127020 04:24
91 3.613917 4.052756 04:23
92 3.634349 4.061735 04:22
93 3.579834 4.034831 04:24
94 3.616243 3.999697 04:24
95 3.635092 3.999085 04:24
96 3.560299 4.018064 04:24
97 3.564267 3.969940 04:24
98 3.510925 4.027102 04:24
99 3.543701 4.034966 04:24
100 3.526968 3.984271 04:25
101 3.491561 3.976734 04:24
102 3.451368 3.946274 04:24
103 3.414409 3.946348 04:24
104 3.450127 3.914510 04:23
105 3.430591 3.908050 04:24
106 3.388823 3.949557 04:24
107 3.291402 3.916359 04:24
108 3.323584 3.959621 04:24
109 3.255738 4.005247 04:24
110 3.302168 3.855192 04:24
111 3.274943 3.917519 04:24
112 3.282359 3.913666 04:24
113 3.251249 3.931701 04:24
114 3.257837 3.906408 04:25
115 3.204905 3.909114 04:23
116 3.210003 3.909720 04:24
117 3.299423 3.889330 04:25
118 3.248588 3.895199 04:25
119 3.136573 3.871612 04:25
120 3.115912 3.881637 04:26
121 3.087857 3.940727 04:25
122 3.150950 3.894757 04:25
123 3.081744 3.878548 04:25
124 3.124973 3.842340 04:25
125 3.059661 3.821547 04:25
126 3.033804 3.862898 04:25
127 3.057881 3.829546 04:26

73.81% [31/42 00:14<00:05 3.0656]
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3063: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed "
Setting last layer to trainable
if WANDB: wandb.finish()

Downstream Task

# class MultiCropWrapper(nn.Module):
#     """
#     Perform forward pass separately on each resolution input.
#     The inputs corresponding to a single resolution are clubbed and single
#     forward is run on the same resolution inputs. Hence we do several
#     forward passes = number of different resolutions used. We then
#     concatenate all the output features and run the head forward on these
#     concatenated features.
#     """
#     def __init__(self, backbone, head):
#         super(MultiCropWrapper, self).__init__()
#         # disable layers dedicated to ImageNet labels classification
#         backbone.fc, backbone.head = nn.Identity(), nn.Identity()
#         self.backbone = backbone
#         self.head = head

#     def forward(self, x):
#         # convert to list
#         if not isinstance(x, list):
#             x = [x]
#         idx_crops = torch.cumsum(torch.unique_consecutive(
#             torch.tensor([inp.shape[-1] for inp in x]),
#             return_counts=True,
#         )[1], 0)
#         start_idx = 0
#         for end_idx in idx_crops:
#             _out = self.backbone(torch.cat(x[start_idx: end_idx]))
#             if start_idx == 0:
#                 output = _out
#             else:
#                 output = torch.cat((output, _out))
#             start_idx = end_idx
#         # Run the head forward on the concatenated features.
#         return self.head(output)

# state_dict = torch.load("/home/code-base/extra_space/dino/checkpoint.pth")
# student = deit_small(patch_size=16, drop_path_rate=0.1)
# student = MultiCropWrapper(student, DINOHead(student.embed_dim, 2**16, norm_last_layer=True))
# student_dict = OrderedDict({".".join(k.split(".")[1:]):v for k,v in state_dict['student'].items()})
# student.load_state_dict(student_dict)

# student = deit_small(patch_size=16, drop_path_rate=0.1)
# state_dict = torch.load("./models/model_99.pth")
# for n,p in student.named_parameters(): p.data = state_dict[f'0.encoder.{n}']
def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['train', 'val'])
    splits = GrandparentSplitter(valid_name='val')(files)
    
    item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    tfms = [[PILImage.create, ToTensor, *item_aug], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=splits)
    
    batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls
class Classifier(Module):
    def __init__(self, vit_backbone, n_feat_layers, n_classes, lin_f=1024, lin_drop=0.3, pooling='avg'):
        self.vit_backbone  = vit_backbone
        self.n_feat_layers = n_feat_layers 
        self.pooling = pooling
        out_dim = self.vit_backbone.norm.weight.size(0)
        
        if self.n_feat_layers == 1: in_f = 2*out_dim
        else:
            if pooling == 'avg':   in_f = out_dim
            elif pooling == 'cat': in_f = out_dim*n_feat_layers
        
        self.mlp = create_cls_module(in_f, n_classes)
        
    def forward(self,x):
        
        out = self.vit_backbone.get_intermediate_layers(x,self.n_feat_layers)
        
        if self.n_feat_layers == 1:
            # cat [CLS] token and avgpooled output tokens from the last layer
            cls_token, output_tokens = out[0][:,0],out[0][:,1:]
            x = torch.cat([cls_token, output_tokens.mean(1)], dim=1)
        else:
            # avgpool or cat [CLS] tokens from last n layers
            out = [o[:,0] for o in out] 
            if self.pooling == 'avg':   x = torch.stack(out,dim=0).mean(0)
            elif self.pooling == 'cat': x = torch.cat(out, 1)
            else:                       raise Exception("Pooling should be avg or cat")
                
        return self.mlp(x)
x = torch.randn(4,3,224,224).cuda()
cls_model = Classifier(student, n_feat_layers=1, n_classes=10).cuda()
cls_model(x).shape
torch.Size([4, 10])
optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)
bs, resize, size = 64, 256, 224
dls = get_dls(size, bs=bs//2)
cls_model = Classifier(student, n_feat_layers=1, n_classes=dls.c)
learn = Learner(dls, cls_model, opt_func=opt_func, metrics=[accuracy,top_k_accuracy],
                loss_func=LabelSmoothingCrossEntropy())
learn.lr_find()
SuggestedLRs(lr_min=0.001096478197723627, lr_steep=0.0012022644514217973)
def finetune(learn, epochs, lr=1e-3, wd=1e-2):
    learn.unfreeze()
    learn.fit_flat_cos(epochs, lr, wd=wd)
    final_acc = learn.recorder.values[-1][-2]
    return final_acc

5 epochs

When training with all layers trainable model underperforms very bad, which is different than than the resnet pretrained models which perform pretty good e.g. SWAV XResNet34. In contrast, when fully frozen DEIT+DINO performs decent on the downstream task. This might be related to ViT characteristics and/or easy overfitting.

state = torch.load("./models/dino_pretraining_180.pth")

student = deit_small(patch_size=16, drop_path_rate=0.1)
student_dict = {}
for k in state['model']:
    if 'student' in k:
        student_dict[".".join(k.split(".")[3:])] = state['model'][k]
student.load_state_dict(student_dict,strict=False) # strict=False ignore MLP head
_IncompatibleKeys(missing_keys=[], unexpected_keys=['0.weight', '0.bias', '2.weight', '2.bias', '4.weight', '4.bias', 'weight_g', 'weight_v'])
optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)
dls = get_dls(size, bs=96)
# cls_model = Classifier(student, n_feat_layers=1, n_classes=dls.c)
# cls_model = Classifier(student, n_feat_layers=2, n_classes=dls.c)
cls_model = Classifier(student, n_feat_layers=2, n_classes=dls.c, pooling='cat')
def model_split(model):
    groups = L([model.vit_backbone, model.mlp])
    return groups.map(params)
learn = Learner(dls, cls_model, opt_func=opt_func, metrics=[accuracy,top_k_accuracy], splitter=model_split,
                loss_func=LabelSmoothingCrossEntropyFlat())
learn.freeze();
learn.lr_find()
SuggestedLRs(lr_min=0.03019951581954956, lr_steep=0.00363078061491251)
learn.freeze();
learn.fit_flat_cos(3,1e-2,wd=1e-2)
learn.unfreeze();
learn.fit_flat_cos(2,slice(1e-4,1e-3),wd=1e-2)
epoch train_loss valid_loss accuracy top_k_accuracy time
0 1.173379 2.006262 0.444897 0.875795 00:58
1 1.131220 2.043696 0.443115 0.881140 00:58
2 1.067292 1.727473 0.568592 0.903792 00:58
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.995551 1.671715 0.579537 0.923136 01:08
1 0.963949 1.623196 0.603716 0.925426 01:07

fin