Skip to content

Reproduce Results on Galaxy Dataset #2

@mmubeen-6

Description

@mmubeen-6

HI @tbepler, I am trying to reproduce the results of your paper on the galaxy dataset but unable to exactly achieve those. Could you please share the exact training parameters. I am currently using the following command to train it.

python3 train_galaxy.py galaxy_zoo/galaxy_zoo_train.npy galaxy_zoo/galaxy_zoo_test.npy -d 0 --num-epochs 300 --save-prefix galaxy_zoo_models/testing -z 100 --minibatch-size 100 --dx-scale 0.125 .

Moreover, in order visualize the reconstructed images, I am using the following code snippet. Please have a look at it.

def get_reconstruction(iterator, x_coord, p_net, q_net, img_size=64, rotate=True, translate=True, dx_scale=0.1, theta_prior=np.pi
                        , augment_rotation=False, z_scale=1, use_cuda=False):
    
    def decode_tensor(input_tensor, img_size):
        input_tensor = input_tensor.view(input_tensor.shape[0], img_size, img_size, 3)
        input_tensor = input_tensor.cpu().detach().numpy()

        input_tensor = input_tensor.clip(0., 1.)
        input_tensor = input_tensor * 255.
        input_tensor = input_tensor.reshape(img_size, img_size, 3)
        input_tensor = input_tensor.astype("uint8")

        print(input_tensor.shape, input_tensor.dtype)
        return input_tensor
    
    for y, in iterator:
        b = y.size(0)
        assert b == 1
        x = Variable(x_coord)
        y = Variable(y)

        x = x.expand(b, x.size(0), x.size(1))
        n = int(np.sqrt(y.size(1)))

        if use_cuda:
            y = y.cuda()
            
        # first do inference on the latent variables
        z_mu,z_logstd = q_net(y_rot.view(b,-1))
        z_std = torch.exp(z_logstd)
        z_dim = z_mu.size(1)

        # draw samples from variational posterior to calculate
        # E[p(x|z)]
        r = Variable(x.data.new(b,z_dim).normal_())
        z = z_std*r + z_mu
        
        if rotate:
            # z[0] is the rotation
            theta_mu = z_mu[:,0]
            theta_std = z_std[:,0]
            theta_logstd = z_logstd[:,0]
            theta = z[:,0]
            z = z[:,1:]
            z_mu = z_mu[:,1:]
            z_std = z_std[:,1:]
            z_logstd = z_logstd[:,1:]

            # calculate rotation matrix
            rot = Variable(theta.data.new(b,2,2).zero_())
            rot[:,0,0] = torch.cos(theta)
            rot[:,0,1] = torch.sin(theta)
            rot[:,1,0] = -torch.sin(theta)
            rot[:,1,1] = torch.cos(theta)
            x = torch.bmm(x, rot) # rotate coordinates by theta

            # use modified KL for rotation with no penalty on mean
            sigma = theta_prior

        if translate:
            # z[0,1] are the translations
            dx_mu = z_mu[:,:2]
            dx_std = z_std[:,:2]
            dx_logstd = z_logstd[:,:2]
            dx = z[:,:2]*dx_scale # scale dx by standard deviation
            dx = dx.unsqueeze(1)
            z = z[:,2:]

            x = x + dx # translate coordinates

        z = z*z_scale

        # reconstruct
        y_hat = p_net(x.contiguous(), z)
        y_hat = y_hat.view(b, -1, 3)

        input_image = decode_tensor(y_rot, img_size)
        recon_image = decode_tensor(y_hat, img_size)

        import matplotlib.pyplot as plt
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize =(15, 6))
        ax1.imshow(input_image)
        ax2.imshow(recon_image)
        fig.savefig('foo.png')
        plt.show(fig)

        break

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions