What Is Linear Regression

April 17, 2023 |  Categories:  Machine Learning  

In this post, I am going to give a textual description for what exactly is linear regression, and how that can be applied to a line.

To start, what is linear regression? In short, linear regression is the process of fitting a linear line to a set of data. That is to say if you have a graph with multiple data points plotted, and you want to find what single non-curved line most closely connects all the data points, you would use linear regression to figure that out. This is incredibly useful because it allows us to now use the line to predict the outcome of an event given input data. For instance, if you knew that someone who ate 2 cookies consumed 20 calories, someone who ate 4 cookies consumed 40 calories, and someone who ate 6 cookies consumed 60 calories, then how many calories would someone consume if they ate 10 cookies? Intuition tells us the answer is 100. However, linear regression can solve this for us as well and it can do it for data sets that aren't so obvious.

So how does it work? In order to use linear regression, you have to start with a line. Let's assume in our cookie data set, we have the line y = 20x. For each data point, we can plug in the x value and compare the calculated y value against the known y value. In this example we would compare the following in the format (known value, calculated value): (20, 40), (40, 80), (60, 120). The difference between each set of values is -20, -40, and -60. Added together this is a total difference of -120. These differences are known as the residuals. For every possible function f(x), we can calculate a different number for the residuals. In practice, having a negative number can be challenging because if some residuals are negative and some residuals are positive, then the cumulative affect of the two cancel each other out instead of adding to the total loss. To fix this, we square the residuals which has the result of making all numbers positive. the sum of the square of the residuals for this function is 14,400.

we can find the sum of squared residuals with the function
((a*x1 + b) - y1)^2 + ((a*x2 + b) - y2)^2 + ((a*x2 + b) - y3)^2 + ... (repeat n times for n data points).
"a" represents the slope and "b" represents the intercept.

If we compare the sum of squared residuals for multiple functions with different slopes (ignoring intercepts for now), we can find the following results:

f(x) = 20x = 14,400
f(x) = 15x = 3,600
f(x) = 10x = 0
f(x) = 5x = 3,600
f(x) = 0 = 14,400

In this example, it is obvious that there are no residuals when f(x) = 10x. However, it is incredibly unlikely to get perfectly 0 residuals, so we can not use that as our metric to determine the best slope. We can look at the trend though and notice that if we were to plot this new information on a graph where the slope of the line is on the x axis and the sum of squared residuals is on the y axis, you would find that as the slope of the line approaches 10, then the slope of the resulting graph approaches 0. We can find the slope of the line at any given point by taking the derivative. Without getting into the math of finding the derivative, it is helpful to just understand the concepts at play. When we find our x-axis coordinate that pertains to the slope of the fitted line which produces a corresponding derivative of 0, we know that is the best fitted line.

When you consider the intercept as well, it becomes a little more complicated to visualize. Essentially, you can plot a graph with the resulting sum of squared residuals given a constant intercept and changing slope, then plot on the same graph a 3D representation of the sum of squared residuals given a constant slope and changing intercept. This resulting new graph provides a point where the derivative of the combined function is 0. The corresponding intercept and slope values given at this point are your slope and intercept values for the line of best fit.

Plotting every possible combination of lines can be impractical and time consuming, so instead techniques like gradient descent are usually used which allows us to choose a new "test line" to plot given the magnitude of the derivative and a learning rate. However, those are topics for another time.

Leave a comment:

Comments: