企业网站首页怎么优化,做网站和管理系统,移动互联网开发招人,免费涨热度软件论文链接#xff1a;https://arxiv.org/pdf/2304.03977.pdf
代码#xff1a;https://github.com/tsb0601/EMP-SSL
其他学习链接#xff1a;突破自监督学习效率极限#xff01;马毅、LeCun联合发布EMP-SSL#xff1a;无需花哨trick#xff0c;30个epoch即可实现SOTA 主要…论文链接https://arxiv.org/pdf/2304.03977.pdf
代码https://github.com/tsb0601/EMP-SSL
其他学习链接突破自监督学习效率极限马毅、LeCun联合发布EMP-SSL无需花哨trick30个epoch即可实现SOTA 主要思想
如图一张图片裁剪成不同的 patch对不同的 patch 做数据增强分别输入 encoder得到多个 embedding对它们求均值得到 作为这张图片的 embedding。最后拉近每个 patch 的 embedding 和图片的 embedding之间的余弦距离再用 Total Coding Rate(TCR) 防止坍塌即 encoder 对所有输入都输出相同的 embedding Total Coding Rate(TCR)
公式如下 其中det 表示求矩阵的行列式d 是 feature vector 的 dimensionb 是 batch size
查了查该公式的含义expand all features of Z as large as possible即尽可能拉远矩阵中特征之间的距离。
源自 PPT 第 24 页
https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf
至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~ 核心代码解读
数据处理
https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27
class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch 4):self.num_patch num_patchdef __call__(self, x):normalize transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform transforms.Compose([transforms.RandomResizedCrop(32,scale(0.25, 0.25), ratio(1,1)),transforms.RandomHorizontalFlip(p0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p0.8),transforms.RandomGrayscale(p0.2),GBlur(p0.1),transforms.RandomApply([Solarization()], p0.1),transforms.ToTensor(), normalize])augmented_x [aug_transform(x) for i in range(self.num_patch)]return augmented_x
由此看出返回的 数据 为长度为 num_patches 个 tensor 的列表。其中每个 tensor 的 shape 为 (B, C, H, W)。 主函数
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63
for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data torch.cat(data, dim0) data data.cuda()z_proj net(data)z_list z_proj.chunk(num_patches, dim0)z_avg chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ contractive_loss(z_list, z_avg)loss_TCR cal_TCR(z_proj, criterion, num_patches)
这里要稍微注意一下几个变量的 shape
data 被 cat 完后num_patches * BCHWz_projnum_patches * BCz_listnum_patchesBCz_avgBC
其中chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67
def chunk_avg(x,n_chunks2,normalizeFalse):x_list x.chunk(n_chunks,dim0)x torch.stack(x_list,dim0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim1) loss
contractive_loss 就是计算每个 patch 的 embedding 和均值的余弦距离
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76
class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim 0num_patch len(z_list)z_list torch.stack(list(z_list), dim0)z_avg z_list.mean(dim0)z_sim 0for i in range(num_patch):z_sim F.cosine_similarity(z_list[i], z_avg, dim1).mean()z_sim z_sim/num_patchz_sim_out z_sim.clone().detach()return -z_sim, z_sim_out TCR loss最大化矩阵之间特征的距离即拉远负样本不是来自同一个样本的 patches之间的距离
https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96
def cal_TCR(z, criterion, num_patches):z_list z.chunk(num_patches,dim0)loss 0for i in range(num_patches):loss criterion(z_list[i])loss loss/num_patchesreturn loss
需要注意函数输入的 z 是 z_proj形状为num_patches * BC。
所以函数内部 z_list 的形状为num_patchesBC即将数据分为了 num_patches 个组每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss最大化组内不同图片的 patch特征的距离。
所以公式中的 指的是一组来自不同图片里 patch 的 embedding形状为BC。
每个组内求 TCR loss 的代码按照公式计算如下 https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76
class TotalCodingRate(nn.Module):def __init__(self, eps0.01):super(TotalCodingRate, self).__init__()self.eps epsdef compute_discrimn_loss(self, W):Discriminative Loss.p, m W.shape #[d, B]I torch.eye(p,deviceW.device)scalar p / (m * self.eps)logdet torch.logdet(I scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)