Home


Elegance of Linear Regression

\( \def\truedist{{\overline{X}}} \def\allset{{\mathcal{X}}} \def\truesample{{\overline{x}}} \def\augdist{\mathcal{A}} \def\R{\mathbb{R}} \def\normadj{\overline{A}} \def\pca{F^*} \def\bold#1{{\bf #1}} \newcommand{\Pof}[1]{\mathbb{P}\left[{#1}\right]} \newcommand{\Eof}[1]{\mathbb{E}\left[{#1}\right]} \)

10/19/2023


These days, it seems like every hard problem is being solved via neural networks. When I interned as a quantitative trader, one of my biggest takeaways was the insane modelling power of linear regression; I did not realize how the firm was basically thousands of linear fits in a trench coat. Though there were many more sophisticated modelling techniques, it seemed like most success came from linear regression on features that "correctly" connected the data to the real world. This ran contrary to the well-known fact that neural networks are more expressive and powerful predictors.


When I first started observing this, I thought there wasn't a "fundamental reason" to prefer linear regression and the world would eventually "neural network" away all the problems. However, as I dug deeper, I realized the true value of linear regression and how it was way more important than I thought. This blog post is a collection of fun linear regression properties that make it the dominant choice for the right application and metric. Though any arbitrary subsection might be relatively shallow, the combination of all these properties make linear regression robust at predicting relationships and uniquely capable of understanding data.

Introduction

We will consider \(n\) i.i.d. \(d\)-dimensional data points \(x_{i} \in \R^d\) with labels \(y_i \in \R\). We will frequently consider the stacked matrix forms of these quantities as \(X \in \R^{n\times d}\) and \(y \in \R^n\). Usually, the data points are drawn from some covariate distribution and the labels can be determined via an underlying ground-truth function \(y_i = f(x_i)\). This function is typically unknown and is potentially randomized itself.


Linear regression implicitly makes the (incredibly reductionist) assumption that the function \(f\) can be well-approximated by a linear function. Linear regression wants to find \(w \in \R^d\) such that \(Xw\) is close to \(y\), which corresponds to

\[\arg\min_{w\in\R^d} \left|\left|Xw - y\right|\right|_2^2\]

When first presented with this formulation, many of the choices seem quite arbitrary and limiting (why linear? why minimize squared norm? how can i incorporate any more information?). Throughout this post, we will thoroughly motivate these design choices and why its so useful. Hopefully, this post can provide a little insight into what makes linear regression so great.

Optimization

One neat property of linear regression is that it has a closed form solution for a given dataset. To get this closed form, we will set the partial derivative of our objective with respect to \(w\) to \(0\)

\[ \begin{aligned} 0 &= \frac{\partial}{\partial w}\left( (Xw-y)^{\top}(Xw-y) \right) \\ 0 &= \frac{\partial}{\partial w} \left(y^{\top}y - 2(Xw)^{\top}y + (Xw)^{\top}Xw \right) \\ 0 &= - 2X^{\top}y + 2X^{\top}Xw \\ w &= (X^{\top}X)^{-1}X^{\top}y \\ \end{aligned} \]

Therefore, when \(X^{\top}X\) is invertible, we have a uniquely defined closed form solution to linear regression! It is worth noting that this matrix is invertible when the columns of \(X\) are linearly independent, or equivalently "no feature can be linearly represented by any other feature". This natural condition simply states that none of our columns are redundant, and is almost surely met for real data with more samples than features.


Nobody really implements this closed form since the necessary matrix operations are infeasible for even 10000 data points. For those interested in bonus content, one way to salvage this closed form is via high-dimensional sketching (refer to David Woodruff's wonderful curriculum). In practice, almost everbody uses gradient descent on the linear regression objective. Since this objective is convex in the weight parameters, this converges rather quickly (even provably so), making it the de facto choice for implementation.

Expressing Functions

The knee-jerk reaction to linear regression is that the world is non-linear and every relationship can not be captured by a linear function. This is circumvented by the most important aspect of linear regression, feature selection. Before solving the objective for a given dataset, we can manipulate the features to represent quantitites of interest.


For example, one immediate limitation is that we can not represent affine functions (linear functions with a non-zero intercept). To remedy this problem, we can make our data \(d+1\)-dimensional by prepending a row of \(1\)'s to the data matrix. We can now express any intercept via the coefficient of the column of \(1\)'s. We can similarly apply such methods to extrapolate to quadratics (add all possible products of features as new features), exponentials (add all exponentiations of features as new features), etc.


In fact, Stone-Weierstrass Theorem tells us that any continuous deterministic \(f\) can be arbitrarily well approximated by a polynomial (think about adding more and more terms of the function's Taylor expansion). Therefore, linear regression can represent any continuous function with sufficient features! Note that this only shows that we can interpolate the training data; we will discuss whether this fits other data from the same distribution later.

Incorporating Model Knowledge

Though we can express any function, it is not clear whether we can incorporate domain knowledge beyond the data points we supply. Luckily, linear regression is amenable to utilizing extra information about the weights or data.

Functional Constraints

To start, let's consider a linear constraint on the weights. For example, if we know that our label is a convex combination of the features, we would like to enforce the fact that the weights sum to \(1\), which naively applying linear regression will not guarantee. In this setting, instead of \(d\) degrees of freedom, we only have \(d-1\) degrees of freedom, since the linear constraint reduces one free variable: the last coefficient is determined for any choice of the first \(d-1\) coefficients. If we index the \(i\)th feature column as \(X_i \in \R^{n}\) and the \(i\)th weight as \(w_i \in \R\) we can rewriting our objective as

\[\begin{aligned}\arg\min_{w} & \; \left|\left|\left[(w_1)(X_1) + (w_2)(X_2) + \ldots + (w_{d-1})(X_{d-1}) + (1-w_1 \ldots-w_d)(X_d)\right] - \left[y\right]\right|\right|_2^2 \\ = \arg\min_{w} & \; \left|\left|\left[(w_1)(X_1 - X_d) + (w_2)(X_2 - X_d) + \ldots (w_{d-1})(X_{d-1} - X_d)\right] - \left[y - X_d\right]\right|\right|_2^2 \end{aligned}\]

If we squint hard enough, we notice that after rearrangement, the optimal argument to the original objective can be solved by solving the regression problem with the \(d-1\) features of \(X_i - X_d\) and the labels \(y - X_d\). Therefore, we can enforce some structure through some clever feature engineering.


This cool trick doesn't work as nicely for linear inequalities or more sophisticated convex constraints. To remedy this, we can abandon our earlier goal of creating a new linear regression problem, and simply carry out gradient descent while projecting our weight vector to the closest point in our feasible space every step. This algorithm, called Projected Gradient Descent, also provably converges for convex feasible spaces, and will quickly reach our desired solution for more complicated conditions on our input. It is worth noting that trying to enforce such properties on neural networks forms many active areas of research.

Probabilistic Priors

What if our knowledge can not be adequately captured by constraints on the weights? Let's consider a synthetic scenario where we assume that the true weight comes from a zero mean Gaussian of variance \(\tau^2I_d\) and we apply some Gaussian noise of variance \(\sigma^2\) to each output. Given these assumptions, we can now define the posterior, or the likelihood of a paremeter given the data. To make the math easier, we will compute this under log without normalizing constants.

\[\begin{aligned} &\log P(w \mid X, y) & \\ &\propto \log P(X, y \mid w) + \log P(w) & \text{[Bayes' rule]}\\ &= -\frac{1}{2\sigma^2}(Xw - y)^{\top}(Xw - y) - \frac{1}{2\tau^2}w^{\top}w & \text{[Gaussian priors]}\\ &\propto -\|Xw -y\|_2^2 - \frac{\sigma^2}{\tau^2}\|w\|_2^2 & \text{[rerrange]} \\ \end{aligned}\]

Take a second to see if this feels familiar. We note that for this specific choice of prior, the posterior is also Gaussian (consequence of conjugate priors). This means that the mode of the posterior is the same as its mean. We will now derive the mode, as it is mathematically simpler. Following similar linear algebra for the closed form for linear regression, the weight that maximizes the likelihood is

\[w = \left(X^{\top}X + \frac{\sigma^2}{\tau^2}I_d\right)^{-1}X^{\top}y\]

First observe that when we set \(\tau^2= 0\), we recover the standard linear regression closed form. However, when we have a prior that the weight vector is somewhat small (instead of having a uniform prior over all possible weights), we develop a new closed form penalizes the weights for being large, effectively shrinking our solution. This new solution is called ridge regression, and it is incredibly useful to regularize linear regression and prevent overfitting, far beyond this precise synthetic scenario.


Obviously, this is not the only prior one can use. Under a Laplace prior, we can derive Lasso regression, which promotes sparsity in the weights. If we had the prior that our noise sometimes induces outliers, we could try the noise prior \((1-\epsilon)\mathcal{N}(0, \sigma^2) + \epsilon\mathcal{N}(0, C\sigma^2)\), a simple form of robust regression. The main takeaway is that with false but representative priors, one can develop better natural variants.

Robustness

Distribution Shift

Linear regression manages to do quite well even under a variety of shifts in the data distribution. In the synthetic example from earlier, note that the closed form is independent of the covariate distribution \(x_i\). Neural networks trained to solve this synthetic task are completely brittle to any shift in the distribution of \(x_i\) (Garg et al, 2022), provably so for low capacity models (Zhang et al, 2023). Beyond this toy scenario, deep learning in both application and theory is notoriously brittle to many forms of distribution shifts, highlighting a benefit of linear regression (Ovadia et al, 2019).

Adversarial Robustness

How much is a prediction affected by a small perturbation of the input? Linear regression is quite robust to input perturbations, since the best possible attack for any input is to take a step orthogonal to the weight vector. However, it is well known that since machine learning is non-convex, neural networks are incredibly brittle to tiny perturbations of the input, even after many explicit attempt sot protect against such adversaries. Interestingly, some of the best methods for protecting against such adversaries involve regularizing the model to be linear around inputs (Qin et al, 2019), effectively outputting a solution closer to linear regression.

Interpretability

One of the largest problems with deep learning techniques is that the resulting bag of matrices can not be interpreted. The beauty of linear regression is that the final answer has a simple interpretation: changing the feature \(x_i\) by \(\delta\) changes the prediction by \(w_i\delta\). This helps in two primary ways

Model Debugging

One straightforward benefit of intepretability is debugging. Given a final model, we can sanity check the resulting weights to see if they align with priors (i.e. does this feature matter more than that feature, is this weight negative or positive, etc.). This is really important since most failures in data pipelines, unlike traditional software, are silent. Instead of getting a failed test case or a compiler error, you're more likely going to get a bad model, and it will be unclear whether it was due to a data parsing error, algorithmic issue, or simple bug. Being able to analyze the final model parameters is an incredibly under-appreciated method of catching a plethora of these issues that unintepretable neural networks envy.

Data Understanding

Linear regression is fundamentally necessary since it can interpret data in a way that can not be done by techniques optimized for prediction. For starters, it quantifies how different features relate to the data. Important, the coefficient of correlation tells us how correlated the feature is to what we want to predict. Beyond predictive power, this allows us to understand which features are uniquely important in predicting the output via standard evaluation metrics such as the coefficient of correlation.


What happens if we want to understand the influence of a feature "controlling" for another feature, or removing the influence of this feature? One natural way to do so is to run a multiple feature linear regression. Then, the slope on your feature of interest can be interpreted as the slope after removing the effect of the other variables. Note that this slope will likely be different from the slope of the single feature linear regression, and can even be a different sign! Causal inference (great textbook here) allows one to develop linear regression methods to remove the influence of variables when measuring correlation, increasing your ability to measure influence with statistical guarantees.

Putting it All Together

Obviously, linear regression is not the solution to all problems. It fails embarassingly for some data properties such as heteroskedasticity, can not model unknown complex relationships, and loses guarantees outside of convenient priors and assumptions. However, for predictive tasks, one should see linear regression as a simple, efficient, and robust estimator for tasks where it is important to incorporate human understanding of the problem. For analysis tasks, one should see linear regression as an interpretable way to derive insights from data. Though it is much more difficult to get linear regression to work, one can trust it so much more when it does work. Overall, one should not discount the power of linear regression, and many problems would benefit from the correct featurization and linear model :))


Thank you for reading, and feel free to reach out with any questions or thoughts!