clipvision-language

  • CLIP 1B Vit-Huge
  • SwitchBack
    • int8 linear
    • 13-25% speedup
    • ~90% of transformer compute spent in linear layers
    • quant
      • quant noise grows with matrix multiply inner dimension size
        • happens with CLIP due to large batch size requirement
      • use 16bit precision for gradient of weight multiplication
      • int8 for forward and input grads
      • provide Triton kernel for SwitchBack
  • reduce large magnitude features
    • layer scale init to 0
    • loss spikes occur 1-8 iterations after the squared gradients become under-estimated by their AdamW second moment estimator.
      • use AdamW-Adafactor, works better than grad clipping
  • StableAdamW == AdamW-Adafactor
    • AdamW + Adafactor clipping
      • tracks the average ratio of the gradient square to the second moment estimator and lowers the learning rate when the ratio is large
class SwitchBackMatmul(autograd.Function): @staticmethod
 
def forward(ctx, X, W): # X [b, n] inputs  
	# W [n, m] weights
	
	# save tensors in ctx
	ctx.save_for_backward = X, W  
	X_int8, state_X = row-wise_quantize(X)
	
	W_int8, state_W = tensor-wise_quantize(W)
	
	# Return output
	return matmul_int8_and_dequanitze(
	  X_int8, 
	  W_int8.t(), 
	  state_X, 
	  state_W
	)
 
@staticmethod  
def backward(ctx, G):
	# G [b, m] gradient to output # Recover tensors from ctx
	X, W = ctx.save_for_backward G_rowwise = rowwise_quantize(G)
	
	W_int8, state_W = tensor-wise_quantize_transpose(W)
	
	# Use 8bit matmul only for X_gradient
	X_gradient = matmul_int8_and_dequanitze(
	  G_int8, 
	  W_int8.t(), 
	  state_X, 
	  state_W
	)  
	W_gradient = matmul_fp16(G.t(), X)
	
	return X_gradient, W_gradient
 
 
class SwitchBackLinear(nn.Linear): def forward(self, X):
	return SwitchBackMatmul.apply(X, self.weight)