문제 설명
리프 텐서의 값을 업데이트하는 적절한 방법은 무엇입니까(예: 경사하강법 업데이트 단계 중) (What's the proper way to update a leaf tensor's values (e.g. during the update step of gradient descent))
장난감 예제
일부 장난감 데이터에 선형 회귀(mx + b)를 맞추려고 시도하는 매우 간단한 경사 하강법 구현을 고려하십시오.
import torch
# Make some data
torch.manual_seed(0)
X = torch.rand(35) * 5
Y = 3 * X + torch.rand(35)
# Initialize m and b
m = torch.rand(size=(1,), requires_grad=True)
b = torch.rand(size=(1,), requires_grad=True)
# Pass 1
yhat = X * m + b # Calculate yhat
loss = torch.sqrt(torch.mean((yhat ‑ Y)**2)) # Calculate the loss
loss.backward() # Reverse mode differentiation
m = m ‑ 0.1*m.grad # update m
b = b ‑ 0.1*b.grad # update b
m.grad = None # zero out m gradient
b.grad = None # zero out b gradient
# Pass 2
yhat = X * m + b # Calculate yhat
loss = torch.sqrt(torch.mean((yhat ‑ Y)**2)) # Calculate the loss
loss.backward() # Reverse mode differentiation
m = m ‑ 0.1*m.grad # ERROR
첫 번째 단계 잘 작동하지만 마지막 줄에 두 번째 패스 오류가 있습니다. m = m ‑ 0.1*m.grad
.
Error
/usr/local/lib/python3.7/dist‑packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non‑leaf Tensor, use .retain_grad() on the non‑leaf Tensor. If you access the non‑leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:417.)
return self._grad
이유에 대한 이해 이것은 패스 1 동안 이 줄
m = m ‑ 0.1*m.grad
복사 m
를 새로운 텐서(즉, 완전히 별도의 메모리 블록)로 만드는 것입니다. 그래서 리프 텐서에서 리프가 아닌 텐서로 바뀝니다.
# Pass 1
...
print(f"{m.is_leaf}") # True
m = m ‑ 0.1*m.grad
print(f"{m.is_leaf}") # False
업데이트를 어떻게 수행합니까?
사용할 수 있다는 언급을 본 적이 있습니다. m 라인을 따라 뭔가.
참조 솔루션
방법 1:
You're observation is correct, in order to perform the update you should:
Apply the modification with in‑place operators.
Wrap the calls with torch.no_grad
context manager.
</ol> For instance:
with torch.no_grad():
m ‑= 0.1*m.grad # update m
b ‑= 0.1*b.grad # update b
참조 문서