Note: This notebook demonstrates how to use SupCon callback with a single GPU. For distributed version, DistributedSupCon checkout documentation.

First import fastai for training and other helpers, you can choose not to use wandb by setting WANDB=False.

from fastai.vision.all import *
torch.backends.cudnn.benchmark = True
WANDB = False
if WANDB:
    try:
        from fastai.callback.wandb import WandbCallback
        import wandb
    except:
        raise ImportError("Please run '!pip install wandb' on another cell to install wandb")

Then import self_supervised augmentations module for creating augmentations pipeline, layers module for creating encoder and model, and finally simclr for self-supervised training.

from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.supcon import *
from self_supervised.vision.metrics import *
import timm

In this notebook we will take a look at ImageWang benchmark, how to train a self-supervised model using MoCo algorithm and then how to use this pretrained model for finetuning on the given downstream task.

Pretraining

def get_dls(size, bs, n_per_class=30, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['unsup', 'train'])

    
    labels = [o.parent.name for i,o in enumerate(files)]
    split_df = pd.DataFrame(labels, columns=['label']).reset_index()
    valid_idxs = split_df.query("label != 'unsup'").groupby("label").sample(n_per_class)['index'].values
    split_df['is_valid'] = False
    split_df.loc[split_df['index'].isin(valid_idxs), 'is_valid'] = True
    train_idxs = split_df[~split_df.is_valid]['index'].values
    valid_idxs = split_df[split_df.is_valid]['index'].values

    
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    dsets = Datasets(files, tfms=tfms, splits=[train_idxs, valid_idxs])
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

ImageWang has several benchmarks for different image sizes, in this tutorial we will go for size=224 and also demonstrate how effectively you can utilize GPU memory.

Define batch size, resize resolution before batching and size for random cropping during self-supervised training. It's always good to use a batch size as high as it can fit the GPU memory.

bs, resize, size = 256, 256, 224
path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
source = untar_data(path)
files = get_image_files(source)
Counter([(o.parent.parent.name,o.parent.name) for o in files])
Counter({('imagewang', 'unsup'): 7750,
         ('val', 'n02089973'): 224,
         ('val', 'n02096294'): 407,
         ('val', 'n02086240'): 409,
         ('val', 'n02105641'): 422,
         ('val', 'n02099601'): 401,
         ('val', 'n02087394'): 408,
         ('val', 'n02088364'): 418,
         ('val', 'n02093754'): 401,
         ('val', 'n02115641'): 410,
         ('val', 'n02111889'): 429,
         ('train', 'n02089973'): 75,
         ('train', 'n02096294'): 147,
         ('train', 'n02086240'): 142,
         ('train', 'n03000684'): 1244,
         ('train', 'n01440764'): 1350,
         ('train', 'n02979186'): 1350,
         ('train', 'n02102040'): 1350,
         ('train', 'n02105641'): 126,
         ('train', 'n03394916'): 1350,
         ('train', 'n03445777'): 1350,
         ('train', 'n03417042'): 1350,
         ('train', 'n03425413'): 1350,
         ('train', 'n02099601'): 127,
         ('train', 'n03028079'): 1350,
         ('train', 'n02087394'): 119,
         ('train', 'n02088364'): 141,
         ('train', 'n02093754'): 146,
         ('train', 'n03888257'): 1350,
         ('train', 'n02115641'): 114,
         ('train', 'n02111889'): 138})

Select architecture to train on, remember all timm and fastai models are available! We need to set pretrained=False here because using imagenet weights for ImageWang data would be cheating.

arch = "resnet50d"
encoder = create_encoder(arch, pretrained=False, n_in=3)

# arch = "resnet34d"
# encoder = CheckpointResNet(create_encoder(arch, pretrained=False, n_in=3), checkpoint_nchunks=2)
if WANDB:
    xtra_config = {"Arch":arch, "Resize":resize, "Size":size, "Algorithm":"SupCon"}
    wandb.init(project="self-supervised-imagewang", config=xtra_config);

Initialize the Dataloaders using the function above.

dls = get_dls(resize, bs)

Create SupCon model. You can change values of hidden_size, projection_size, and n_layers. For this problem, defaults work just fine so we don't do any changes.

model = create_supcon_model(encoder)

Next step is perhaps the most critical step for achieving good results on a custom problem - data augmentation. For this, we will use utility function from self_supervised.vision.supcon.get_simclr_aug_pipelines but you can also use your own list of Pipeline augmentations. self_supervised.vision.supcon.get_moco_aug_pipelinesshould be enough for most of the cases since under the hood it uses self_supervised.augmentations.get_multi_aug_pipelines and self_supervised.augmentations.get_batch_augs. You can do shift+tab and see all the arguments that can be passed to get_supcon_aug_pipelines. You can simply pass anything that you could pass to get_batch_augs including custom xtra_tfms.

get_supcon_aug_pipelines excepts size for random resized cropping of the 2 views of a given image and the rest of the arguments are coming from get_batch_augs()

aug_pipelines = get_supcon_aug_pipelines(size, rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False) 

Here, we will feed the augmentation pipelines and leave temperature parameter as default.

cbs=[SupCon(aug_pipelines, 
            unsup_class_id = dls.vocab.o2i['unsup'], 
            unsup_method = UnsupMethod.All, 
            reg_lambda = 1.0, 
            temp = 0.07)]

cbs=[SupConMOCO(aug_pipelines, 
                unsup_class_id = dls.vocab.o2i['unsup'], 
                unsup_method = UnsupMethod.All, 
                K=4096,
                m=0.999,
                reg_lambda = 1.0, 
                temp = 0.07)]

if WANDB: cbs += [WandbCallback(log_preds=False, log_model=False)]
knn_metric_cb = KNNProxyMetric()
cbs += [knn_metric_cb]
metric = ValueMetric(knn_metric_cb.accuracy, metric_name='knn_accuracy')
learn = Learner(dls, model, opt_func=Lamb, cbs=cbs, metrics=metric)

Before starting training let's check whether our augmentations makes sense or not. Since this step consumes GPU memory, once you are done with inspection, restart the notebook and skip this step. We can see that 2 views of the same image side by side and indeed augmentations look pretty good. Now, it's time restart the notebook and skip this step.

b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.sup_con.show(n=5);

Use mixed precision with to_fp16() for more GPU memory, larger batch size and faster training . We could also use gradient checkpointing wrapper models from self_supervised.layers to save even more memory, e.g. CheckpointSequential().

learn.to_fp16();

Learning good representations via contrastive learning usually takes a lot of epochs. So here number epochs are set to 100. This might change depending on your data distribution and dataset size.

lr,wd,epochs=1e-2,1e-2,100
learn.unfreeze()
learn.fit_flat_cos(epochs, lr, wd=wd, pct_start=0.5)
if WANDB: wandb.finish()

Waiting for W&B process to finish... (success).

Run history:


epoch▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
knn_accuracy▁▂▁▃▃▄▃▁▅▄▅▆▅▇▇▇▆▆▆▆▅█▅▇█▆▇▇
lr_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss█▇▅▆▆▅▅▄▄▄▃▃▄▄▄▄▄▃▂▃▃▃▃▂▂▂▃▂▂▆▂▂▂▁▂▂▁▁▄▂
sqr_mom_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss█▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
valid_loss██▆▅▆▅▃▅▅▄▄▃▆▄▃▂▂▁▄▂▆▄▂▁▅▁▂▁
wd_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


epoch28.25
eps_01e-05
knn_accuracy0.18167
lr_00.001
mom_00.9
raw_loss7.75702
sqr_mom_00.99
train_loss7.72455
valid_loss8.91475
wd_00.0001

Synced likely-jazz-70: https://wandb.ai/keremturgutlu/self-supervised-imagewang/runs/2j8iph1f
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20220308_023935-2j8iph1f/logs
save_name = f'supcon_iwang_sz{size}_epc{epochs}'
learn.save(save_name)
# torch.save(learn.model.encoder.state_dict(), learn.path/learn.model_dir/f'{save_name}_encoder.pth')
torch.save(learn.model.encoder.resnet_model.state_dict(), learn.path/learn.model_dir/f'{save_name}_encoder.pth')
learn.recorder.plot_loss()
wandb.run.name
'bumbling-paper-72'

Downstream Task

optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)
bs = 128
size = 256
def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['train', 'val'])
    splits = GrandparentSplitter(valid_name='val')(files)
    
    item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    tfms = [[PILImage.create, ToTensor, *item_aug], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=splits)
    
    batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls
def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(size=size, arch='xresnet34', encoder_path="models/swav_iwang_sz128_epc100_encoder.pth"):
    
    dls = get_dls(size, bs=bs//2)
    pretrained_encoder = torch.load(encoder_path)
    encoder = create_encoder(arch, pretrained=False, n_in=3)
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,224,224)).size(-1)
    classifier = create_cls_module(nf, dls.c)
    model = nn.Sequential(encoder, classifier, Flatten())
    learn = Learner(dls, model, opt_func=opt_func, splitter=split_func,
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropyFlat())
    return learn
def finetune(size, epochs, arch, encoder_path, lr=1e-2, wd=1e-2):
    learn = create_learner(size, arch, encoder_path)
    learn.unfreeze()
#     learn.freeze()
    learn.fit_flat_cos(epochs, lr, wd=wd)
    final_acc = learn.recorder.values[-1][-2]
    return final_acc

5 epochs

acc = []
runs = 5
for i in range(runs): acc += [finetune(size, epochs=5, arch='resnet50d', encoder_path=f'models/true-sweep-41_encoder.pth')]
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.891920 2.030339 0.488165 0.848307 01:25
1 0.895664 2.121488 0.499618 0.815984 01:23
2 0.852885 1.861526 0.561721 0.896411 01:23
3 0.819235 1.766511 0.596844 0.926190 01:23
4 0.745480 1.537153 0.652329 0.929753 01:23
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.880797 2.130177 0.483075 0.844999 01:23
1 0.890207 1.990296 0.500382 0.856452 01:23
2 0.845114 2.021743 0.513362 0.867142 01:23
3 0.828395 2.073368 0.492492 0.843981 01:23
4 0.745889 1.514073 0.674726 0.930008 01:23
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.889642 1.790089 0.573683 0.896411 01:23
1 0.878387 1.874182 0.552812 0.901502 01:23
2 0.860759 2.032588 0.494019 0.829982 01:23
3 0.821018 1.880166 0.542123 0.876559 01:23
4 0.733482 1.522297 0.667345 0.937898 01:23
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.888339 2.043426 0.500382 0.848562 01:25
1 0.888483 2.042721 0.477221 0.846780 01:23
2 0.847030 2.154033 0.477475 0.839654 01:23
3 0.818787 1.850598 0.545940 0.890557 01:24
4 0.751503 1.554493 0.656401 0.931280 01:24
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.880823 2.091427 0.483584 0.887758 01:23
1 0.891838 1.925979 0.524561 0.878341 01:23
2 0.855217 1.836971 0.542123 0.890812 01:23
3 0.820396 1.849614 0.542886 0.882413 01:23
4 0.748284 1.551022 0.654620 0.931026 01:23
np.mean(acc)
0.661084246635437

20 epochs

acc = []
runs = 3
for i in range(runs): acc += [finetune(size, epochs=20, arch='xresnet34', encoder_path=f'models/confused-sweep-13_encoder.pth')]
np.mean(acc)

80 epochs

acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=80, arch='xresnet34',encoder_path=f'models/confused-sweep-13_encoder.pth')]
np.mean(acc)

200 epochs

acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=200, arch='xresnet34', encoder_path=f'models/confused-sweep-13_encoder.pth')]
66.50% [133/200 1:40:14<50:29]
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.869703 1.603625 0.638330 0.908628 00:43
1 0.738499 1.532995 0.658946 0.923390 00:43
2 0.709089 1.545185 0.664291 0.917282 00:43
3 0.705038 1.519349 0.665564 0.927717 00:43
4 0.701668 1.636940 0.622550 0.918809 00:46
5 0.695683 1.543493 0.663528 0.916009 00:50
6 0.686737 1.539560 0.671163 0.923136 00:44
7 0.680906 1.498526 0.675999 0.921354 00:44
8 0.677552 1.508487 0.667854 0.922627 00:43
9 0.674518 1.519361 0.670654 0.919063 00:46
10 0.665412 1.550280 0.655129 0.921100 00:44
11 0.665843 1.494319 0.669891 0.917791 00:45
12 0.660519 1.488854 0.681344 0.921354 00:47
13 0.659262 1.546521 0.662510 0.918045 00:44
14 0.655583 1.676335 0.629677 0.891066 00:46
15 0.653679 1.546218 0.658692 0.915246 00:44
16 0.654734 1.507556 0.670400 0.916518 00:44
17 0.652892 1.518740 0.672436 0.917536 00:47
18 0.649217 1.471843 0.682362 0.928735 00:44
19 0.647636 1.532747 0.665309 0.912955 00:45
20 0.647990 1.544863 0.659710 0.917791 00:46
21 0.645690 1.555919 0.664546 0.917536 00:44
22 0.644248 1.616540 0.645457 0.902774 00:46
23 0.643940 1.533293 0.671927 0.913719 00:43
24 0.643079 1.493405 0.677526 0.923136 00:44
25 0.641105 1.560501 0.656656 0.913973 00:46
26 0.641650 1.511677 0.673454 0.916009 00:44
27 0.640307 1.575970 0.652583 0.915246 00:45
28 0.639157 1.613817 0.646475 0.910410 00:47
29 0.639705 1.540645 0.664800 0.920336 00:44
30 0.640183 1.520499 0.672690 0.918300 00:44
31 0.638371 1.559659 0.652329 0.907356 00:44
32 0.638882 1.590857 0.653092 0.909901 00:44
33 0.635947 1.520144 0.673454 0.920590 00:47
34 0.637217 1.609845 0.641130 0.906592 00:44
35 0.636331 1.520351 0.675235 0.914482 00:44
36 0.634860 1.532848 0.670145 0.914228 00:45
37 0.637789 1.482974 0.684144 0.919318 00:44
38 0.637334 1.591692 0.644693 0.910410 00:45
39 0.636380 1.531559 0.669891 0.915246 00:43
40 0.636859 1.627369 0.646475 0.903029 00:44
41 0.634279 1.553971 0.655638 0.917027 00:46
42 0.631529 1.565562 0.654365 0.908628 00:43
43 0.630301 1.529892 0.671418 0.918809 00:44
44 0.634323 1.592670 0.651311 0.904301 00:46
45 0.633182 1.480278 0.687707 0.922372 00:43
46 0.632433 1.554290 0.661746 0.908883 00:46
47 0.632967 1.534170 0.668618 0.914228 00:46
48 0.629657 1.565951 0.660982 0.902774 00:44
49 0.628943 1.576432 0.659455 0.907356 00:46
50 0.630356 1.543911 0.665055 0.915246 00:44
51 0.630036 1.553484 0.669381 0.904810 00:44
52 0.629229 1.583017 0.653601 0.909392 00:45
53 0.629157 1.551492 0.666073 0.909646 00:44
54 0.629943 1.549929 0.670654 0.909901 00:46
55 0.629232 1.532147 0.669127 0.912700 00:45
56 0.630214 1.550726 0.662255 0.921863 00:44
57 0.628146 1.583061 0.664800 0.910155 00:47
58 0.628307 1.603759 0.646729 0.899975 00:43
59 0.626960 1.559176 0.662764 0.913464 00:44
60 0.629398 1.533982 0.667600 0.911173 00:47
61 0.627043 1.605522 0.648257 0.908628 00:43
62 0.629501 1.600642 0.653092 0.909646 00:44
63 0.628051 1.556016 0.668618 0.917282 00:44
64 0.628616 1.535497 0.667091 0.911173 00:43
65 0.628973 1.554626 0.662001 0.911428 00:45
66 0.628108 1.609547 0.642657 0.894630 00:43
67 0.627201 1.549583 0.662001 0.912446 00:43
68 0.630659 1.537625 0.667091 0.908374 00:47
69 0.627197 1.593526 0.650293 0.902265 00:44
70 0.625676 1.612805 0.644184 0.906847 00:43
71 0.626757 1.589640 0.650038 0.913464 00:45
72 0.626163 1.597183 0.653092 0.906847 00:43
73 0.626740 1.579899 0.657419 0.911173 00:45
74 0.624481 1.613326 0.649529 0.902520 00:45
75 0.622779 1.542646 0.670654 0.909901 00:43
76 0.625912 1.700185 0.630186 0.882158 00:45
77 0.624742 1.543221 0.670909 0.915246 00:43
78 0.625231 1.602534 0.653092 0.899975 00:45
79 0.625685 1.673348 0.636040 0.895648 00:44
80 0.625234 1.629006 0.644948 0.903283 00:43
81 0.623205 1.648422 0.643421 0.886994 00:44
82 0.625121 1.685609 0.624332 0.901502 00:43
83 0.624947 1.617991 0.648257 0.903029 00:44
84 0.624974 1.550486 0.668363 0.917536 00:45
85 0.624651 1.582956 0.668109 0.895902 00:43
86 0.625576 1.700037 0.630949 0.883431 00:44
87 0.624592 1.609895 0.645966 0.901502 00:45
88 0.625317 1.632035 0.658946 0.886231 00:43
89 0.621874 1.582937 0.648257 0.903029 00:44
90 0.626025 1.598742 0.659455 0.892848 00:43
91 0.624505 1.629864 0.651311 0.898193 00:43
92 0.624303 1.718987 0.621023 0.890557 00:45
93 0.624256 1.644525 0.641894 0.907356 00:43
94 0.621443 1.555878 0.661237 0.912446 00:43
95 0.620402 1.616238 0.645711 0.893866 00:45
96 0.624698 1.622719 0.645711 0.900738 00:43
97 0.624898 1.656592 0.638839 0.905574 00:47
98 0.623742 1.645989 0.641130 0.902774 00:44
99 0.621856 1.625195 0.646984 0.892594 00:44
100 0.624188 1.676001 0.634767 0.894630 00:46
101 0.622995 1.578023 0.657165 0.907865 00:44
102 0.622677 1.599621 0.652329 0.903029 00:43
103 0.619904 1.594792 0.655383 0.895393 00:46
104 0.620121 1.643322 0.649275 0.890812 00:43
105 0.620024 1.587936 0.653347 0.905828 00:45
106 0.621057 1.604326 0.660219 0.895393 00:44
107 0.621163 1.584998 0.663273 0.907865 00:44
108 0.620983 1.603610 0.656147 0.904556 00:45
109 0.620239 1.637596 0.647238 0.894375 00:43
110 0.620139 1.673262 0.637567 0.890812 00:43
111 0.619876 1.731578 0.616951 0.870450 00:46
112 0.619817 1.775249 0.609824 0.882413 00:44
113 0.619827 1.544595 0.671672 0.916009 00:45
114 0.619299 1.645119 0.645711 0.905574 00:49
115 0.620907 1.712692 0.625095 0.883940 00:43
116 0.619116 1.605237 0.657419 0.893866 00:46
117 0.620825 1.602566 0.656656 0.902774 00:47
118 0.623467 1.669477 0.628913 0.898956 00:44
119 0.621086 1.641958 0.650547 0.894375 00:49
120 0.620659 1.683734 0.639094 0.891321 00:47
121 0.620320 1.571884 0.671418 0.902774 00:46
122 0.619457 1.696718 0.641130 0.900484 00:49
123 0.622396 1.612049 0.648766 0.909646 00:45
124 0.625176 1.751117 0.617969 0.881649 00:47
125 0.622359 1.547589 0.669381 0.913209 00:48
126 0.631657 1.654508 0.644184 0.902265 00:48
127 0.627171 1.564172 0.661237 0.911173 00:46
128 0.624255 1.537886 0.672436 0.914991 00:47
129 0.619449 1.555512 0.664037 0.912700 00:47
130 0.621423 1.786124 0.612879 0.875286 00:43
131 0.624269 1.628261 0.644184 0.892594 00:44
132 0.623737 1.580078 0.656656 0.904556 00:44

48.35% [44/91 00:18<00:19 0.6242]
np.mean(acc)