import matplotlib.pyplot as plt
Why truncated normal initialization?
Neurons would be dead below < -2 and above > 2 since GeLU can be approximated with input times the sigmoid function:
$$ x\sigma(1.702x) $$so truncated normal helps with that and makes sures all neurons are updated.
trunc_dist = [trunc_normal_(torch.tensor([0.]),std=1.5,a=-2,b=2).item() for o in range(5000)]
plt.hist(trunc_dist, bins=30);
bs = 4
x_large = [torch.randn(4,3,224,224)]*2
x_small = [torch.randn(16,3,96,96)]*4
x = x_large + x_small; [xi.size() for xi in x]
vit_encoder = VisionTransformer(patch_size=32, embed_dim=128, depth=4, num_heads=4, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
vit = MultiCropWrapper(vit_encoder)
out = vit(x)
len(out)