In the last post I tried to summarise the rationale behind the free energy principle and the resulting active inference dynamics. I will assume you read this post and have a rough idea about the free energy principle for action and perception. Now I want to get concrete and introduce you to what I call a “Deep Active Inference” agent. While I tried to minimize formalism in the last post, I’m sorry that it will get quite technical now. But I will try and be very specific in this post, to allow people to implement their own versions of this type of agent. You can find a reference implementation here.
Our agent will live in a discrete version of the mountaincar world. It will start at the bottom of a potential landscape that looks like this:
Our agent is a small car on this landscape, whose true, physical state is given by its current position and velocity . We will use bold face to indicate vectors.
The dynamics are given by a set of update equations, which can be summarised symbolically by
Here the next physical state depends on the current physical state and the action state of the agent. Think of this action state as the current state of the agent’s effector organs, i.e. muscles at time .
This equation can be decomposed into a set of equations. First we have the downhill force acting on our agent, depending on its position
The agent’s motor can generate a force , depending on the agent’s current action state :
Finally, we have a laminar friction force , depending linearly on the agent’s current velocity :
Combined, this yields the total force acting on the agent:
We use this force to update the velocity:
and then update the position:
The initial state of our agent is resting at the bottom of the landscape, i.e.
The goal of our agent will be to reach and stay at . However, looking at the downhill force
we notice that the landscape gets very steep at .
As the action force generated by the motor is limited to the interval due to the function, the agent is not strong enough to climb the slope at , which results in a downhill force of .
Our agent has a noisy sense of its real position :
To show that the agent can indeed learn a complex, nonlinear, generative model of its sensory input, we add another sensory channel with a nonlinear, non-bijective transformation of :
One could use this sensory channel as a “reward” channel, as it encodes the agent’s distance to what we will later define as its “goal position”, . However, right now we will not use it this way, but directly give the agent a strong expectation (i.e. prior) on its noisy sense of position to define its goal in this environment (we will come to that later). But feel free to experiment with the code and try how expectations on this “reward-like” channel might shape the agent’s actions.
Note that the agent has no direct measurement of its current velocity , i.e. it has to estimate it from the sequence of its sensory inputs, which all just depend on .
To see if the agent understands its own actions, we also give it a proprioceptive sensory channel, allowing it to feel its current action state :
Note that having a concrete sense of its action state is *not* necessary for an active inference agent to successfully control its environment. E.g. the reflex arcs that we use to increase the likelihood of our sensory inputs, given our generative model of the world, feature their own, closed loop neural dynamics at the level of the spinal cord. So we do not (and do not have to) have direct access to the “action states” of our muscles, when we just lift our arm by expecting it to lift. However, adding this channel will allow us later to directly sample the agents proprioceptive sensations from its generative model, to check if it understands its own action on the environment.
The Agent’s Generative Model of the Environment
First, the agent has a prior over the hidden states , i.e. the latent variables in its generative model:
with the initial state, before it has encountered any observations, .
This factorisation means that the next state in the agents generative model only depends on the current state of its inner model of the world. In our concrete implementation, we model the distribution of as diagonal Gaussian, where the means and standard deviations are calculated from the current state using neural networks. I.e.
We use to encompass all parameters of the generative model, that we are going to optimise. In practice, this means all the parameters of the neural networks to calculate the means and standard-deviations.
Similarly, the likelihood functions for each of the three observables also factorise
So the likelihood of each observable for a given time only depends on the current state . We also use Gaussian distributions, to obtain
Sampling from the Generative Model
Once the agent has acquired a generative model of its sensory inputs, by optimising the parameters , one can sample from this model simply by propagating batches of processes like this:
- Evaluate using the state sampled at the previous iteration.
- Draw a single sample from this distribution.
- Use this sample to evaluate the likelihoods
- Sample a single observation from each of these likelihoods.
- Carry over to the next timestep.
The fact that we will sample many processes in parallel will allow us to get a good approximation, although we only sample once per individual process and per timestep.
Following Kingma & Welling/Rezende, Mohamed & Wierstra we do not explicitly represent the sufficient statistics of the variational posterior at every time step. This would require an additional, costly optimisation of these variational parameters at each individual time step. Instead we use an inference network, which approximates the dependency of the variational posterior on the previous state and the current sensory input and which is parameterised by time-invariant parameters. This allows us to learn these parameters together with the parameters of the generative model and the action function (c.f. below) and also allows for very fast inference later on. We use the following factorisation for this approximation of the variational density (approximate posterior) on the states , given the agents observations :
and we again use diagonal Gaussians
In general, the action state of the agent could be minimised directly at each timestep, to minimize the free energy by driving the sensory input (which depends on the action state via the dynamics of the world) towards the agents expectations, in terms of its likelihood function. However, this would require a costly optimisation for each timestep (and every single simulated process). Thus, we use the same rationale as Rezende, Mohamed & Wierstra/Kingma & Welling (and we, c.f. previous section) use for the variational density , and approximate the complex dependency of the action state on the agent’s current estimate of the world (and via this on the true state of the world) by fixed, but very flexible functions, i.e. (deep) neural networks. This yields an explicit functional dependency
whose parameters we include, together with the parameters of the generative model and the variational density, to the set of parameters that we will optimise.
Again we use a diagonal Gaussian form and neural networks to calculate the means and standard deviations:
We now can just optimise the time-invariant parameters . This approximation makes learning and propagating the agent very fast, allowing for efficient simultaneous learning of the generative model, the variational density, and the action function.
The approximation of both the sufficient statistics of the variational density and the action states by (possibly) deep neural networks is the reason why we call this class of agents Deep Active Inference agents.
However, in this very concrete case it turned out, that it is enough to just use two single layer networks to calculate the mean and standard deviation of the action state from the agent’s state estimate . This reduces the computational complexity and the number of parameters a bit and allows us to fit the model on a single (four year old) GPU using evolution strategies.
All in all, the causal structure of our model looks like this:
Here the solid lines correspond to the factors of the agent’s generative model, and . The dashed lines correspond to the variational density . The dotted lines correspond to the true generative process and . The wiggly line describes the dependency of the action on the hidden states .
The Free Energy Objective
Now we have everything that we need to put our objective function together. We use the following form of the variational free energy bound (which I discussed in the previous post):
where means the average with respect to the variational density
Using the above factorisation, the free energy becomes:
Ooookay, but how do we evaluate this humongous expression? The idea is that we simulate several thousand processes in parallel, which allows us to approximate the variational density just by a single sample per process and per timestep (analogous to stochastic backpropagation/the variational autoencoder, where only one sample per data point is enough, since the gradients depend on entire (mini-)batches of datapoints):
- For each timestep, we first calculate the means and standard deviations of
, using the sampled state from the previous timestep, where , and the current observations.
- Then we draw a single sample of from this distribution and evaluate the sum of negative log-likelihoods
, approximating the expectation over using just this single sample.
- We calculate the means and standard deviations of the prior using the sample of the last timestep. We then use the closed form of the KL-Divergence for diagonal Gaussians for the second term:
- To generate the observations for the next timestep, we draw a single sample from , using our sample of . We forward the sampled action to the generative process (i.e. the simulated world), to generate the next state of the world and the resulting observations . Initially we use .
- We can now evaluate the next timestep, using the new observations and the sampled state , which will be used as previous state by the next iteration.
As stated above, the fact that we will be running a lot of processes in parallel allows us to resort to this simple sampling scheme for the individual processes.
The minimisation of the free energy with respect to the parameters will improve the agents generative model , by lower bounding the evidence of the observations , given the generative model. Simultaneously it will make the variational density a better approximation of the true posterior , as can be seen from the following, equivalent form of the free energy (c.f. last post):
Additionally, the parameters of the action function will be optimised, so that the agent seeks out expected states under its own model of the world, minimizing .
Optimising its generative model of the world gives the agent a sense of epistemic value. I.e. it will not only seek out the states, which it expects, but it will also try to improve its general understanding of the world. This is very similar to (but in my opinion more natural than) recent work on artificial curiosity. However, in contrast to this work, our agent is not only interested in those environmental dynamics which it can directly influence by its own action, but also tries to learn the statistics of its environment which are beyond its control.
Goal Directed Behavior
If we just propagate and optimise our agent as it is now, it will look for a stable equilibrium with its environment and settle there. However, to be practically applicable to real-life problems, we have to instil some concrete goals in our agent. We can achieve this by defining states that it will expect to be in. Then the action states will try to fulfil these expectations.
In this concrete case, we want to propagate the agent for 30 timesteps and want it to be at for at least the last 10 timesteps. As the agent’s priors act on the hidden states, we introduce a hard-wired state which just represents the agents current position. We do this by hard-coding the first dimension of the state vector to:
This can be seen as a homeostatic, i.e. vitally important, state parameter. E.g. the CO₂ concentration in our blood is extremely important and tightly controlled, as opposed to the possible brightness perceived at the individual receptors of our retina, which can vary by orders of magnitude. Though we might not directly change our behavior depending on visual stimuli, a slight increase in the CO₂ concentration of our blood and the concurring decrease in the pH will trigger chemoreceptors in the carotid and aortic bodies, which in turn will increase the activity of the respiratory centers in the medulla oblongata and the pons, leading to a fast and strong increase in ventilation, which might be accompanied by a subjective feeling of dyspnoea or respiratory distress. These hard-wired connection between vitally important body parameters and direct changes in perception and action might be very similar to our approach to encode the goal-relevant states explicitly.
But besides explicitly encoding relevant environmental parameters in the hidden states of the agent’s latent model, we also have to specify the corresponding prior expectations (such as explicit boundaries for the pH of our blood). We do this by explicitly setting the prior over the first dimension of the state vector for :
While this kind of hard-coded inference dynamics and expectations might be fixed for individual agents of a class (i.e. species) within their lifetimes, these mappings can be optimised on a longer timescale over populations of agents by evolution. In fact, evolution might be seen as a very similar learning process, only on different spatial and temporal scales (c.f. the post of Marc Harper on John Baez’s Blog, and this talk by John Baez).
Without action, the model and the objective would be very similar to the objective functions of Rezende, Mohamed & Wierstra, Kingma & Welling, and Chung & al and we could just sample a lot of random processes and do a gradient descent on the estimated free energy with respect to the parameters , using backpropagation and an optimizer like ADAM. However, to do this here, we would have to backpropagate through the dynamics of the world. I.e. our agent would have to know the equations of motions of its environment (or at least their partial derivatives). As this is obviously not the case, and as many environments are not even differentiable, we have to resort to another approach.
Luckily, it was recently shown that evolution strategies allow for efficient optimisation of non-differentiable objective functions .
Instead of searching for a single optimal parameter set , we will introduce a distribution on the space of parameters, which is called the “population density”. We will optimise its sufficient statistics to minimize the expected free energy under this distribution. The population density can be seen as a population of individual agents, whose average fitness we optimise, hence the name. The expected free energy over this population, as function of the sufficient statistics of the population density is:
Now we can calculate the gradient of with respect to the sufficient statistics :
Using again a diagonal Gaussian for we get for the gradients with respect to the means (Try and calculate it yourself!):
Following the discussion at Ferenc Huszár’s Blog, we will also optimize the standard deviations of our population density, using the following gradients (Try also to derive this on your own!):
Drawing samples from a standard normal distribution , we can approximate samples from by . Thus, we can approximate the gradients via sampling by:
Actually, for reasons of stability we are not optimising directly, but calculate the standard deviations using:
with . By choosing constant and optimising we prevent divisions-by-zero and make sure that there is no sign-switch during the optimisation. The chain rule gives:
Now we have everything: A model of the world, the agent, an objective function which we can evaluate and which has gradients that we can optimise. The final step is to run your optimisation method of choice (mine is ADAM) with the gradient estimates for the means and standard-derivations of our population density and watch!
Note that we do not need any tricks like batch-normalization for our method to converge and that it can run on a single NVIDIA GPU (I tested it on the first generation Titan from 2013), using about 3GB of graphics memory.
Results (Wait? It really worked?)
So let’s look at the convergence of a simple agent with hidden state variables (as implemented here):
It quickly converges from its random starting parameters to a plateau, on which it tries to directly climb the hill and gets stuck at the steep slope. However, after a while (about 10,000 updates of the population density) it discovers, that it can get higher by first moving in the opposite direction, thereby gaining some momentum, which it can use to overcome the steep parts of the slope:
This sudden, rapid decline in the free energy might (okay, I am overinterpreting, but try and stop me…) be our agents first (and in this environment only) “AHA”-moment.
After about 50 000 iterations our agent’s trajectory looks very efficient: It takes a short left swing, to gain just the required momentum, and then it directly swings up to its target position and stays there. Because the environment still is quite sloped there, it needs to counteract the downhill force by just the right motor force .
Also, our agent now has developed quite some understanding of its environment. We can compare the above image, which was generated by actually propagating the agent in its environment, with the following one, which was generated just by sampling from the agents generative model of its environment:
We see that the agent did not only learn the timecourse of its proprioceptive sensory channel and its sense of position , but also the – in this setting irrelevant – channel , which is just a nonlinear transformation of its position. Note that we are plotting here 10 processes sampled from the generative model as described above. Bear in mind that we are approximating each density just by a single sample per timestep and per process. Thus, although our estimates seem quite noisy, they are very consistent with the actual behavior of the agent, and the variability could be ameliorated by averaging over several processes.
Having such a generative model of the environment, we can not only propagate it freely, but we can also use it to test the hypothesis of the agent, given some a priori assumptions on the timecourse of certain states and/or sensory channels. For this we are using the sampling approach developed in Appendix F of Rezende, Mohamed & Wierstra (and shown in their Figure 5). This way one can for example sample the agent’s prior beliefs about his trajectory , given its proprioceptive inputs , i.e. . Thinking about autonomous agents in a real environment (i.e. robots, autonomous vehicles, …) this might be a good way to “ask” the trained models how they think the world might react to their actions, sneaking a peek into the “black-box” associated with classical deep reinforcement learning algorithms. Using the above example we can just take the average timecourse of the proprioceptive channel for the true interaction with the environment and shift it 10 timesteps back and we get:
First, we see that not all of the 10 sampled processes did converge. This might be due to the Markov-Chain-Monte-Carlo-Sampling approach, in which the chain has to be initialized close enough to the solution to guarantee convergence. However, for 9 out of 10 processes, the results look quite reasonable. However, in this example the relative weighting of the agent’s prior expectations on its position is quite strong, leading to the fact that it tightly sticks to the optimal trajectory, as soon as it has learned it. Thus, the dynamics it infers might deviate from the true dynamics. However, if you put the agent in a very noisy environment and take away its effector organs, it reduces to a generative recurrent latent variable model, which has shown to be able to model and generate very complex data.
So what are possible directions to go?
First, there are many degrees of freedom in terms of the used nonlinearities, the recurrent architecture of the a priori state transition function, the action function, the variational posterior, or the likelihood functions. So playing with these might actually be very fun and lead to some significant improvements in performance and understanding.
Second, it would be interesting to test this architecture in a more complex setting. A challenging, but realistic next step might be learning pong from pixels and testing the architecture on the Atari 2600 environments.
Third, one could look into implementing more complex a priori beliefs in the model. This could be done by sampling from the agent’s generative model already during the optimisation. E.g. one could sample some processes from the generative model, calculate the quantity on which a constraint in terms of prior expectations of the agent should be placed, and calculate the difference between the sampled and the target distribution, e.g. using the KL-divergence. Then one could add this difference as penalty to the free energy. To enforce this constraint, one could use for example the Basic Differential Multiplier Method (BDMM), which is similar to the use of Lagrange multipliers, but which can be used with gradient descent optimisation schemes. The idea is to add the penalty term , which is equal to zero if the constraint is fulfilled, to the function to be minimised , scaled by a multiplier . The combined objective would look like this:
To optimise the objective while forcing the penalty term to be zero, one can perform a gradient descent on for the parameters , but a gradient ascent for the penalty parameter . This prospective sampling to optimise the goals of the agent might be actually what parts of prefrontal cortex do when thinking about the future and how to reach certain goals.
Fourth, it might be really interesting to see how episodic memory (i.e. an artificial hippocampus) might enhance the agent’s learning, following DeepMind’s cool work on differentiable neural computers.
I think there are even more possible directions to head from here, that’s why I thought it might be nice to put the idea and the code out here and see what people think about it.