Algorithm

BYOL

BYOL

Abstract: We introduce Bootstrap Your Own Latent (BYOL), a new approach to self-supervised image representation learning. BYOL relies on two neural networks, referred to as online and target networks, that interact and learn from each other. From an augmented view of an image, we train the online network to predict the target network representation of the same image under a different augmented view. At the same time, we update the target network with a slow-moving average of the online network. While state-of-the art methods rely on negative pairs, BYOL achieves a new state of the art without them. BYOL reaches 74:3% top-1 classification accuracy on ImageNet using a linear evaluation with a ResNet-50 architecture and 79:6% with a larger ResNet. We show that BYOL performs on par or better than the current state of the art on both transfer and semi-supervised benchmarks. Our implementation and pretrained models are given on GitHub.3

class BYOLModel[source]

BYOLModel(encoder, projector, predictor) :: Module

Compute predictions of v1 and v2

You can either use BYOLModel module to create a model by passing predefined encoder, projector and predictor models or you can use create_byol_model with just passing predefined encoder and expected input channels.

You may notice projector/MLP module defined here is different than the one defined in SimCLR, in the sense that it has a batchnorm layer. You can read this great blog post for a better intuition on the effect of the batchnorm layer in BYOL.

create_byol_model[source]

create_byol_model(encoder, hidden_size=4096, projection_size=256, bn=True, nlayers=2)

Create BYOL model

encoder = create_encoder("tf_efficientnet_b0_ns", n_in=3, pretrained=False, pool_type=PoolingType.CatAvgMax)
model = create_byol_model(encoder, hidden_size=2048, projection_size=128)
out = model(torch.randn((2,3,224,224)), torch.randn((2,3,224,224)))
out[0].shape, out[1].shape
(torch.Size([2, 128]), torch.Size([2, 128]))

BYOL Callback

The following parameters can be passed;

  • aug_pipelines list of augmentation pipelines List[Pipeline] created using functions from self_supervised.augmentations module. Each Pipeline should be set to split_idx=0. You can simply use get_byol_aug_pipelines utility to get aug_pipelines.
  • m is momentum for target encoder/model update, a similar idea to MoCo.

BYOL algorithm uses 2 views of a given image, and BYOL callback expects a list of 2 augmentation pipelines in aug_pipelines.

You can simply use helper function get_byol_aug_pipelines() which will allow augmentation related arguments such as size, rotate, jitter...and will return a list of 2 pipelines, which we can be passed to the callback. This function uses get_multi_aug_pipelines which then get_batch_augs. For more information you may refer to self_supervised.augmentations module.

Also, you may choose to pass your own list of aug_pipelines which needs to be List[Pipeline, Pipeline] where Pipeline(..., split_idx=0). Here, split_idx=0 forces augmentations to be applied in training mode.

get_byol_aug_pipelines[source]

get_byol_aug_pipelines(size, rotate=True, jitter=True, bw=True, blur=True, resize_scale=(0.2, 1.0), resize_ratio=(0.75, 1.3333333333333333), rotate_deg=30, jitter_s=0.6, blur_s=(4, 32), same_on_batch=False, flip_p=0.5, rotate_p=0.3, jitter_p=0.3, bw_p=0.3, blur_p=0.3, stats=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), cuda=True, xtra_tfms=[])

class BYOL[source]

BYOL(aug_pipelines, m=0.999, print_augs=False) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

Example Usage

path = untar_data(URLs.MNIST_TINY)
items = get_image_files(path)
tds = Datasets(items, [PILImageBW.create, [parent_label, Categorize()]], splits=GrandparentSplitter()(items))
dls = tds.dataloaders(bs=5, after_item=[ToTensor(), IntToFloatTensor()], device='cpu')
fastai_encoder = create_encoder('xresnet18', n_in=1, pretrained=False)
model = create_byol_model(fastai_encoder, hidden_size=4096, projection_size=256, bn=True)
aug_pipelines = get_byol_aug_pipelines(size=28, rotate=False, jitter=False, bw=False, blur=False, stats=None, cuda=False)
learn = Learner(dls, model, cbs=[BYOL(aug_pipelines=aug_pipelines, print_augs=True), ShortEpochCallback(0.001)])
Pipeline: RandomResizedCrop -> RandomHorizontalFlip
Pipeline: RandomResizedCrop -> RandomHorizontalFlip
learn.summary()
BYOLModel (Input shape: 5)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     5 x 32 x 14 x 14    
Conv2d                                    288        True      
BatchNorm2d                               64         True      
ReLU                                                           
Conv2d                                    9216       True      
BatchNorm2d                               64         True      
ReLU                                                           
____________________________________________________________________________
                     5 x 64 x 14 x 14    
Conv2d                                    18432      True      
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     5 x 128 x 4 x 4     
Conv2d                                    73728      True      
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     True      
BatchNorm2d                               256        True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     5 x 128 x 4 x 4     
Conv2d                                    8192       True      
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     True      
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     True      
BatchNorm2d                               256        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     5 x 256 x 2 x 2     
Conv2d                                    294912     True      
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     True      
BatchNorm2d                               512        True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     5 x 256 x 2 x 2     
Conv2d                                    32768      True      
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     True      
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     True      
BatchNorm2d                               512        True      
Sequential                                                     
ReLU                                                           
____________________________________________________________________________
                     5 x 512 x 1 x 1     
Conv2d                                    1179648    True      
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    True      
BatchNorm2d                               1024       True      
____________________________________________________________________________
                     []                  
AvgPool2d                                                      
____________________________________________________________________________
                     5 x 512 x 1 x 1     
Conv2d                                    131072     True      
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    True      
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    True      
BatchNorm2d                               1024       True      
Sequential                                                     
ReLU                                                           
AdaptiveAvgPool2d                                              
AdaptiveMaxPool2d                                              
Flatten                                                        
____________________________________________________________________________
                     5 x 4096            
Linear                                    4198400    True      
BatchNorm1d                               8192       True      
ReLU                                                           
____________________________________________________________________________
                     5 x 256             
Linear                                    1048832    True      
____________________________________________________________________________
                     5 x 4096            
Linear                                    1052672    True      
BatchNorm1d                               8192       True      
ReLU                                                           
____________________________________________________________________________
                     5 x 256             
Linear                                    1048832    True      
____________________________________________________________________________

Total params: 18,560,288
Total trainable params: 18,560,288
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7fdfb9f15a70>
Loss function: <bound method BYOL.lf of BYOL>

Callbacks:
  - TrainEvalCallback
  - ShortEpochCallback
  - BYOL
  - Recorder
  - ProgressCallback
b = dls.one_batch()
learn._split(b)
learn('before_fit')
learn('before_batch')
axes = learn.byol.show(n=5)
learn.fit(1)
epoch train_loss valid_loss time
0 00:03
learn.recorder.losses
[tensor(4.0914)]