A two-way switch example to better understand Total Correlation
Recently, I was working on a project that requires learning a latent representation with disentangled factors for high-dimensional inputs. For a brief introduction to disentanglement, while we could use an autoencoder (AE) to compress a high-dimensional input into a compact embedding, there is always dependence among the embedding dimensions, meaning that multiple dimensions always change together in a dependent way. This is undesirable in many scenarios, for example:
- We’d like to train a generative model that maps a latent embedding to a high dimensional output (e.g., image), but meanwhile we wish to control the generation result by modifying only one dimension of the embedding each time. (This also facilitates our interpretation of the embedding space structure.)
- We’d like to train a policy operating on the latent embedding. For training efficiency, we need to restrict the action space from being combinatorial. To do so, we make each action only modify one dimension while it still can result in meaningful changes in the observation space. This is similar in spirit to Independently Controllable Factors.
There are many techniques to enforce disentanglement in the representation learning literature, among which variational autoencoder (VAE) is probably the most known one. A research work in 2019 delved into the KLD term of VAE, and concluded that the most important loss contained by KLD that results in disentanglement is the total correlation (TC) among latent dimensions. (As an aside, the dimension-wise distribution match to a prior also contained by KLD can actually hinder the reconstruction precision.)
Formally, let the latent embedding be $z \in \mathbb{R}^N$ where $N$ is the total dimensionality. The TC among dimensions is defined as
$$ \mathcal{TC}(z)=\mathbb{E}_{p(z)}\log\frac{p(z)}{\prod_n p(z_n)} $$where $p(z_n)$ is the marginal. If all dimensions are independent, then TC becomes 0. TC is a generalization of mutual information (MI), from two variables to multiple variables.
There are many ways to estimate the TC given samples of $z$. Since they are not the main focus of this post, we’ll briefly mention two of them. One trick is the Neural Dependency Measure (NDM) which trains a discriminator to estimate the KLD between the original $z$ and that shuffled along the batch axis. Another technique reduces TC recursively to a tree-like or chain-like path of MI calculations between subsets of dimensions.
Before I found the above literature about TC definition and estimation, in my project I tended to believe that I should encourage minimal MI between every two dimensions $0\le n,m \le N-1$, so as to enforce disentanglement. Mathematically, my attempted objective is to minimize
$$ \mathcal{TMI}(z)\triangleq\sum_{n,m,n\neq m}\mathbb{E}_{p(z_n,z_m)}\log\frac{p(z_n,z_m)}{p(z_n)p(z_m)} $$It’s easy to verify that $\mathcal{TMI}(z)\neq\mathcal{TC}(z)$ by math. And after thinking for a while, I realized that $\mathcal{TMI}$ is a weaker minimization objective than $\mathcal{TC}$. Why? Because MI/TC tells us how much we can infer about the value of a variable given other variable values. The more variables (clues) we have, the more likely that their dependence is non-zero.
But can we actually come up with an intuitive example? Below I show a simple example of three random variables where $\mathcal{TMI}(z)=0$ but $\mathcal{TC}(z)>0$, for a better intuitive understanding of TC.
Suppose that we have three Bernoulli variables: $s_1,s_2,l$, where $s_i$ denotes the state of switch $i$ and $l$ denotes the state of a light. The way how the switches work is $l=s_1\ XOR\ s_2$, in other words:
$l$ | $s_1$ | $s_2$ |
---|---|---|
0 | 0 | 0 |
1 | 1 | 0 |
1 | 0 | 1 |
0 | 1 | 1 |
Namely, the light is on only when the two switches are not in the same state. This is called two-way switching for lighting. This design is usually used in a scenario where one wants to control the light at two places that are far apart. For example, the two switches can be placed at the two ends of a long hallway, controlling the light in the middle. (I realized that there is in fact a two-way switching design in my own kitchen!)
Now suppose $p(s_i=1)=0.5$ for either switch, and they are independent. It’s easy to see that the marginal $p(l=1)=p(s_1=1,s_2=0)+p(s_1=0,s_2=1)=0.5$. Also the joint $p(s_i,l)=p(s_i)p(l|s_i)=0.25$. By definition, $\mathcal{TMI}(z)=0$ while
$$ \mathcal{TC}(z)=\mathbb{E}_{p(s_1,s_2,l)}\log\frac{p(s_1,s_2,l)}{p(s_1)p(s_2)p(l)}=\frac{4}{8}\log \frac{0.5^2}{0.5^3}\approx 0.347 $$That is, in this scenario it’s impossible to infer the state of a switch/light without looking at both remaining variables.
What does this example imply? It implies that if I use $\mathcal{TMI}(z)$ to enforce disentanglement in the latent representation, I cannot completely remove the dependence existing in multiple latent dimensions such as “when the light is on, two switches have a special pattern”. Thus the learned representation doesn’t in theory have total disentanglement. In fact, I need to rely on $\mathcal{TC}(z)$ for this purpose.