- Automatic gradient calculating API
Tutorial
- Automatically computing gradients of $y$ w.r.t $x$
x = torch.randn(2, requires_grad=True)
y = x*3
gradients = torch.tensor([100,0.1], dtype = torch.float)
y.backward(gradients)
print(x.grad)
#tensor([300.0000, 0.3000])
- Calling backward() twice
- We need to specify backward(retain_graph = True) to indicate not to free intermediate resources
x = torch.randn(2, requires_grad=True)
y = x*3
gradients = torch.tensor([100,0.1], dtype = torch.float)
y.backward(gradients, retain_graph = True)
print(x.grad)
y.backward(gradients)
print(x.grad)
#tensor([300.0000, 0.3000])
#tensor([600.0000, 0.6000])
- A tensor y is a computed result, so it contains the grad_fn attribute
- Referencing Function(class) that is called to construct
x = torch.randn(2, requires_grad=True)
y = x *3
z = x / 2
w = x + y
w, y,z
#(tensor([ 0.7003, -0.1958], grad_fn=<AddBackward0>),
#tensor([ 0.5252, -0.1468], grad_fn=<MulBackward0>),
#tensor([ 0.0875, -0.0245], grad_fn=<DivBackward0>))
- We can hook a function by using register hook
- Hooking is to alter other software components by intercepting function calls (messages, events, etc) passed between software components
- Define hook function
- register forward hook
- register forward pre hook
- register backward hook