[edit]
Elevating Perceptual Sample Quality in PCs through Differentiable Sampling
NeurIPS 2021 Workshop on Pre-registration in Machine Learning, PMLR 181:1-25, 2022.
Abstract
Deep generative models have seen a dramatic improvement in recent years, due to the use of alternative losses based on perceptual assessment of generated samples. This improvement has not yet been applied to the model class of probabilistic circuits (PCs), presumably due to significant technical challenges concerning differentiable sampling, which is a key requirement for optimizing perceptual losses. This is unfortunate, since PCs allow a much wider range of probabilistic inference routines than main-stream generative models, such as exact and efficient marginalization and conditioning. Motivated by the success of loss reframing in deep generative models, we incorporate perceptual metrics into the PC learning objective. To this aim, we introduce a differentiable sampling procedure for PCs, where the central challenge is the non-differentiability of sampling from the categorical distribution over latent PC variables. We take advantage of the Gumbel-Softmax trick and develop a novel inference pass to smoothly interpolate child samples as a strategy to circumvent non-differentiability of sum node sampling. We initially hypothesized, that perceptual losses, unlocked by our novel differentiable sampling procedure, will elevate the generative power of PCs and improve their sample quality to be on par with neural counterparts like probabilistic auto-encoders and generative adversarial networks. Although our experimental findings empirically reject this hypothesis for now, the results demonstrate that samples drawn from PCs optimized with perceptual losses can have similar sample quality compared to likelihood-based optimized PCs and, at the same time, can express richer contrast, colors, and details. Whereas before, PCs were restricted to likelihood-based optimization, this work has paved the way to advance PCs with loss formulations that have been built around deep neural networks in recent years.