pytorch Sub-pixel Convolution Pixel Shuffle 3D代码

import torch, torchvision
import torch.nn.functional as F
dtype = torch.float
device = torch.device("cuda:0")

# https://github.com/pytorch/pytorch/pull/5051/commits/60984f01602bf36c3de3b88dad732440a7e98f0c?diff=split

X = torch.range(1, 9).view(1, 9, 1, 1)
print(X.size(), X.data.view(-1)) 
Y = F.pixel_shuffle(X, 3)
print(Y.size()) 
print(Y.data) 

def pixel_shuffle_3d(input, upscale_factor):
    batch_size, channels, in_height, in_width, in_depth = input.size()
    channels //= upscale_factor ** 3
    
    out_height = in_height * upscale_factor
    out_width = in_width * upscale_factor
    out_depth = in_depth * upscale_factor

    input_view = input.contiguous().view(
    batch_size, channels, upscale_factor, upscale_factor, upscale_factor,
        in_height, in_width, in_depth)

    shuffle_out = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
    return shuffle_out.view(batch_size, channels, out_height, out_width, out_depth)

X3 = torch.range(1, 27).view(1, 27, 1, 1, 1)
Y3 = pixel_shuffle_3d(X3, 3)
print(Y3.size()) 
print(Y3.data)


发表于:2018-05-25 15:18:25

原文链接(转载请保留): http://www.multisilicon.com/blog/a25332339.html

友情链接: MICROIC
首页