Causal Machine Learning – Introduction

May 8th, 2019

Prof. Robert Ness

Khoury College of Computer Sciences – Northeastern University

1. Introduction

Discovering the direction of causality has been of interest for scientists for a long time. Researchers devote many years of their lives to discover and validate a single direction of cause in a big system.

Examples of problems in causal inference

To properly contextualize our motivation, we start by understanding how causal inference developed as a field across domains, including economics, biology, social science, computer science, anthropology, epidemiology, statistics.

Estimation of causal effects

The problem of finding causal effects is the primary motivation of researchers in these domains. For example, in the late 80s and 90s, doctors used to prescribe Hormonal replacement therapy to old women. Experts believed that at the lower age, women have a lower risk of heart disease than men do, but as they age, after menopause, their estrogen level decline. However, after doing a large randomized trial, where women were selected randomly and given either a placebo or estrogen, the results showed that taking estrogen increases the chance of getting heart disease. Causal inference techniques are essential because the stakes are quite high.

Counterfactual reasoning with statistics

Counterfactual reasoning means observing reality, and then imagining how reality would have unfolded differently had some causal factor been different. For example, “had I broken up with my girlfriend sooner, I would be much happier today” or “had I studied harder for my SATs, I would be in a much better school today.” An example of a question from an experimental context would be “This subject took the drug, and their condition improved. What is the difference between this amount improvement and the improvement they would have seen had they taken placebo?”

Counterfactual reasoning is fundamental to how we as humans reason. However, statistical methods are generally not equipped to enable this type of logic. Your counterfactual reasoning process works with data both from actual and hypothetical realities, while your statistical procedure only has access to data from actual reality.

The same is true of cutting-edge machine learning. Intuition tells us that if we trained the most powerful deep learning methods to provide us with relationship advice based on our romantic successes and failers, something would be lacking in that advice since those counterfactual outcomes are missing from the training data.

The challenge of running experiments

In traditional statistics, randomized experiments are the gold standard for discovering the causal effect. An example of a randomized experiment is an A/B test on a new feature in an app. We randomly assign users to two groups and let one group use the feature while the other is presented with a control comparison. We then observe some key outcome, such as conversions. As we will learn, the randomization enables us to conclude the difference between the two groups is the causal effect of the feature on the conversions, because it isolates that effect from other unknown factors that are also affecting the conversions.

However, in many instances, setting up this randomization might be complicated. What if users object to not getting a feature that other users are enjoying? What if the experience of the feature and probability of conversion both depend on user-related factors, such that it is unclear how to do proper randomization? What if some users object to being the subjects of an experiment? What if it is unethical to do the experiment?

What is left out

Causal inference spans to many other concepts and we won’t be able to cover all of them. Though the concepts below are important, they are out of scope for this course.

  • Causal discovery
  • Causal inference with regression models and various canonical SCM models
  • Doubly-robust estimation
  • Interference due to network effects (important in social network tech companies like Facebook or Twitter)
  • heterogeneous treatment effects
  • deep architectures for causal effect inference
  • causal time series models
  • algorithmic information theory approaches to causal inference


This course will rely heavily on the following books:

  • Pearl, Judea. Causality. Cambridge university press, 2009.
  • Peters, Jonas, Dominik Janzing, and Bernhard Schölkopf. Elements of causal inference: foundations and learning algorithms. MIT Press, 2017.
    While not necessary for the course, these books are worth buying just to have as a reference.
  • Scutari and Denis’s Bayesian networks: with examples in R, publish Chapman and Hall in 2014.

2. Causal modeling as an extension of generative modeling

Generative Vs. Discreminiative Models

Let’s focus on supervised learning for a moment. We know from our machine learning course that, a discreminiative model directly estimates the likelihood \(P(Y \mid X)\). The goal here is, how to best distinguish the classes or estimate y given the observation.

Generative model learns the joint distribution \(P(X, Y)\) underlying the data. We will discuss this more in later lectures. In simple words, these models can generate data that looks like real data.

Why we are focusing on generative models? Because the physical causal mechanisms is the reason why we observe something. For example, a city with high altitude will experience lower termperature. In other words, these causal mechanisms generate data. Thus, a causal model is a generative model!

Model-based ML and learning to think about the data-generating process

The following is the typical checklist in training a statistical machine learning model.

  1. Split the data into training and test sets.
  2. Choose a few models from literally thousands of algorithm choices. Typically this choice is limited algorithms you are familiar with, are in vogue, or happen to be implemented in the software you have available.
  3. Manipulate the data until it fits your algorithm inputs and outputs.
  4. Evaluate the model on test data, compare to other models
  5. ( optional) If data doesn’t fit the algorithms modeling assumptions, manipulate the data until it does.
  6. (optional) If using a deep learning algorithm, search for hyperparameter settings that further optimize prediction.

This process works well. However, in this workflow, the data scientist’s time is devoted to manipulating data, hyperparameters, and often, the problem definition itself until things work.

An alternative approach is to think hard about the process that generated the data, and then explicitly building your assumptions about that process into a bespoke solution tailored to each new problem. This approach is model-based machine learning. Proponents like this approach because with an excellent model-based machine learning framework, you can create a bespoke solution to pretty much any problem, and don’t need to learn a vast number of machine learning algorithms and techniques.

Most interestingly, with the model-based machine learning approach the data scientists shifts her time from transforming her problem to fit some standard algorithm, to thinking hard about the process that generated the problem data, and then building those assumptions into the designing of the algorithm.

We’ll see in this class that when we think about the data-generating process, we think about it causally, meaning it has some ordering of causes and effects. In this class, we formalize this intuition by apply causal inference theory to model-based machine learning.

Note on reinforcement learning

As reinforcement learning gains in popularity amongst machine learning researchers and practitioners, many may have encountered the term “model-based” for the first time in a reinforcement learning (RL) context. Model-based RL is indeed an example of model-based machine learning.

  1. Model-free RL. The agent has no model of the generating process of the data it perceives in the environment; i.e., how states and actions lead to new states and rewards. It can only learn in a Pavlovian sense, relying solely upon experience.
  2. Mode-based RL: The agent has a model of the generating process of the data it perceives in the environment. This model enables the agent to make use not only of experience but also of model-based predictions of the consequences of particular actions it has less experience performing.

In supervised learning, we evaluate bunch of models with some statistics or goodness of fit. If we don’t meet the target, we change and transform the feature such that we meet the target. We mutate the data until we fit the model of our choice.

What is missing here is, a deep thought about how the data is actually generated. Let’s understand this with some examples.

Case Studies

From linear regression to model-based machine learning

The standard Gaussian linear regression model is represented as follows:

$$Y = \beta X + \alpha + \epsilon\text{, }\epsilon \sim \mathcal{N}(0,1)$$

When we read this model specification, it is natural to think of it as predictors \(X\) generating target variable \(Y\). Indeed, the term generates feels a lot like causes here. Usually, we moderate this feeling by remembering that linear regression models only correlation, and we could just have easily regressed \(X\) on \(Y\). In this course, we learn how to formalize that feeling.

We can turn this model into a generative model by placing a marginal distribution on X.

\epsilon &\sim \mathcal{N}(0,1)\nonumber\
X &\sim P_X\nonumber \
Y &= \beta X + \alpha + \epsilon \nonumber

At this point, we are already telling a data generating story where \(Y\) comes from \(X\) and \(\epsilon\). Now let’s expand on that story. Suppose we observe that \(Y\) is measured from some instrument, and we suppose that this instrument is adding technical noise to \(Y\). Now the regression model becomes a noise model.

\epsilon &\sim \mathcal{N}(0,1)\nonumber\
X &\sim P_X\nonumber \
Z &\sim P_Z \nonumber\
Y &= \beta X + \alpha + \epsilon + Z \nonumber

Binary classifier

The logistic regression model has the form:

$$\mathbb{E}[Y] = \texttt{logit}(\beta X + \alpha)$$

If we read this formula, it reads as Y comes from X. Of course that is not true, this model doesn’t care whether Y comes from X or vice versa, in other words, it doesn’t care how Y is generated, it merely wants to model \(P(Y=1|X)\).

In contrast, a naive Bayes classifier models \(P(X, Y)\) as \(P(X|Y)P(Y)\). \(P(X|Y)\) and \(P(Y)\) are estimated from the data, and then we use Bayes rule to find \(P(Y|X)\) and predict \(Y\) given \(X\). \(P(X|Y)P(Y)\) is a representation of the data generating process that reads as “there is some unobserved \(Y\), and then we observe \(X\) which came from \(Y\).” There is nothing that forces us to apply naive Bayes only in problems where the generation of the prediction target generation of the features. Yet, this is precisely the kind of problem where this approach tends to get applied, such as spam detection. I argue that it \(P(X|Y)P(Y)\) aligns with a causal intuition that \(X\) comes from \(Y\), and we avoid the inner cringe that comes from using naive Bayes when we suspect that \(Y\) comes from \(X\). Causal modeling gives us a language to formalize this intuition.

Gaussian Mixture Model:

Recap: Let’s recap the intuition with simple GMM with two gaussians. We observe the data and realize that it is comming from two Gaussians with mean say \(\mu_1\), and \(\mu_2\). Now Let \(Z_i\) be a binary variable that says \(X_i\) belongs to one of these distributions.


  • Some probabilistic process generated \(\mu\).
  • Some other Direchlet process generated \(\theta\).
  • Then a discrete distribution with parameter \(\theta\) generated \(Z_1\).
  • \(Z_1\) and all the \(\mu\) generated \(X_1\).

We can code this story real quick in a modern probabilistic programming language.

 alphas = 0.5*torch.ones(K)
 theta = sample('theta', Direchlet(alphas))

 for each Mixture component k: 
        mu[k] = sample('mu_k', Normal(0, Sigma))

 for each data_point:
        Z[i] = sample('Z_i', Discrete(theta))
        X[i] = sample('X_i', Normal(mu[Z[i]], scale))

The plate model shows exactly this narrative which can be casted as a graphical story of how the data is generated. In fact, any plate model like Hidden Markov model, mixed membership models like LDA, Linear factor models exactly cast the story of how the data is generated. A probabilistic program can be seen as a discription of data generating process. In the next lecture, we will cover basics of Pyro – a PyTorch based probabilistic programming language (PPL). We will also learn to code these well known models into PPL.

Deep generative models

Deep generative models are generative models that use deep neural network architectures. Examples include variational autoencoders and generative adversarial networks. Rather than make the data generation story explicit, their basic implementation compresses all generative process into a latent encoding. But nothing is forcing them to do so. In this course, we will see examples of deep generative models where we model the critical components of the data generating process explicitly, and let the latent encoding handle nuisance variables that we don’t care about.

Inferring the correct story is really hard!

We can think of deep models as to models that learn optimal circuit given input signal and output channels. In contrast, data generating process, you’re writing a program. Inferring a program is a lot harder than inferring a circuit. In fact, that is an ill-specified problem, because there are numerous programs we could write to generate the same data. From Algorithmic information theory, finding the shortest program that generates the output is an NP-hard problem.

This is where domain knowledge comes to rescue! Let’s say an economist has some models in mind about supply and demand, or he wants to test some behavioral or cognitive model. Every model that stood the test of time, is backed by extensive research and is validated rigorously. And just like we can test a machine learning model, we can test the performance of the model after updating with the modeling assumptions from domain knowledge.

Inferring Latent variables is also hard!

The well-known generative models like GMM, we have nice modeling assumptions that enable fast inference. In general, inference is hard. Probabilistic machine learning community focuses heavily on tricks like conjugate priors, conditional independence that makes inference easy. We can consider all of those tricks as modeling assumption, thus part of the story of our belief of how the data was generated.

Once we are done with the story, PPL should be able to do out-of-the-box inference.

Causal Bayesian Networks

Bayesian network explicitly models a joint distribution by factorizing it to conditionals. We can represent the factorization with a DAG (Directed Acyclic Graph). We will cover them in depth in later lectures. There can be multiple different factorizations that represents same distribution. For example, we can prove that all of the factorizations below entail same joint distribution over \(X\), \(Y\) and \(Z\).

$$P(X,Y,Z) = P(X\mid Y)P(Y\mid Z)P(Z)$$

$$P(X,Y,Z) = P(Z\mid Y)P(Y\mid X)P(X)$$

$$P(X,Y,Z) = P(Z\mid Y, X)P(Y\mid X)P(X)$$

So the question is, which factorization should we choose? It’s simple, whatever fits our belief of the story of how the data should be generated. Then, the directions follows causality. We call them Causal Bayesian networks. Here is the famous example of Lung cancer Bayesian network. We could entail the same distribution with the visit to Asia to Lung Cancer, but we know that Visit to Asia does not cause Lung Cancer.


Inference in Bayesian network

We are not going to focus too much into inference algorithms in this class, but it is important to note that they are important and often hard. Below are some of the simplest inference techniques for the Bayesian network above.

Suppose we observe positive X-Ray. How do we infer that a person has tuberculosis?

One way for this is computing the conditional probability and marginalize appropriately according to the query. Here, \(\sum_A\) means marginalize over \(A\).

$$P(T \mid X) = \frac{P(X,T)}{P(X)} = \frac{\sum_{A, L, E, S, B, D}P(A, T, L, X, E, S, B, D)}{\sum_{A, L, T, E, S, B, D}P(A, T, L, X, E, S, B, D)}$$

As you can see, this is not always scalable and becomes harder as the network size grows. Also, this approach is viable when all the variables are discrete. Things become complicated when variables are continuous.

The other way is sampling techniques. In rejection sampling, if the observation doesn’t match with what we’re looking for, we throw them out. So for this question, we will reject all the samples who had negative X-Ray.

As we saw while coding GMM story, it has for loops and other complicated things. Inference of unobserved latent variables on these programs are a lot harder and we generally use advanced sampling techniques like HMC.

Coming up next..

In the next lecture, we will give an overview of tools we’re going to use extensively in this course: bnlearn and Pyro. bnlearn is a simple but yet powerful package in R to work with Bayesian networks, and Pyro is a universal probabilistic programming language (PPL) written in Python and supported by PyTorch on the backend.

Leave a Reply

Your email address will not be published. Required fields are marked *

Name *