In [1]:
import torch

## Showing How create_graph=True Works

In [2]:
x = torch.tensor([3.0,1.0], requires_grad=True)
y = torch.tensor([1.0,2.0], requires_grad=True)
z = torch.sum(x*y)
w = z**2
w.backward() #get dw/dx, dw/dy
print(x.grad)
print(y.grad)
print(z.grad)


tensor([10., 20.])
tensor([30., 10.])
None


$x,y\in {\rm R}^2, z\in {\rm R}$

$w=z^2, \quad z=x^T y, \quad w(x,y)=(x^T y)^2$

$\nabla_x w=2y^T x y, \quad \nabla_y w=2x^T y x$

$\frac{\partial^2 w}{\partial x_1^2}=2y_1^2, \quad$
$\frac{\partial^2 w}{\partial x_1x_2}=2y_1y_2, \quad$
$\frac{\partial^2 w}{\partial x_1y_1}=4x_1y_1+2x_2y_2, \quad$
$\frac{\partial^2 w}{\partial x_1y_2}=2x_2y_1, \quad$

In [3]:
x = torch.tensor([3.0,1.0], requires_grad=True)
y = torch.tensor([1.0,2.0], requires_grad=True)
z = torch.sum(x*y)
w = z**2
'''get dw/dx, dw/dy, and get the formula (computation graph) for dw/dx1, dw/dx2, dw/dy1, dw/dy2'''
w.backward(create_graph=True)
print(x.grad)
print(y.grad)

partial_x_1 = x.grad[0]
print(partial_x_1)
x.grad.data.zero_() #release grad logging
y.grad.data.zero_() #release grad logging
print(x.grad)
print(y.grad)


partial_x_1.backward() #get d(dw/dx1)/dx, d(dw/dx1)/dy
print(x.grad) #[d2w/dx1dx1,d2w/dx1dx2]
print(y.grad) #[d2w/dx1dy1,d2w/dx1dy2]

tensor([10., 20.], grad_fn=<CloneBackward>)
tensor([30., 10.], grad_fn=<CloneBackward>)
tensor(10., grad_fn=<SelectBackward>)
tensor([0., 0.], grad_fn=<CloneBackward>)
tensor([0., 0.], grad_fn=<CloneBackward>)
tensor([2., 4.], grad_fn=<CloneBackward>)
tensor([16.,  2.], grad_fn=<CloneBackward>)


In [4]:
x = torch.tensor([3.0,1.0], requires_grad=True)
y = torch.tensor([1.0,2.0], requires_grad=True)
z = torch.sum(x*y)
w = z**2
w.backward(create_graph=True) #get dw/dx, dw/dy, and get the formula (computation graph) for dw/dx1, dw/dx2, dw/dy1, dw/dy2
print(x.grad)
print(y.grad)

partial_x_1 = x.grad[0]
print(partial_x_1)
#without releasing grad logging

partial_x_1.backward() #get d(dw/dx1)/dx, d(dw/dx1)/dy
print(x.grad) #[d2w/dx1dx1,d2w/dx1dx2]+[dw/dx1,dw/dx2]
print(y.grad) #[d2w/dx1dy1,d2w/dx1dy2]+[dw/dy1,dw/dy2]

tensor([10., 20.], grad_fn=<CloneBackward>)
tensor([30., 10.], grad_fn=<CloneBackward>)
tensor(10., grad_fn=<SelectBackward>)
tensor([12., 24.], grad_fn=<CloneBackward>)
tensor([46., 12.], grad_fn=<CloneBackward>)


## A real example: MAML (2nd order)
https://colab.research.google.com/drive/1MFJwRdOTefd6UOYRsNjdc7BWuB7Qe3lY

The following code cannot run; just for illustration

In [None]:
class net(nn.Module):
    def __init__(self, init_weight=None):
        super(net, self).__init__()
        if type(init_weight) != type(None):
            for name, module in init_weight.named_modules():
                if name != '':
                    setattr(self, name, MetaLinear(module))
        else:
            self.hidden1 = nn.Linear(1, 40)
            self.hidden2 = nn.Linear(40, 40)
            self.out = nn.Linear(40, 1)
    
    def zero_grad(self):
        layers = self.__dict__['_modules']
        for layer in layers.keys():
            layers[layer].zero_grad()
    def update(self, parent, lr = 1):
        layers = self.__dict__['_modules']
        parent_layers = parent.__dict__['_modules']
        for param in layers.keys():
            layers[param].weight = layers[param].weight - lr*parent_layers[param].weight.grad
            layers[param].bias = layers[param].bias - lr*parent_layers[param].bias.grad
        # gradient will flow back due to clone backward
        
    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        return self.out(x)


In [None]:
class Meta_learning_model():
    def __init__(self, init_weight = None):
        super(Meta_learning_model, self).__init__()
        self.model = net().to(device)
        if type(init_weight) != type(None):
            self.model.load_state_dict(init_weight)
        self.grad_buffer = 0
    def gen_models(self, num, check = True):
        models = [net(init_weight=self.model).to(device) for i in range(num)]
        return models
    def clear_buffer(self):
        print("Before grad", self.grad_buffer)
        self.grad_buffer = 0

In [None]:
bsz = 10
train_x, train_y, train_label = meta_task_data(task_num=50000*10) 
train_x = torch.Tensor(train_x).unsqueeze(-1) # add one dim
train_y = torch.Tensor(train_y).unsqueeze(-1)
train_dataset = data.TensorDataset(train_x, train_y)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=bsz, shuffle=False)

meta_model = Meta_learning_model()
meta_optimizer = torch.optim.Adam(meta_model.model.parameters(), lr = 1e-3)

In [None]:
epoch = 1
for e in range(epoch):
    meta_model.model.train()
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = y.to(device)
        sub_models = meta_model.gen_models(bsz)

        meta_l = 0
        for model_num in range(len(sub_models)):
                
            sample = list(range(10))
            np.random.shuffle(sample)
            
            # meta learning
            
            y_tilde = sub_models[model_num](x[model_num][sample[:5],:])
            little_l = F.mse_loss(y_tilde, y[model_num][sample[:5],:])
            #compute gradient ∇_ϕ, obtain its computation graph for high-order gradient
            little_l.backward(create_graph = True)
            sub_models[model_num].update(lr = 1e-2, parent = meta_model.model)
            #clear gradient in optimizer (avoid from gradient cumulation)
            meta_optimizer.zero_grad()
            #compute 2nd-order gradient
            #in detail: the update() method in sub_model is defined as such:
            #layers[par].weight = layers[par].weight-lr*parent_layers[par].weight.grad
            #parent_layers[par].weight.grad has computation graph because of 
            #create_graph=True
            #therefore, when again using sub_models for forwarding, we actually applying computation graph of grad. Therefore, the meta-update will consider the computation graph of grad.
            y_tilde = sub_models[model_num](x[model_num][sample[5:],:])
            meta_l =  meta_l + F.mse_loss(y_tilde, y[model_num][sample[5:],:])

        meta_l = meta_l / bsz
        meta_l.backward()
        meta_optimizer.step()
        meta_optimizer.zero_grad()