Skip to content

VariationalEncoder

latent.modules.encoder.VariationalEncoder

Variational encoder. This model compresses input data by parameterizing a latent distribution that is regularized through a KL Divergence loss.

__init__(self, latent_dim=50, name='variational_encoder', initializer='glorot_normal', use_decomposed_kld=False, x_size=1000, kld_weight=0.0001, tc_weight=0.001, capacity=0.0, prior='normal', latent_dist='independent', iaf_units=[256, 256], n_pseudoinputs=200, **kwargs) special

Parameters:

Name Type Description Default
latent_dim int

Integer indicating the number of dimensions in the latent space.

50
name str

String indicating the name of the model.

'variational_encoder'
initializer Union[str, Callable]

Initializer for the kernel weights matrix (see keras.initializers)

'glorot_normal'
use_decomposed_kld bool

Boolean indicating whether to use the decomposed KLD loss (Chen 2019)

False
x_size int

Total number of data points. Only used if use_decomposed_kld = True.

1000
kld_weight float

Float indicating the weight of the KL Divergence regularization loss. If use_decomposed_kld = True, this indicated the weight of the dimension-wise KLD.

0.0001
tc_weight float

Float indicating the weight of the total correlation term of the KLD loss. Only used if use_decomposed_kld = True.

0.001
capacity float

Capacity of the KLD loss. Can be linearly increased using a KL scheduler callback.

0.0
prior Literal['normal', 'iaf', 'vamp']

The choice of prior distribution. One of the following:

  • 'normal' - A unit gaussian (normal) distribution.
  • 'iaf' - A unit gaussian with a Inverse Autoregressive Flows bijector (Kingma 2016)
  • 'vamp' - A variational mixture of posteriors (VAMP) prior (Tomczak 2017)
'normal'
latent_dist Literal['independent', 'multivariate']

The choice of latent distribution. One of the following:

  • 'independent' - A independent normal produced by tfpl.IndependentNormal.
  • 'multivariate' - A multivariate normal produced by tfpl.MultivariateNormalTriL.
'independent'
iaf_units Iterable[int]

Integer list indicating the units in the IAF bijector network. Only used if prior = 'iaf'.

[256, 256]
n_pseudoinputs int

Integer indicating the number of pseudoinputs for the VAMP prior. Only used if prior = 'vamp'.

200
**kwargs

Other arguments passed on to DenseStack.

{}