3D GAN — 3D Generative Adversarial Networks for Volume Generation and Classification
The problem of near-perfect image generation was smashed by the DCGAN in 2015 and taking inspiration from the same, MIT CSAIL came up with 3D-GAN (published at NIPS 2016) which generated near perfect voxel mappings.
We propose a novel framework, namely 3D Generative Adversarial Network (3D-GAN), which generates 3D objects from a probabilistic space by leveraging recent advances in volumetric convolutional networks and generative adversarial nets. The benefits of our model are three-fold: first, the use of an adversarial criterion, instead of traditional heuristic criteria, enables the generator to capture object structure implicitly and to synthesize high-quality 3D objects; second, the generator establishes a mapping from a low-dimensional probabilistic space to the space of 3D objects; third, the adversarial discriminator provides a powerful 3D shape descriptor which, learned without supervision, has wide applications in 3D object recognition.
All code related to this blog post can be found at: meetps/tf-3dgan

Architecture
The architecture of 3D-GAN is very intuitive with the generator consisting of deconvolutions that upsample a high-channeled input feature map to a lower channeled output feature map, and the discriminator just mirrors the generator but consists of strided convolutions. One point to note is that there is not a single fully connected layer in the network, nor at the generator-start nor at discriminator ending. It is fully convolutional in its true sense.

Data
The data (available only in cube_length=32) can be found on the Princeton's ShapeNet Website.
Download and extract the data and change the path appropriately in dataIO.py. The 3D-GAN takes a volume with cube_length=64, so I've included the upsampling method in the dataIO.py.
Training
Training the 3D-GAN is a non-trivial task, especially if you don't know the exact hyperparameters and tricks. I used tricks from Soumith's ganhacks repo. I faced a lot of problems during the training, the hyperparameters you see in the code are finetuned for best results, change them at your own peril.
Here's a bunch of things [with observations] I tried until I got to unleash the beast:
- Changed the loss of generator to be
max(log(D_z))instead ofmin(log(1-Dz))as former has vanishing gradients [better numerical convergence of loss] - Used sampling from Gaussian (0, 0.33) instead of uniform sampling [Faster convergence and better object generation]
- Increased the learning rate of discriminator to 10e-3 [Generator is unable to cope up and discriminator accuracy reaches near 100%]
- Autoencoder: Pretraining [Autoencoder converges to trivial solution (i.e all voxels have values near 0.5) with box-like artifacts]
- Autoencoder: Adding L2 loss of weights to autoencoder [No Change]
- Autoencoder: Attempted changing the optimizer from Adam to RMSProp [No change]
- Autoencoder: Clipped the values below 0.5 to zero and above 0.5 to 1 forcing the network to stop reaching trivial solution [Some non-zero voxels are now visible near desired locations, but there is little or no similarity in structure to a chair]
Training the 3dgan_mit_biasfree.py for about 20000 batches and you'll be good to go.
Variants
As mentioned above I tried several things and variants of the 3D-GAN, I've made separate files for each of them, briefly summarized as:
| File | Description |
|---|---|
3dgan_mit_biasfree.py | 3dgan as mentioned in the paper, with same hyperparams. |
3dgan.py | Baseline 3dgan with fully connected layer at end of discriminator. |
3dgan_mit.py | 3dgan as mentioned in the paper with bias in convolutional layers. |
3dgan_autoencoder.py | 3dgan with support for autoencoder based pre-training. |
3dgan_feature_matching.py | 3dgan with additional loss of feature matching of last layers. |
dataIO.py | Data input output and plotting utilities. |
utils.py | TensorFlow utils like leaky_relu and batch_norm layer. |
Generation and Visualization
It's important to visualize the results as the network trains. I used visdom (fantastic work by FAIR). I personally prefer visdom over tensorboard as it's very easy to use with a sleek interface unlike the complex tensorboard. Visdom also supports visualizing 3D Volumes with plotly!