Utilities for creating torch Modules for self supervised learning.

create_fastai_encoder[source]

create_fastai_encoder(arch:str, pretrained=True, n_in=3, pool_type='catavgmax')

Create timm encoder from a given arch backbone

create_timm_encoder[source]

create_timm_encoder(arch:str, pretrained=True, n_in=3, pool_type='catavgmax')

Creates a body from any model in the timm library. If pool_type is None then it uses timm default

create_encoder[source]

create_encoder(arch:str, pretrained=True, n_in=3, pool_type='catavgmax')

A utility for creating encoder without specifying the package

inp = torch.randn((1,3,384,384))

Fastai encoder expects a function as it's first argument, where timm expects a string. Also, fastai defaults to concat pooling, aka catavgmax in timm. With timm's selective pooling any PoolingType can used. Experiments show that concat pooling is better on average so it is set as our default.

For any other pool_type fastai uses AdaptiveAvgPool2d, for timm you can choose from the remaining PoolingType.

fastai_encoder = create_fastai_encoder(xresnet34)
out = fastai_encoder(inp); out.shape
torch.Size([1, 1024])
fastai_encoder = create_fastai_encoder(xresnet34, pool_type=False)
out = fastai_encoder(inp); out.shape
torch.Size([1, 512])
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False)
out = model(inp); out.shape
torch.Size([1, 2560])
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape
torch.Size([1, 1280])
model = create_encoder("xresnet34", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape
torch.Size([1, 512])
model = create_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape
torch.Size([1, 1280])

Vision Transformer is a special case which uses Layernorm.

vit_model = create_timm_encoder("vit_large_patch16_384", pretrained=False)
out = vit_model(inp); out.shape
torch.Size([1, 1024])

create_mlp_module[source]

create_mlp_module(dim, hidden_size, projection_size, bn=False, nlayers=2)

MLP module as described in papers, used as projection layer

create_mlp_module(1024,4096,128)
Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=128, bias=True)
)
create_mlp_module(1024,4096,128,nlayers=3)
Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=4096, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=4096, out_features=128, bias=True)
)
create_mlp_module(1024,4096,128,bn=True)
Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=128, bias=True)
)
create_mlp_module(1024,4096,128,bn=True,nlayers=3)
Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=4096, out_features=128, bias=True)
)

create_cls_module[source]

create_cls_module(nf, n_out, lin_ftrs=None, ps=0.5, use_bn=True, first_bn=True, bn_final=False, lin_first=False, y_range=None)

Creates classification layer which takes nf flatten features and outputs n_out logits

inp = torch.randn((2,3,384,384))
encoder = create_encoder("xresnet34", pretrained=False)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
tensor([[-0.5934,  0.0218, -1.0546, -0.0870, -0.0212],
        [ 0.8928,  1.1403,  0.0279, -0.5045, -1.0595]])
encoder = create_encoder("vit_large_patch16_384", pretrained=False)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
tensor([[ 0.0023, -0.0434, -0.1689,  0.7236,  1.4304],
        [ 0.2860,  0.3319, -1.1037, -0.1302, -1.2017]])

create_model can be used to create models for classification, for example quickly creating a model for downstream classification training.

create_model[source]

create_model(arch, n_out, pretrained=True, n_in=3, pool_type='catavgmax', lin_ftrs=None, ps=0.5, use_bn=True, first_bn=True, bn_final=False, lin_first=False, y_range=None)

_splitter can be passed to Learner(...,splitter=splitter_func). This can be used to freeze or unfreeze encoder layers, in this case first parameter group is the encoder and second parameter group is the classification head. Simply by indexing to model[0] and model[1] we can access encoder and classification head modules.

model = create_model("xresnet34", 10, pretrained=False)
model[1]
Sequential(
  (0): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Dropout(p=0.25, inplace=False)
  (2): Linear(in_features=1024, out_features=512, bias=False)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=512, out_features=10, bias=False)
)
with torch.no_grad(): print(model(inp))
tensor([[ 1.6586,  2.8627,  0.5942,  1.3333, -0.7247, -0.1750,  0.4790, -4.1426,
          0.3254,  0.5920],
        [-0.6557,  0.3469, -3.1064,  0.4271,  3.6438,  0.0830, -1.9096,  4.2991,
         -1.3772,  0.3817]])
model = create_model("vit_large_patch16_384", 10, pretrained=False, use_bn=False, first_bn=False, bn_final=False)
model[1]
Sequential(
  (0): Dropout(p=0.25, inplace=False)
  (1): Linear(in_features=1024, out_features=512, bias=True)
  (2): ReLU(inplace=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=512, out_features=10, bias=True)
)
with torch.no_grad(): print(model(inp))
tensor([[ 2.2183,  2.0234, -4.9572, -1.5017,  5.2824, -0.1557,  1.8053,  2.5815,
          1.0612,  1.0911],
        [ 0.8053, -0.1254, -1.0162, -2.4544,  3.7484,  0.2554,  1.4608,  0.5014,
         -1.6777, -2.0474]])

Gradient Checkpointing

For memory conservation, to train with larger image resolution and/or batch size. It's compatible with all timm ResNet, EfficientNet and VisionTransformer models, and fastai models. But it should be easy to implement for any encoder model that you are using.

This is a current fix for using gradient checkpointing with autocast / to_fp16() https://github.com/pytorch/pytorch/pull/49757/files

class CheckpointResNet[source]

CheckpointResNet(resnet_model, checkpoint_nchunks=2) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

class CheckpointEfficientNet[source]

CheckpointEfficientNet(effnet_model, checkpoint_nchunks=2) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

class CheckpointVisionTransformer[source]

CheckpointVisionTransformer(vit_model, checkpoint_nchunks=2) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

class CheckpointSequential[source]

CheckpointSequential(fastai_model, checkpoint_nchunks=2) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

L(timm.list_models("*resnet50*"))[-10:]
(#10) ['seresnet50','seresnet50tn','skresnet50','skresnet50d','ssl_resnet50','swsl_resnet50','tv_resnet50','vit_base_resnet50d_224','vit_small_resnet50d_s3_224','wide_resnet50_2']
encoder = create_encoder("seresnet50", pretrained=False)
encoder = CheckpointResNet(encoder, checkpoint_nchunks=4)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
tensor([[ 0.2153, -0.8222, -0.1195, -0.1419,  0.2558],
        [-0.0267,  1.1275,  0.4353, -0.2715, -1.3025]])
L(timm.list_models("*efficientnet*"))[-10:]
(#10) ['tf_efficientnet_el','tf_efficientnet_em','tf_efficientnet_es','tf_efficientnet_l2_ns','tf_efficientnet_l2_ns_475','tf_efficientnet_lite0','tf_efficientnet_lite1','tf_efficientnet_lite2','tf_efficientnet_lite3','tf_efficientnet_lite4']
encoder = create_encoder("tf_efficientnet_b0_ns", pretrained=False)
encoder = CheckpointEfficientNet(encoder, checkpoint_nchunks=4)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
tensor([[ 0.2183, -1.7747,  0.6225, -0.2091, -0.6604],
        [-0.4133,  1.4024,  0.4160, -0.6159,  0.6558]])
encoder = create_encoder("xresnet34", pretrained=False)
encoder = CheckpointSequential(encoder, checkpoint_nchunks=4)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
tensor([[ 0.2094, -0.0202,  1.5631,  0.8257,  0.8442],
        [-0.4477, -0.2046, -0.9960, -1.3508,  0.2298]])