正文
,M表示dropout_mask矩阵,
表示element-wise product
-
由此可知,在整个链式推导的过程中,我们不会用到Z',即我们不用保存dropout的输入
(2)上图中,GELU的输入为什么算作激活值?
我们设GELU输入为
(为了作区分,这里符号表达和图中稍许不同),输出为
。离
最近的带参矩阵为
,那么我们可以通过求
的具体形式,来判断
在bwd中是否有用。
-
-
-
-
,这里为了表达简练,直接把Y到最后损失L的过程写成映射函数f
-
其中你看到的和
相关的那一串,就是GELU函数反向传播时的Y对Y'的偏导,本质上就是因为在链式传导到gelu这一项时需要用到Y',所以我们才将Y'作为激活值保存。
(3)上图中,如果我们把GELU替换成一个不带参的线性层,比如我们有
,那我们还需要保存Y'吗?
-
-
-
-
,这里为了表达简练,直接把Y到最后损失L的过程写成映射函数f
-
上面这3个例子在告诉我们,决定一份数据是否能作为激活值保存下来的要点就在于它会不会在bwd的链式传导过程中被使用。所以大家不要凭主观去判定“一个带参矩阵的输入输出一定就是激活值”,一定要自己动手推一下
。当然,只要看得多了,在不用手动推导的情况下,也能快速知道哪些数据是激活值了。
2.2 Attention层的激活值大小
设
b=batch_size,s=seq_len,h=hidden_size,a = head_num
-
Input LN
:数据进过layernorm前的结果将会被用在bwd的计算中,因此会被作为激活值存储下来,
其占据的存储大小为2bsh
,单位为bytes。
-
-
-
X经过带参矩阵Wq, Wk, Wv后的结果Q, K, V会被作为激活值保存,大小为3 * 2bsh = 6bsh
-
-