SwAV ImageWang Tutorial

First import fastai for training and other helpers, you can choose not to use wandb by setting WANDB=False.

from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
import wandb

WANDB = False

Then import self_supervised augmentations module for creating augmentations pipeline, layers module for creating encoder and model, and finally swav for self-supervised training.

from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *

In this notebook we will take a look at ImageWang benchmark, how to train a self-supervised model using SwAV algorithm and then how to use this pretrained model for finetuning on the given downstream task.

Warning: This notebook actually reaches best public leaderboard score 😃

Pretraining

def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    
    files = get_image_files(source)
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
    
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

ImageWang has several benchmarks for different image sizes, in this tutorial we will go for size=224 and also demonstrate how effectively you can utilize GPU memory.

Define batch size, resize resolution before batching and size for random cropping during self-supervised training.

bs, resize, size = 96, 256, 224

Set queue size, it needs to be a multiple of batch size.

K = bs*2**4

Select architecture to train on, remember all timm and fastai models are available! We need to set pretrained=False here because using imagenet weights for ImageWang data would be cheating.

arch = "xresnet34"
encoder = create_encoder(arch, pretrained=False, n_in=3)
if WANDB:
    xtra_config = {"Arch":arch, "Resize":resize, "Size":size, "Algorithm":"SWAV"}
    wandb.init(project="self-supervised-imagewang", config=xtra_config);

Initialize the Dataloaders using the function above.

dls = get_dls(resize, bs)

Create SwAV model. You can change values of hidden_size, projection_size, n_protos - number of prototypes/psuedo classes for cluster assignment. If defaults are not working for your problem try changing n_protos first. For this problem, defaults work just fine so we don't do any changes.

model = create_swav_model(encoder)

Next step is perhaps the most critical step for achieving good results on a custom problem - data augmentation. For this, we will use utility function from self_supervised.vision.swav.get_swav_aug_pipelines but you can also use your own list of Pipeline augmentations. self_supervised.vision.swav.get_swav_aug_pipelinesshould be enough for most of the cases since under the hood it uses self_supervised.augmentations.get_multi_aug_pipelines and self_supervised.augmentations.get_batch_augs. You can do shift+tab and see all the arguments that can be passed to get_swav_aug_pipelines. You can simply pass anything that you could pass to get_batch_augs including custom xtra_tfms.

get_swav_aug_pipelines excepts certain arguments, for more detail please visit SwAV's documentation. To briefly explain, here we define 2 large crops and 6 smaller crops which are in size and int(3/4*size) pixel resolution respectively. We also set RandomResizedCrop scales for large crops to (0.25 - 1.0) and for smaller crops to (0.2, 0.35). Rest of the arguments are coming from get_batch_augs()

aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[size,int(3/4*size)], 
                                       min_scales=[0.25,0.2],
                                       max_scales=[1.0,0.35],
                                       rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False) 

Finally we need to pass the indexes of large crops to crop_assgn_ids, since we defined the first 2 crops as large in our aug_pipelines indexes will be 0 and 1. I also set queue start to 0.5, I found it to be a good value.

cbs=[SWAV(aug_pipelines, crop_assgn_ids=[0,1], K=K, queue_start_pct=0.5, temp=0.1)]
if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]
learn = Learner(dls, model, cbs=cbs)

Before starting training let's check whether our augmentations makes sense or not. Since this step consumes GPU memory, once you are done with inspection, restart the notebook and skip this step. We can see that first 2 crops are the larger ones and indeed augmentations look pretty good. Now, it's time restart the notebook and skip this step.

b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.swav.show(n=5);

For more GPU memory, larger batch size and faster training use mixed precision with to_fp16(). We could also use gradient checkpointing wrapper models from self_supervised.layers to save even more memory, e.g. CheckpointSequential().

learn.to_fp16();

Learning good representations via contrastive learning usually takes a lot of epochs. So here number epochs are set to 100. This might change depending on your data distribution and dataset size.

lr,wd,epochs=1e-2,1e-2,100
learn.unfreeze()
learn.fit_flat_cos(epochs, lr, wd=wd, pct_start=0.5)
epoch train_loss valid_loss time
0 7.407468 7.684866 02:58
1 7.030927 7.189878 02:52
2 6.837667 6.869849 02:52
3 6.633281 7.156127 02:52
4 6.468298 6.766192 02:54
5 6.345496 6.350045 02:53
6 6.241495 6.387009 02:54
7 6.162815 6.393086 02:53
8 6.081127 6.323627 02:53
9 5.999241 6.243154 02:51
10 5.947983 6.122771 02:52
11 5.914664 6.117457 02:54
12 5.857216 6.057860 02:53
13 5.806761 5.971296 02:53
14 5.781459 5.844307 02:54
15 5.732420 5.866521 02:52
16 5.689071 5.923126 02:52
17 5.633221 5.761728 02:52
18 5.633708 5.741578 02:54
19 5.612320 5.878681 02:53
20 5.601634 5.753514 02:54
21 5.603115 5.957873 02:53
22 5.595489 5.819541 02:54
23 5.564273 5.694228 02:52
24 5.563308 5.734323 02:53
25 5.508146 5.775989 02:52
26 5.511865 5.661479 02:52
27 5.457837 5.636052 02:53
28 5.458113 5.703008 02:55
29 5.509527 5.668771 02:55
30 5.453142 5.632955 02:52
31 5.437060 5.600440 02:52
32 5.414873 5.655412 02:52
33 5.438187 5.605145 02:54
34 5.425883 5.540411 02:53
35 5.393618 5.595127 02:53
36 5.383820 5.646406 02:51
37 5.359114 5.649547 02:53
38 5.367377 5.577794 02:52
39 5.385077 5.616859 02:52
40 5.342968 5.498207 02:54
41 5.334440 5.522304 02:52
42 5.354709 5.487960 02:53
43 5.324558 5.482665 02:53
44 5.304360 5.542214 02:53
45 5.303082 5.545982 02:53
46 5.310629 5.506954 02:52
47 5.342871 5.555840 02:52
48 5.294147 5.543742 02:52
49 5.285220 5.016583 02:54
50 5.312743 5.102065 02:52
51 4.697960 4.662322 02:53
52 4.663615 4.661558 02:53
53 4.589036 4.624248 02:53
54 4.580728 4.602210 02:53
55 4.517045 4.611333 02:53
56 4.486736 4.509859 02:54
57 4.470596 4.445058 02:53
58 4.452981 4.549795 02:53
59 4.361004 4.419662 02:52
60 4.346026 4.429433 02:55
61 4.335344 4.421149 02:54
62 4.288753 4.380322 02:53
63 4.258289 4.367942 02:53
64 4.234794 4.386373 02:53
65 4.215182 4.296172 02:54
66 4.191622 4.322562 02:57
67 4.146343 4.254191 03:00
68 4.144763 4.256152 02:59
69 4.069776 4.200845 02:58
70 4.058836 4.132379 02:59
71 4.008277 4.090577 02:59
72 3.996061 4.063117 02:59
73 3.952912 4.138320 03:00
74 3.966690 4.123882 03:01
75 3.901841 4.118014 03:00
76 3.891283 4.008917 03:00
77 3.862275 4.020964 02:59
78 3.793281 4.000199 02:59
79 3.811054 4.012441 02:58
80 3.752578 4.049435 03:00
81 3.709358 3.974121 02:59
82 3.715454 3.955640 02:59
83 3.677514 3.925742 02:58
84 3.649344 3.969006 03:00
85 3.651184 3.942744 02:59
86 3.609802 3.882123 02:58
87 3.564646 3.905103 02:59
88 3.554883 3.844343 02:59
89 3.517262 3.841999 02:58
90 3.512948 3.883128 03:00
91 3.511761 3.817661 03:00
92 3.476729 3.838302 03:00
93 3.459274 3.869549 02:59
94 3.466088 3.845832 03:00
95 3.434940 3.850080 03:00
96 3.448100 3.875105 03:00
97 3.446653 3.868119 03:01
98 3.418090 3.830288 03:02
99 3.448179 3.851897 03:01
if WANDB: wandb.finish()
save_name = f'swav_iwang_sz{size}_epc{epochs}'
learn.save(save_name)
torch.save(learn.model.encoder.state_dict(), learn.path/learn.model_dir/f'{save_name}_encoder.pth')
learn.recorder.plot_loss()

Downstream Task

optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)
bs, size
(96, 224)
def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['train', 'val'])
    splits = GrandparentSplitter(valid_name='val')(files)
    
    item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    tfms = [[PILImage.create, ToTensor, *item_aug], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=splits)
    
    batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls
def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(size=size, arch='xresnet34', encoder_path="models/swav_iwang_sz128_epc100_encoder.pth"):
    
    dls = get_dls(size, bs=bs//2)
    pretrained_encoder = torch.load(encoder_path)
    encoder = create_encoder(arch, pretrained=False, n_in=3)
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,224,224)).size(-1)
    classifier = create_cls_module(nf, dls.c)
    model = nn.Sequential(encoder, classifier)
    learn = Learner(dls, model, opt_func=opt_func, splitter=split_func,
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
    return learn
def finetune(size, epochs, arch, encoder_path, lr=1e-2, wd=1e-2):
    learn = create_learner(size, arch, encoder_path)
    learn.unfreeze()
    learn.fit_flat_cos(epochs, lr, wd=wd)
    final_acc = learn.recorder.values[-1][-2]
    return final_acc

5 epochs

acc = []
runs = 5
for i in range(runs): acc += [finetune(size, epochs=5, arch='xresnet34', encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.983487 1.558032 0.711631 0.921354 00:58
1 0.907253 1.474977 0.730211 0.930517 00:58
2 0.878498 1.430007 0.734284 0.932298 01:00
3 0.826681 1.326239 0.759990 0.948333 00:58
4 0.735671 1.216653 0.794350 0.960295 00:58
epoch train_loss valid_loss accuracy top_k_accuracy time
0 1.002475 1.796736 0.648002 0.900229 00:59
1 0.908794 1.402081 0.744719 0.931535 00:58
2 0.858086 1.445738 0.716722 0.933062 00:58
3 0.811090 1.300391 0.776024 0.954441 00:59
4 0.740800 1.217372 0.798167 0.957750 00:58
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.986952 1.708992 0.659964 0.904301 00:59
1 0.937269 1.404593 0.747264 0.942479 00:59
2 0.864935 1.405310 0.750318 0.939425 00:58
3 0.810867 1.288793 0.768389 0.949860 00:58
4 0.742714 1.219064 0.803003 0.958514 00:59
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.968935 1.604818 0.687198 0.940952 00:58
1 0.924613 1.683160 0.633749 0.912955 00:56
2 0.844297 1.621276 0.671418 0.914228 00:55
3 0.825853 1.351661 0.755918 0.949351 00:55
4 0.743408 1.207588 0.789259 0.955968 00:56
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.965100 1.479348 0.741156 0.943752 00:56
1 0.920319 1.532363 0.723085 0.930262 00:55
2 0.845511 1.352898 0.751082 0.940188 00:54
3 0.820847 1.341636 0.766607 0.951642 00:55
4 0.740365 1.209206 0.793332 0.957241 00:53
np.mean(acc)
0.7956223011016845

20 epochs

acc = []
runs = 3
for i in range(runs): acc += [finetune(size, epochs=20, arch='xresnet34', encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.971454 1.545572 0.708323 0.921354 00:54
1 0.927443 1.499802 0.721303 0.939679 00:54
2 0.873696 1.475774 0.729448 0.940443 00:54
3 0.803617 1.430663 0.724103 0.944261 00:54
4 0.803485 1.309435 0.775006 0.945024 00:54
5 0.766212 1.365433 0.744464 0.941715 00:54
6 0.753622 1.369751 0.731738 0.941970 00:54
7 0.743506 1.371420 0.760499 0.940697 00:56
8 0.726871 1.382386 0.747518 0.948587 00:53
9 0.720248 1.393038 0.751082 0.943752 00:54
10 0.722475 1.423445 0.734029 0.932298 00:54
11 0.721780 1.344736 0.767371 0.945533 00:54
12 0.714426 1.463007 0.725884 0.931026 00:54
13 0.715922 1.395860 0.765080 0.943752 00:54
14 0.702909 1.322655 0.772970 0.945024 00:54
15 0.690696 1.275325 0.783405 0.944515 00:54
16 0.681320 1.292186 0.774497 0.942479 00:54
17 0.664168 1.251853 0.789005 0.948842 00:54
18 0.646135 1.225087 0.788496 0.954187 00:54
19 0.642103 1.224002 0.791550 0.953932 00:54
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.963382 1.713247 0.656656 0.932044 00:56
1 0.931585 1.536007 0.724357 0.937643 00:55
2 0.851089 1.420539 0.756172 0.949606 00:54
3 0.817316 1.328225 0.764062 0.938407 00:54
4 0.793688 1.371521 0.735811 0.932044 00:54
5 0.778133 1.326565 0.766353 0.949351 00:56
6 0.748602 1.449237 0.734029 0.937134 00:55
7 0.751250 1.322996 0.754390 0.945024 00:56
8 0.734199 1.411296 0.729702 0.936880 00:55
9 0.727239 1.264135 0.772970 0.957496 00:56
10 0.721183 1.405088 0.736320 0.937643 00:57
11 0.718222 1.369370 0.744973 0.950624 00:57
12 0.720674 1.349388 0.756427 0.939425 00:56
13 0.712852 1.440273 0.737592 0.945024 00:56
14 0.718118 1.347861 0.755918 0.940952 00:56
15 0.696110 1.370475 0.753627 0.941970 00:58
16 0.677008 1.312105 0.764826 0.945279 00:56
17 0.650657 1.244593 0.789768 0.952151 00:56
18 0.643058 1.212385 0.795877 0.956986 00:55
19 0.639072 1.211426 0.794095 0.955205 00:56
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

75.00% [15/20 13:40<04:33]
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.985622 1.549440 0.721812 0.921863 00:56
1 0.917132 1.575578 0.733265 0.920081 00:56
2 0.857101 1.328448 0.764571 0.956986 00:54
3 0.824363 1.346856 0.765589 0.946551 00:54
4 0.794153 1.344439 0.756681 0.948842 00:54
5 0.771290 1.381032 0.756681 0.947569 00:54
6 0.761646 1.365927 0.754645 0.947315 00:54
7 0.734839 1.393742 0.755154 0.936371 00:54
8 0.744643 1.329592 0.759226 0.945788 00:54
9 0.723056 1.427026 0.728430 0.933825 00:54
10 0.721213 1.308481 0.767625 0.941461 00:54
11 0.721407 1.414454 0.731229 0.928990 00:55
12 0.715903 1.535586 0.701705 0.921609 00:54
13 0.713083 1.468722 0.715704 0.933316 00:54
14 0.712866 1.331187 0.764571 0.946297 00:54

70.16% [214/305 00:31<00:13 0.7002]
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

80 epochs

acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=80, arch='xresnet34',encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.969396 1.548739 0.710104 0.932553 00:54
1 0.910081 1.519832 0.713668 0.939170 00:55
2 0.853625 1.308844 0.771443 0.944770 00:54
3 0.823330 1.447318 0.729702 0.926190 00:55
4 0.789518 1.322515 0.770171 0.954950 00:54
5 0.774128 1.364938 0.759481 0.939170 00:55
6 0.758874 1.454771 0.720031 0.929244 00:55
7 0.747683 1.323867 0.751082 0.946042 00:54
8 0.731047 1.421696 0.727921 0.933316 00:54
9 0.729223 1.306388 0.767116 0.949096 00:54
10 0.735498 1.387965 0.752100 0.941715 00:55
11 0.727468 1.409651 0.741410 0.926953 00:54
12 0.701236 1.309646 0.766862 0.933825 00:54
13 0.715139 1.500116 0.726393 0.941206 00:54
14 0.715735 1.333505 0.757445 0.942988 00:54
15 0.707127 1.346936 0.762535 0.933062 00:55
16 0.695886 1.401636 0.750064 0.941461 00:55
17 0.705250 1.397648 0.738610 0.935607 00:54
18 0.701646 1.428903 0.735302 0.930008 00:54
19 0.692144 1.318277 0.764571 0.940443 00:54
20 0.687428 1.427480 0.732502 0.928990 00:55
21 0.706741 1.534121 0.699160 0.914482 00:54
22 0.687827 1.362269 0.748028 0.929499 00:54
23 0.698861 1.325396 0.758717 0.931535 00:54
24 0.688978 1.514965 0.720794 0.911428 00:54
25 0.688998 1.426650 0.729448 0.934334 00:54
26 0.690303 1.405096 0.744464 0.932044 00:54
27 0.694694 1.494388 0.731229 0.928735 00:54
28 0.693535 1.367281 0.749809 0.941461 00:54
29 0.678336 1.520411 0.703741 0.908374 00:54
30 0.687282 1.845479 0.634004 0.897429 00:54
31 0.686466 1.446187 0.721303 0.931280 00:53
32 0.683836 1.432986 0.730720 0.923390 00:54
33 0.681060 1.454526 0.721558 0.928226 00:53
34 0.687230 1.506918 0.718503 0.923390 00:53
35 0.684928 1.485413 0.709341 0.926953 00:55
36 0.680879 1.605373 0.684144 0.899211 00:54
37 0.673911 1.479267 0.710613 0.917282 00:53
38 0.686266 1.418930 0.732756 0.920845 00:53
39 0.698295 1.533268 0.703741 0.914737 00:54
40 0.686656 1.482220 0.726903 0.928226 00:53
41 0.675766 1.437706 0.725375 0.922881 00:53
42 0.690365 1.598217 0.691016 0.915246 00:54
43 0.688183 1.547985 0.691779 0.901247 00:54
44 0.681453 1.496727 0.704250 0.917282 00:53
45 0.694414 1.459416 0.717231 0.928990 00:53
46 0.695265 1.485699 0.718758 0.925426 00:54
47 0.689777 1.517926 0.708068 0.909137 00:54
48 0.686182 1.756034 0.645202 0.874014 00:53
49 0.681257 1.704624 0.657165 0.901502 00:53
50 0.681298 1.515766 0.705269 0.915500 00:54
51 0.686041 1.584682 0.691525 0.909901 00:53
52 0.688857 1.528068 0.706796 0.910155 00:53
53 0.692053 1.479343 0.711631 0.925172 00:54
54 0.691768 1.634907 0.670909 0.895648 00:53
55 0.694984 1.554712 0.705523 0.920081 00:55
56 0.681746 1.661562 0.669381 0.877068 00:54
57 0.674440 1.570412 0.688470 0.908119 00:54
58 0.685405 1.488756 0.718503 0.929244 00:54
59 0.679044 1.631579 0.669636 0.902520 00:54
60 0.670184 1.583849 0.686943 0.911428 00:54
61 0.681118 1.508625 0.708068 0.914482 00:54
62 0.674748 1.576503 0.684653 0.916773 00:54
63 0.668170 1.550465 0.691016 0.907610 00:54
64 0.673996 1.683372 0.662255 0.890812 00:54
65 0.667660 1.440590 0.726648 0.919318 00:54
66 0.656487 1.452681 0.719522 0.930517 00:54
67 0.654437 1.475030 0.725375 0.918300 00:55
68 0.640541 1.501212 0.701960 0.910155 00:53
69 0.638184 1.447417 0.726903 0.920081 00:53
70 0.627773 1.422424 0.727412 0.924154 00:53
71 0.626549 1.327196 0.761262 0.934843 00:54
72 0.621915 1.385796 0.741919 0.928226 00:54
73 0.623199 1.339253 0.749809 0.928481 00:54
74 0.618367 1.329801 0.754645 0.930517 00:54
75 0.614396 1.309854 0.761517 0.932807 00:54
76 0.616890 1.317712 0.763553 0.933062 00:54
77 0.619391 1.318682 0.763808 0.931026 00:54
78 0.615834 1.318036 0.764571 0.931280 00:54
79 0.614837 1.324565 0.762789 0.931535 00:55
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

0.7627894878387451

200 epochs

acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=200, arch='xresnet34', encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
55.50% [111/200 1:41:52<1:21:41]
epoch train_loss valid_loss accuracy top_k_accuracy time
0 0.984085 1.678459 0.670654 0.933825 00:54
1 0.922175 1.501608 0.704759 0.941461 00:54
2 0.854042 1.435633 0.751591 0.942734 00:55
3 0.826787 1.373917 0.754136 0.943497 00:54
4 0.789434 1.392040 0.741919 0.947824 00:54
5 0.776616 1.347452 0.751336 0.944770 00:54
6 0.751912 1.359706 0.756172 0.951387 00:54
7 0.753233 1.408145 0.728175 0.939425 00:55
8 0.747958 1.393444 0.749809 0.944006 00:53
9 0.727515 1.344913 0.763808 0.945533 00:54
10 0.724958 1.352103 0.758717 0.940188 00:54
11 0.724346 1.312932 0.749046 0.948587 00:54
12 0.714598 1.373003 0.746246 0.937134 00:54
13 0.707744 1.425695 0.737338 0.931026 00:54
14 0.702472 1.352487 0.741156 0.937898 00:54
15 0.706678 1.375309 0.751336 0.940443 00:54
16 0.704602 1.431296 0.734793 0.931535 00:54
17 0.704455 1.472749 0.726903 0.932807 00:56
18 0.695439 1.398680 0.733774 0.938661 00:54
19 0.697567 1.469991 0.719267 0.927717 00:54
20 0.687049 1.387192 0.750827 0.937389 00:55
21 0.703025 1.322433 0.757190 0.939679 00:54
22 0.689635 1.391278 0.744464 0.930771 00:55
23 0.693994 1.482016 0.716467 0.927208 00:54
24 0.684093 1.436301 0.734284 0.926699 00:55
25 0.690816 1.487099 0.717231 0.920081 00:55
26 0.696051 1.415452 0.734538 0.930008 00:55
27 0.689084 1.431780 0.727666 0.942988 00:55
28 0.701508 1.409245 0.738356 0.931789 00:55
29 0.697585 1.458436 0.745482 0.913719 00:55
30 0.694256 1.497586 0.709086 0.907356 00:55
31 0.683301 1.491587 0.723085 0.923899 00:55
32 0.690459 1.410597 0.740646 0.931535 00:55
33 0.683572 1.490317 0.710868 0.916773 00:56
34 0.686878 1.458840 0.727921 0.920336 00:55
35 0.677830 1.410740 0.741665 0.945788 00:54
36 0.688049 1.471387 0.719267 0.926190 00:56
37 0.683464 1.708444 0.649784 0.893866 00:55
38 0.688690 1.576261 0.692034 0.912446 00:55
39 0.686766 1.470237 0.707814 0.925426 00:55
40 0.686739 1.541277 0.694579 0.916773 00:55
41 0.689952 1.594274 0.680580 0.918300 00:55
42 0.683161 1.576835 0.684907 0.908374 00:55
43 0.684455 1.508371 0.720540 0.915246 00:55
44 0.680635 1.431406 0.726393 0.920845 00:55
45 0.676348 1.486763 0.713922 0.911682 00:55
46 0.692117 1.584741 0.678799 0.904810 00:55
47 0.693484 1.425312 0.727666 0.923899 00:55
48 0.690247 1.500790 0.715449 0.916773 00:55
49 0.690122 1.564491 0.697633 0.907101 00:55
50 0.678335 1.473881 0.716722 0.919827 00:54
51 0.688312 1.495640 0.708323 0.903538 00:54
52 0.688115 1.571938 0.687198 0.896411 00:54
53 0.693030 1.453131 0.708323 0.925935 00:55
54 0.682282 1.547777 0.688470 0.914737 00:55
55 0.690334 1.631448 0.658692 0.909901 00:55
56 0.669336 1.638186 0.663782 0.906592 00:55
57 0.681890 1.644338 0.667854 0.909646 00:55
58 0.683376 1.688515 0.658946 0.891830 00:55
59 0.682141 1.513852 0.698397 0.913464 00:55
60 0.678162 1.470327 0.709086 0.918045 00:54
61 0.692525 1.623360 0.667091 0.895902 00:55
62 0.687947 1.482525 0.723085 0.918300 00:55
63 0.681053 1.539435 0.691270 0.921609 00:55
64 0.689001 1.514883 0.707559 0.912446 00:55
65 0.674240 1.605520 0.688216 0.892084 00:55
66 0.680477 1.616489 0.681853 0.897938 00:55
67 0.679482 1.544586 0.693052 0.916009 00:55
68 0.688315 1.602148 0.678544 0.906847 00:55
69 0.679371 1.697513 0.670909 0.881140 00:55
70 0.682603 1.538098 0.695342 0.910410 00:55
71 0.694602 1.494579 0.716976 0.918045 00:56
72 0.677493 1.733421 0.650038 0.875541 00:54
73 0.677590 1.731535 0.640112 0.872232 00:54
74 0.678715 1.619843 0.672945 0.914228 00:54
75 0.674638 1.604907 0.677272 0.911428 00:55
76 0.686663 1.523086 0.703487 0.908883 00:55
77 0.691598 1.572895 0.688470 0.910155 00:55
78 0.679832 1.597538 0.680326 0.901247 00:55
79 0.673768 1.800995 0.618987 0.887249 00:55
80 0.687403 1.599654 0.682362 0.906847 00:55
81 0.679005 1.623663 0.685925 0.925172 00:55
82 0.695586 1.600658 0.683889 0.897938 00:55
83 0.679262 1.616485 0.666582 0.916009 00:55
84 0.680654 1.850782 0.612115 0.874014 00:55
85 0.683910 1.560729 0.692797 0.906083 00:55
86 0.686606 1.638184 0.670654 0.907101 00:55
87 0.687335 1.585300 0.682107 0.906083 00:55
88 0.680553 1.622772 0.668363 0.917536 00:54
89 0.687137 1.643962 0.671418 0.892848 00:54
90 0.675912 1.506958 0.708323 0.932044 00:55
91 0.690031 1.620577 0.674726 0.892594 00:54
92 0.682069 1.735025 0.639857 0.880122 00:54
93 0.677696 1.599553 0.661237 0.903538 00:54
94 0.675856 1.754643 0.626877 0.878850 00:54
95 0.677197 1.626098 0.661491 0.900229 00:53
96 0.679812 1.729993 0.637821 0.905828 00:54
97 0.692947 2.047096 0.563502 0.820820 00:55
98 0.672682 1.672114 0.668109 0.894884 00:55
99 0.676195 1.509647 0.711377 0.923136 00:53
100 0.681926 1.667116 0.658437 0.905828 00:54
101 0.680554 1.794548 0.623823 0.891321 00:53
102 0.675588 1.550800 0.688470 0.903792 00:53
103 0.673606 1.595899 0.672690 0.905319 00:54
104 0.680733 1.718332 0.653347 0.887249 00:55
105 0.668970 1.623140 0.669127 0.912700 00:53
106 0.681201 1.556153 0.688216 0.908883 00:54
107 0.680996 1.674697 0.670400 0.900484 00:55
108 0.674758 1.632093 0.658183 0.901502 00:55
109 0.678867 1.691918 0.666327 0.891830 00:55
110 0.669255 1.691689 0.652838 0.879868 00:54

68.52% [209/305 00:29<00:13 0.6808]
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

np.mean(acc)
0.7314838171005249