专栏名称: 数据派THU
本订阅号是“THU数据派”的姊妹账号,致力于传播大数据价值、培养数据思维。
目录
相关文章推荐
龙岩发布  ·  龙岩交警最新发布!这些路口违法行为多发→ ·  7 小时前  
InfoTech  ·  2025年工信部职业技术/专项技术认证 ·  18 小时前  
人工智能与大数据技术  ·  Claude ... ·  昨天  
数局  ·  原想&伊肤泉:2025问题肌抗衰白皮书 ·  3 天前  
51好读  ›  专栏  ›  数据派THU

用离散标记重塑人体姿态:VQ-VAE实现关键点组合关系编码

数据派THU  · 公众号  · 大数据  · 2025-05-09 17:00

正文

请到「今天看啥」查看全文


self.numberOfMixerBlocks = numberOfMixerBlocks # N self.codebookTokenDimension = codebookTokenDimension # M self.internalMixerSize = internalMixerSize self.internalMixerTokenSize = internalMixerTokenSize self.mixerDropout = mixerDropout
self.initial_linear = nn.Linear(self.dimensionOfKeypoints, self.linearProjectionSize) # 从BxKxD投影到BxKxH
self.mixer_layers = nn.ModuleList([MixerLayer(self.linearProjectionSize, self.internalMixerSize, self.numberOfKeypoints, self.internalMixerTokenSize, self.mixerDropout) for _ in range(self.numberOfMixerBlocks)]) # BxKxH
self.mixer_layer_norm = nn.LayerNorm(self.linearProjectionSize) # BxKxH
self.token_linear = nn.Linear(self.numberOfKeypoints, self.codebookTokenDimension) # BxHxK -> BxHxM
self.feature_embed = nn.Linear(self.linearProjectionSize, self.codebookTokenDimension)
def forward(self, x):
# 之前: BxDxK x = x.transpose(2,1) # 之后: BxKxD
# 之前: BxKxD x = self.initial_linear(x) # 之后: BxKxH
# 之前: BxKxH for mixer in self.mixer_layers: x = mixer(x) # 之后: BxKxH
# 之前: BxKxH x = self.mixer_layer_norm(x) # 之后: BxKxH
# 之前: BxKxH x = x.transpose(2,1) # 之后: BxHxK
# 之前: BxHxK x = self.token_linear(x) # 之后: BxHxM
# 之前: BxHxM x = x.transpose(2,1) # 之后: BxMxH
# 之前: BxMxH x = self.feature_embed(x) # 之后: BXMxM
return x

编码器接收一组二维关键点坐标,通过基于MLP-Mixer架构设计的网络结构将这些坐标转换为M个潜在标记特征。具体而言,关键点首先被嵌入到高维空间,然后在关节和特征维度之间进行混合处理,最终投影到形状为B × M × D(批量大小×标记数量×标记维度)的输出特征空间。

EMA码本(VQ层)

 class CodebookVQ(nn.Module):      def __init__(self, codebookDimension, numberOfCodebookTokens, decay=0.99, epsilon=1e-5):          super(CodebookVQ, self).__init__()  
self.codebookDimension = codebookDimension self.numberOfCodebookTokens = numberOfCodebookTokens self.decay = decay self.epsilon = epsilon
self.register_buffer('codebook', torch.empty(numberOfCodebookTokens, codebookDimension)) self.codebook.data.normal_()
self.register_buffer('ema_cluster_size', torch.zeros(numberOfCodebookTokens)) self.register_buffer('ema_w', torch.empty(numberOfCodebookTokens, codebookDimension)) self.ema_w.data.normal_()
def forward(self, encode_feat):
M = encode_feat.shape[1] B = encode_feat.shape[0] encode_feat = encode_feat.view(-1, self.codebookDimension) # [B*M, M]
# 计算与码本条目的距离 distances = ( encode_feat.pow(2).sum(1, keepdim=True) - 2 * encode_feat @ self.codebook.t() + self.codebook.pow(2).sum(1) ) # [B*M, num_tokens]
# 找到最近的码本索引 encoding_indices = torch.argmin(distances, dim=1) # [B*M] encodings = F.one_hot(encoding_indices, self.numberOfCodebookTokens).type(encode_feat.dtype) # [B*M, num_tokens]
# 量化输出 quantized = encodings @ self.codebook # [B*M, M] quantized = quantized.view_as(encode_feat) # 重塑回原始输入形状
if self.training:
# EMA更新 ema_counts = encodings.sum(0) # [num_tokens] dw = encodings.t() @ encode_feat # [num_tokens, M]
self.ema_cluster_size.mul_(self.decay).add_(ema_counts, alpha=1 - self.decay) self.ema_w.mul_(self.decay).add_(dw, alpha=1 - self.decay)
n = self.ema_cluster_size.sum() cluster_size = ( (self.ema_cluster_size + self.epsilon) / (n + self.numberOfCodebookTokens * self.epsilon) * n )
self.codebook.data = self.ema_w / cluster_size.unsqueeze(1)
quantized = quantized.view(B, M, M) encoding_indices = encoding_indices.view(B, M)
return quantized, encoding_indices

潜在标记通过向量量化层进行离散化处理,该层采用指数移动平均(EMA)方法更新码本。在这一过程中,每个连续的标记向量都被码本中最相近的离散代码向量替换,从而将姿态表示转化为一组符号化的离散表示。具体实现中:

  • 码本包含num_codes个代码向量条目
  • 每个输入标记根据L2距离独立选择最相近的码本向量
  • 码本在训练过程中通过EMA机制进行自我更新,确保码本适应训练数据分布

姿态解码器

 class PoseDecoder(nn.Module):      def __init__(self, codebookTokenDimension=64, numberOfKeypoints=11, keypointDimension=2, hiddenDimensionSize=128, numberOfMixerBlocks=4, mixerInternalDimensionSize=64, mixerTokenInternalDimensionSize=128, mixerDropout=0.1):          super(PoseDecoder, self).__init__()  
self.codebookTokenDimension = codebookTokenDimension self.numberOfKeypoints = numberOfKeypoints self.keypointDimension = keypointDimension self.hiddenDimensionSize = hiddenDimensionSize self.mixerInternalDimensionSize = mixerInternalDimensionSize self.mixerTokenInternalDimensionSize = mixerTokenInternalDimensionSize






请到「今天看啥」查看全文