MLP-Mixer: An all-MLP Architecture for Vision

时间:2021-09-06
本文章向大家介绍MLP-Mixer: An all-MLP Architecture for Vision,主要包括MLP-Mixer: An all-MLP Architecture for Vision使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

https://arxiv.org/pdf/2105.01601

----------------------------------------------------------

2021-09-02

感知机:判别模型  线性二分类

 token-mixing:作用于列,混合提炼不同patch的特征      depth-wise conv

channel-mixing:作用于行,混合提炼不同channel的特征   1*1卷积

       

class PreNormResidual(nn.Module):
    def __init__(self,dim,fn):
        super(PreNormResidual, self).__init__()
        self.fn=fn
        self.norm=nn.LayerNorm(dim)

    def forward(self,x):
        return self.fn(self.norm(x))+x


def FeedForward(dim,
                expansion_factor=4,
                dropout=0,
                dense=nn.Linear):
    return nn.Sequential(
        dense(dim,dim*expansion_factor),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(dim*expansion_factor,dim),
        nn.Dropout(dropout)
    )


def MLPMixer(*,image_size,channels,patch_size,dim,depth,num_classes,
             expansion_factor=4,
             dropout=0):
    num_patches=(image_size//patch_size)**2
    chan_first,chan_last=partial(nn.Conv1d,kernel_size=1),nn.Linear

    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)",p1=patch_size,p2=patch_size),
        nn.Linear((patch_size**2)*channels,dim),

        *[
            nn.Sequential(
                PreNormResidual(dim,FeedForward(num_patches,expansion_factor,dropout,chan_first)),
                PreNormResidual(dim,FeedForward(dim,expansion_factor,dropout,chan_last))
            )
        for _ in range(depth)],

        nn.LayerNorm(dim),
        Reduce("b n c -> b c","mean"),
        nn.Linear(dim,num_classes)
    )

原文地址:https://www.cnblogs.com/shuimobanchengyan/p/15218928.html