2023-12-17 - Stable and low-precision training for large-scale vision-language models

#clip #vision-language

class SwitchBackMatmul(autograd.Function): @staticmethod

def forward(ctx, X, W): # X [b, n] inputs  
	# W [n, m] weights
	
	# save tensors in ctx
	ctx.save_for_backward = X, W  
	X_int8, state_X = row-wise_quantize(X)
	
	W_int8, state_W = tensor-wise_quantize(W)
	
	# Return output
	return matmul_int8_and_dequanitze(
	  X_int8, 
	  W_int8.t(), 
	  state_X, 
	  state_W
	)

@staticmethod  
def backward(ctx, G):
	# G [b, m] gradient to output # Recover tensors from ctx
	X, W = ctx.save_for_backward G_rowwise = rowwise_quantize(G)
	
	W_int8, state_W = tensor-wise_quantize_transpose(W)
	
	# Use 8bit matmul only for X_gradient
	X_gradient = matmul_int8_and_dequanitze(
	  G_int8, 
	  W_int8.t(), 
	  state_X, 
	  state_W
	)  
	W_gradient = matmul_fp16(G.t(), X)
	
	return X_gradient, W_gradient


class SwitchBackLinear(nn.Linear): def forward(self, X):
	return SwitchBackMatmul.apply(X, self.weight)