First import fastai for training and other helpers, you can choose not to use wandb by setting WANDB=False
.
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
import wandb
WANDB = False
Then import self_supervised augmentations
module for creating augmentations pipeline, layers
module for creating encoder and model, and finally swav
for self-supervised training.
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *
In this notebook we will take a look at ImageWang benchmark, how to train a self-supervised model using SwAV algorithm and then how to use this pretrained model for finetuning on the given downstream task.
Warning: This notebook actually reaches best public leaderboard score 😃
def get_dls(size, bs, workers=None):
path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
source = untar_data(path)
files = get_image_files(source)
tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)],
[parent_label, Categorize()]]
dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
batch_tfms = [IntToFloatTensor]
dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
return dls
ImageWang has several benchmarks for different image sizes, in this tutorial we will go for size=224
and also demonstrate how effectively you can utilize GPU memory.
Define batch size, resize resolution before batching and size for random cropping during self-supervised training.
bs, resize, size = 96, 256, 224
Set queue size, it needs to be a multiple of batch size.
K = bs*2**4
Select architecture to train on, remember all timm and fastai models are available! We need to set pretrained=False
here because using imagenet weights for ImageWang data would be cheating.
arch = "xresnet34"
encoder = create_encoder(arch, pretrained=False, n_in=3)
if WANDB:
xtra_config = {"Arch":arch, "Resize":resize, "Size":size, "Algorithm":"SWAV"}
wandb.init(project="self-supervised-imagewang", config=xtra_config);
Initialize the Dataloaders using the function above.
dls = get_dls(resize, bs)
Create SwAV model. You can change values of hidden_size
, projection_size
, n_protos
- number of prototypes/psuedo classes for cluster assignment. If defaults are not working for your problem try changing n_protos
first. For this problem, defaults work just fine so we don't do any changes.
model = create_swav_model(encoder)
Next step is perhaps the most critical step for achieving good results on a custom problem - data augmentation. For this, we will use utility function from self_supervised.vision.swav.get_swav_aug_pipelines
but you can also use your own list of Pipeline augmentations. self_supervised.vision.swav.get_swav_aug_pipelines
should be enough for most of the cases since under the hood it uses self_supervised.augmentations.get_multi_aug_pipelines
and self_supervised.augmentations.get_batch_augs
. You can do shift+tab and see all the arguments that can be passed to get_swav_aug_pipelines
. You can simply pass anything that you could pass to get_batch_augs
including custom xtra_tfms
.
get_swav_aug_pipelines
excepts certain arguments, for more detail please visit SwAV's documentation. To briefly explain, here we define 2 large crops and 6 smaller crops which are in size and int(3/4*size) pixel resolution respectively. We also set RandomResizedCrop scales for large crops to (0.25 - 1.0) and for smaller crops to (0.2, 0.35). Rest of the arguments are coming from get_batch_augs()
aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
crop_sizes=[size,int(3/4*size)],
min_scales=[0.25,0.2],
max_scales=[1.0,0.35],
rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False)
Finally we need to pass the indexes of large crops to crop_assgn_ids
, since we defined the first 2 crops as large in our aug_pipelines indexes will be 0 and 1. I also set queue start to 0.5, I found it to be a good value.
cbs=[SWAV(aug_pipelines, crop_assgn_ids=[0,1], K=K, queue_start_pct=0.5, temp=0.1)]
if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]
learn = Learner(dls, model, cbs=cbs)
Before starting training let's check whether our augmentations makes sense or not. Since this step consumes GPU memory, once you are done with inspection, restart the notebook and skip this step. We can see that first 2 crops are the larger ones and indeed augmentations look pretty good. Now, it's time restart the notebook and skip this step.
b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.swav.show(n=5);
For more GPU memory, larger batch size and faster training use mixed precision with to_fp16()
. We could also use gradient checkpointing wrapper models from self_supervised.layers
to save even more memory, e.g. CheckpointSequential()
.
learn.to_fp16();
Learning good representations via contrastive learning usually takes a lot of epochs. So here number epochs are set to 100. This might change depending on your data distribution and dataset size.
lr,wd,epochs=1e-2,1e-2,100
learn.unfreeze()
learn.fit_flat_cos(epochs, lr, wd=wd, pct_start=0.5)
if WANDB: wandb.finish()
save_name = f'swav_iwang_sz{size}_epc{epochs}'
learn.save(save_name)
torch.save(learn.model.encoder.state_dict(), learn.path/learn.model_dir/f'{save_name}_encoder.pth')
learn.recorder.plot_loss()
optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)
bs, size
def get_dls(size, bs, workers=None):
path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
source = untar_data(path)
files = get_image_files(source, folders=['train', 'val'])
splits = GrandparentSplitter(valid_name='val')(files)
item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
tfms = [[PILImage.create, ToTensor, *item_aug],
[parent_label, Categorize()]]
dsets = Datasets(files, tfms=tfms, splits=splits)
batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
return dls
def split_func(m): return L(m[0], m[1]).map(params)
def create_learner(size=size, arch='xresnet34', encoder_path="models/swav_iwang_sz128_epc100_encoder.pth"):
dls = get_dls(size, bs=bs//2)
pretrained_encoder = torch.load(encoder_path)
encoder = create_encoder(arch, pretrained=False, n_in=3)
encoder.load_state_dict(pretrained_encoder)
nf = encoder(torch.randn(2,3,224,224)).size(-1)
classifier = create_cls_module(nf, dls.c)
model = nn.Sequential(encoder, classifier)
learn = Learner(dls, model, opt_func=opt_func, splitter=split_func,
metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
return learn
def finetune(size, epochs, arch, encoder_path, lr=1e-2, wd=1e-2):
learn = create_learner(size, arch, encoder_path)
learn.unfreeze()
learn.fit_flat_cos(epochs, lr, wd=wd)
final_acc = learn.recorder.values[-1][-2]
return final_acc
acc = []
runs = 5
for i in range(runs): acc += [finetune(size, epochs=5, arch='xresnet34', encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
acc = []
runs = 3
for i in range(runs): acc += [finetune(size, epochs=20, arch='xresnet34', encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=80, arch='xresnet34',encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=200, arch='xresnet34', encoder_path=f'models/swav_iwang_sz{size}_epc100_encoder.pth')]
np.mean(acc)
np.mean(acc)