PyTorch backward 与 in-place 赋值小记

 

前言

本文简要分析PyTorchbackward使用中遇到的in-place赋值问题, 以作记录。 backward作为PyTorch的重要函数, 用于自动计算loss对计算图中所有requires_grad=True的叶子节点的梯度。日常使用简洁明了, 但偶尔也可能有”奇怪”的需求, 例如在复现论文算法2(Normalized SGD)更新底层网络(Lower Layers)参数(${\boldsymbol \theta}$)时。其更新方式定义如下:

其中标红部分, 需要对Target error(${\boldsymbol \delta} = {\bf W} {\boldsymbol h} + {\boldsymbol b} - Y$)和顶层网络(Upper Layer)的参数${\bf W}$进行放缩, 分别乘上${\boldsymbol \Sigma^{-1} }$。但是, 另一方面, 算法还需要对顶层的网络进行更新, 方式如下:

如果我们将loss定义为$\mathcal{L} \triangleq \left( {\boldsymbol \delta} \right)^2$, 那么backward函数所计算出的loss对顶层网络参数的梯度如上述所需, 但对底层网络参数梯度则为:

这样并未实现$\eqref{lower_update}$中的放缩。为此, 考虑能否暂时地将${\bf W}$和${\boldsymbol \delta}$乘上放缩因子, 在执行完$\eqref{lower_update}$中的更新后再将其还原呢? 可以!

简单的例子

为了以上的需求, 我们先从一个简单的例子说起。类似地, 我们构造一个”五脏俱全”的计算图如下:

graph LR;
    x --> y((y));
    y --> z(z);
    w-->z;
style x fill:#9ACD32, stroke:#333
style w fill:#9ACD32, stroke:#333

计算规则如下:

其中$x, y, w, z$均为标量。

第一步: backward小试

以上代码中z.backward()计算了$\partial z / \partial x$和$\partial z / \partial x$, 打印的结果如下:

tensor([4.])
tensor([1.])

$\partial z / \partial x = 2x w = 4$, 而$\partial z / \partial w = y = x^2 = 1$, 不出意外, 这正是”日常”使用backward的效果。

第二步: 尝试改改.data

这一步中, 试着修改一下w.data将其乘以2, 企图得到不同的结果。实际上, 打印的结果如下:

tensor([4.])
tensor([1.])

与上一步中的结果毫无区别。原本预期的结果是该值修改后, 相应计算所得的$\partial z / \partial x = 2x w = 8$, 但并不是。

第三步: 试试in-place修改

这一步与上一步的唯一区别在于对w.data修改方式变为了w.data *= 2, 即in-place赋值方式1, 这一概念并非PyTorch专有, 而是程序设计中的通用称呼。这一步所得到的输出如下:

tensor([8.])
tensor([1.])

正是预期的结果。分析两种赋值方式的区别在于:

w.data = w.data * 2
graph LR;
subgraph Before assignment
    A[w.data] --> B(old value)
end
subgraph After assignment
	A .-> C(new value)
	B .-> |" times 2 "| C
end
style B fill:pink,stroke:#333
style C fill:yellow, stroke:#333

如图所示, 实际上开辟了新的内存空间, 计算新的w.data, 而原本w.data所指向的内存地址所存放的结果仍然为原值。

但对于in-place的赋值方式则不同:

w.data *= 2
graph LR;
    A[w.data] --> B("old | new value")
    B .-> |" times 2 "| B
style B fill:pink,stroke:#333

如图所示, 在in-place赋值下, w.data所指向的内存地址不会发生变化, 仅相应地址中的值变化。而这恰恰是实现$\eqref{lower_update}$所需要的。因为, backward在计算梯度时, 是通过原地址获取相应的数值。如此, 通过in-place赋值方式所得到的w.data就是修改后预期的结果, 相反第二步中非in-place的赋值方式下原地址的值不变, 所以计算梯度时仍然用的原值。

小结

  • backward计算梯度时通过地址获取相应的所需的值, 而非标签(变量名称, 如w.data), 一般情况下这两者相等(标签指向相应的值), 所以看不出差别。
  • in-place赋值不新开辟内存空间, 而是在原址上修改值。
  • tensor.data可以获取tensor中的数据, 且对tensor.data操作时不会被autograd所记录, 即不会反映到backward中。