正文
self.numberOfMixerBlocks = numberOfMixerBlocks # N
self.codebookTokenDimension = codebookTokenDimension # M
self.internalMixerSize = internalMixerSize
self.internalMixerTokenSize = internalMixerTokenSize
self.mixerDropout = mixerDropout
self.initial_linear = nn.Linear(self.dimensionOfKeypoints,
self.linearProjectionSize) # 从BxKxD投影到BxKxH
self.mixer_layers = nn.ModuleList([MixerLayer(self.linearProjectionSize,
self.internalMixerSize,
self.numberOfKeypoints,
self.internalMixerTokenSize,
self.mixerDropout) for _ in range(self.numberOfMixerBlocks)]) # BxKxH
self.mixer_layer_norm = nn.LayerNorm(self.linearProjectionSize) # BxKxH
self.token_linear = nn.Linear(self.numberOfKeypoints,
self.codebookTokenDimension) # BxHxK -> BxHxM
self.feature_embed = nn.Linear(self.linearProjectionSize,
self.codebookTokenDimension)
def forward(self, x):
# 之前: BxDxK
x = x.transpose(2,1)
# 之后: BxKxD
# 之前: BxKxD
x = self.initial_linear(x)
# 之后: BxKxH
# 之前: BxKxH
for mixer in self.mixer_layers:
x = mixer(x)
# 之后: BxKxH
# 之前: BxKxH
x = self.mixer_layer_norm(x)
# 之后: BxKxH
# 之前: BxKxH
x = x.transpose(2,1)
# 之后: BxHxK
# 之前: BxHxK
x = self.token_linear(x)
# 之后: BxHxM
# 之前: BxHxM
x = x.transpose(2,1)
# 之后: BxMxH
# 之前: BxMxH
x = self.feature_embed(x)
# 之后: BXMxM
return x
编码器接收一组二维关键点坐标,通过基于MLP-Mixer架构设计的网络结构将这些坐标转换为M个潜在标记特征。具体而言,关键点首先被嵌入到高维空间,然后在关节和特征维度之间进行混合处理,最终投影到形状为B × M × D(批量大小×标记数量×标记维度)的输出特征空间。
class CodebookVQ(nn.Module):
def __init__(self, codebookDimension, numberOfCodebookTokens, decay=0.99, epsilon=1e-5):
super(CodebookVQ, self).__init__()
self.codebookDimension = codebookDimension
self.numberOfCodebookTokens = numberOfCodebookTokens
self.decay = decay
self.epsilon = epsilon
self.register_buffer('codebook', torch.empty(numberOfCodebookTokens, codebookDimension))
self.codebook.data.normal_()
self.register_buffer('ema_cluster_size', torch.zeros(numberOfCodebookTokens))
self.register_buffer('ema_w', torch.empty(numberOfCodebookTokens, codebookDimension))
self.ema_w.data.normal_()
def forward(self, encode_feat):
M = encode_feat.shape[1]
B = encode_feat.shape[0]
encode_feat = encode_feat.view(-1, self.codebookDimension) # [B*M, M]
# 计算与码本条目的距离
distances = (
encode_feat.pow(2).sum(1, keepdim=True)
- 2 * encode_feat @ self.codebook.t()
+ self.codebook.pow(2).sum(1)
) # [B*M, num_tokens]
# 找到最近的码本索引
encoding_indices = torch.argmin(distances, dim=1) # [B*M]
encodings = F.one_hot(encoding_indices, self.numberOfCodebookTokens).type(encode_feat.dtype) # [B*M, num_tokens]
# 量化输出
quantized = encodings @ self.codebook # [B*M, M]
quantized = quantized.view_as(encode_feat) # 重塑回原始输入形状
if self.training:
# EMA更新
ema_counts = encodings.sum(0) # [num_tokens]
dw = encodings.t() @ encode_feat # [num_tokens, M]
self.ema_cluster_size.mul_(self.decay).add_(ema_counts, alpha=1 - self.decay)
self.ema_w.mul_(self.decay).add_(dw, alpha=1 - self.decay)
n = self.ema_cluster_size.sum()
cluster_size = (
(self.ema_cluster_size + self.epsilon)
/ (n + self.numberOfCodebookTokens * self.epsilon) * n
)
self.codebook.data = self.ema_w / cluster_size.unsqueeze(1)
quantized = quantized.view(B, M, M)
encoding_indices = encoding_indices.view(B, M)
return quantized, encoding_indices
潜在标记通过向量量化层进行离散化处理,该层采用指数移动平均(EMA)方法更新码本。在这一过程中,每个连续的标记向量都被码本中最相近的离散代码向量替换,从而将姿态表示转化为一组符号化的离散表示。具体实现中:
-
-
-
码本在训练过程中通过EMA机制进行自我更新,确保码本适应训练数据分布
class PoseDecoder(nn.Module):
def __init__(self, codebookTokenDimension=64, numberOfKeypoints=11, keypointDimension=2, hiddenDimensionSize=128, numberOfMixerBlocks=4, mixerInternalDimensionSize=64, mixerTokenInternalDimensionSize=128, mixerDropout=0.1):
super(PoseDecoder, self).__init__()
self.codebookTokenDimension = codebookTokenDimension
self.numberOfKeypoints = numberOfKeypoints
self.keypointDimension = keypointDimension
self.hiddenDimensionSize = hiddenDimensionSize
self.mixerInternalDimensionSize = mixerInternalDimensionSize
self.mixerTokenInternalDimensionSize = mixerTokenInternalDimensionSize