PyTorch autograd -- grad can be implicitly created only for scalar outputs(PyTorch autograd -- 只能为标量输出隐式创建 grad)
问题描述
我在 PyTorch
中使用了 autograd
工具,并且发现自己需要通过整数索引访问一维张量中的值.像这样:
I am using the autograd
tool in PyTorch
, and have found myself in a situation where I need to access the values in a 1D tensor by means of an integer index. Something like this:
def basic_fun(x_cloned):
res = []
for i in range(len(x)):
res.append(x_cloned[i] * x_cloned[i])
print(res)
return Variable(torch.FloatTensor(res))
def get_grad(inp, grad_var):
A = basic_fun(inp)
A.backward()
return grad_var.grad
x = Variable(torch.FloatTensor([1, 2, 3, 4, 5]), requires_grad=True)
x_cloned = x.clone()
print(get_grad(x_cloned, x))
我收到以下错误消息:
[tensor(1., grad_fn=<ThMulBackward>), tensor(4., grad_fn=<ThMulBackward>), tensor(9., grad_fn=<ThMulBackward>), tensor(16., grad_fn=<ThMulBackward>), tensor(25., grad_fn=<ThMulBackward>)]
Traceback (most recent call last):
File "/home/mhy/projects/pytorch-optim/predict.py", line 74, in <module>
print(get_grad(x_cloned, x))
File "/home/mhy/projects/pytorch-optim/predict.py", line 68, in get_grad
A.backward()
File "/home/mhy/.local/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
总的来说,我有点怀疑使用变量的克隆版本应该如何在梯度计算中保持该变量.在A
的计算中实际上不使用变量本身,因此当您调用A.backward()
时,它不应成为该操作的一部分.
I am in general, a bit skeptical about how using the cloned version of a variable is supposed to keep that variable in gradient computation. The variable itself is effectively not used in the computation of A
, and so when you call A.backward()
, it should not be part of that operation.
感谢您对这种方法的帮助,或者是否有更好的方法来避免丢失梯度历史并仍然通过 requires_grad=True
索引通过一维张量!
I appreciate your help with this approach or if there is a better way to avoid losing the gradient history and still index through a 1D tensor with requires_grad=True
!
res
是一个包含 1 到 5 平方值的零维张量列表.为了连接一个包含 [1.0, 4.0, ..., 25.0] 的张量,我改变了 return Variable(torch.FloatTensor(res))
到 torch.stack(res, dim=0)
,产生 tensor([ 1., 4., 9., 16., 25.], grad_fn=
.
res
is a list of zero-dimensional tensors containing squared values of 1 to 5. To concatenate in a single tensor containing [1.0, 4.0, ..., 25.0], I changed return Variable(torch.FloatTensor(res))
to torch.stack(res, dim=0)
, which produces tensor([ 1., 4., 9., 16., 25.], grad_fn=<StackBackward>)
.
但是,我收到了这个由 A.backward()
行引起的新错误.
However, I am getting this new error, caused by the A.backward()
line.
Traceback (most recent call last):
File "<project_path>/playground.py", line 22, in <module>
print(get_grad(x_cloned, x))
File "<project_path>/playground.py", line 16, in get_grad
A.backward()
File "/home/mhy/.local/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 84, in backward
grad_tensors = _make_grads(tensors, grad_tensors)
File "/home/mhy/.local/lib/python3.5/site-packages/torch/autograd/__init__.py", line 28, in _make_grads
raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs
推荐答案
我将我的 basic_fun
更改为以下内容,解决了我的问题:
I changed my basic_fun
to the following, which resolved my problem:
def basic_fun(x_cloned):
res = torch.FloatTensor([0])
for i in range(len(x)):
res += x_cloned[i] * x_cloned[i]
return res
此版本返回标量值.
这篇关于PyTorch autograd -- 只能为标量输出隐式创建 grad的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:PyTorch autograd -- 只能为标量输出隐式创建 grad


- 使用公司代理使Python3.x Slack(松弛客户端) 2022-01-01
- 如何使用PYSPARK从Spark获得批次行 2022-01-01
- 使用 Cython 将 Python 链接到共享库 2022-01-01
- 我如何卸载 PyTorch? 2022-01-01
- ";find_element_by_name(';name';)";和&QOOT;FIND_ELEMENT(BY NAME,';NAME';)";之间有什么区别? 2022-01-01
- 我如何透明地重定向一个Python导入? 2022-01-01
- 计算测试数量的Python单元测试 2022-01-01
- 检查具有纬度和经度的地理点是否在 shapefile 中 2022-01-01
- YouTube API v3 返回截断的观看记录 2022-01-01
- CTR 中的 AES 如何用于 Python 和 PyCrypto? 2022-01-01