Fitting a Linear Line to Data Points

A line of best fit is a line that is the best approximation of a given set of data. This line can be linear, exponential, logarithmic, polynomial, a power average or a moving average and will depend on the type of data you have and the purpose of the approximated line. Your dataset can even be in 3 dimensions.

I'm only going to look at 2D data (ordered pairs) and the Least Square Method to find a straight line (y=mx+c) that best fit the given data points.



I've implemented the Least Squares Method using Processing and shall explain the code together with the math.

In order to manage all the data Points, I've created a very simple class called Points. This class is used to create new points and only has one method (display) which displays a single data point.


class Points
{
  float xPos, yPos;

  Points(float _x, float _y)
  {
    xPos = _x;
    yPos = _y;
  }

  void display()
  {
    stroke(0);
    fill(0);
    ellipse(xPos, yPos, 10,10);
  }
}

Step 1:

Calculate the mean of all the x-values

\overline{X}=\frac{\sum_{i=1}^{n}x_i}{n}

and all the y-values.

\overline{Y}=\frac{\sum_{i=1}^{n}y_i}{n}

where n is the total number of data points.

What these formulas do is add all the x-values and divide the total by the number of x-values. This is also done for all the y-values.


int maxPoints = dataPoints.size();
for (int i=0; i < maxPoints; i++)
{
  meanX += dataPoints.get(i).xPos / maxPoints;
  meanY += dataPoints.get(i).yPos / maxPoints;
}

This is a straight forward for loop. It retreives each data point's x and y value, divide it by the variable maxPoints, which is the number of points currently on the screen, and then add it to variables called meanX and meanY respectively.

Step 2:

Calculate the slope (m) of the straight line with this formula:

m=\frac{\sum_{i=1}^{n}(x_i-\overline{X})(y_i-\overline{Y})}{\sum_{i=1}^{n}(x_i-\overline{X})^2}

 

float mNumer = 0.0;    //Numerator used to calculate m
float mDenom = 0.0;   //Denominator used to calculate m

for (int n = 0; n < maxPoints; n++)
{
mNumer += (dataPoints.get(n).xPos - meanX)*(dataPoints.get(n).yPos - meanY);
mDenom += pow((dataPoints.get(n).xPos - meanX),2);
}

float m = mNumer / mDenom;

I've decide to calculate the numerator and denominator separately in order to have it more closely resemble the formula. In line 10, the final value for m is calculated.

Step 3:

Using the straight line formula we now calculate the y-intercept:

c = \overline{Y}-m\overline{X}


float c = meanY - m*meanX;

Step 4

Using the values for m and c, draw a straight line.

line(0,c,width,m*width+c);  

In order to draw a straight line we need 2 points. I've selected points on the edges of the display area in order to keep the math a bit simpler.

The line is drawn from point (x1,y1) to (x2,y2). The starting points' x, is just 0 and y is the point on the y-axis where the straight line will cross it, so it is just c.

It is a bit trickier for point 2. Point 2's x is the width of the display area, essentially the maximum value x can ever be. Point 2's y value is calculated by substituting x2 (which is the width of the display area) and m (previously calculated) into the straight line formula y=mx+c.

Complete Program

The complete program, is shown below and can be copied and pasted into an empty Processing sketch.


//A Least Squares Line Fit algorithm

ArrayList<Points> dataPoints = new ArrayList<Points>();

void setup()
{
  size(500,500);
  frameRate(10); //Set screen refresh
}

void draw()
{
  clear();
  background(200);
  textSize(10);
  fill(0);
  text("mouse LEFT = Add/Remove data point", 10,height - 20);
  text("mouse RIGHT = clear all data points", 10, height - 10);

  //Display all the points
  for (Points p: dataPoints)
  {
    p.display();
  }
  plotLine();
}

void plotLine()
{
  float meanX = 0.0;
  float meanY = 0.0;
  float mNumer = 0.0;
  float mDenom = 0.0;

  int maxPoints = dataPoints.size();
  for (int n=0; n < maxPoints; n++)
  {
    meanX += dataPoints.get(n).xPos / maxPoints;
    meanY += dataPoints.get(n).yPos / maxPoints;
  }

  for (int n = 0; n < maxPoints; n++)
  {
    mNumer += (dataPoints.get(n).xPos - meanX)*(dataPoints.get(n).yPos - meanY);
    mDenom += pow((dataPoints.get(n).xPos - meanX),2);
  }

  float m = mNumer / mDenom;
  float c = meanY - m*meanX;
  line(0,c,width,m*width+c);
}

//Use mousePressed to create new dataPoints
void mousePressed()
{
  boolean dataPointFound = false;
  int idx = 0;

  if (mousePressed && (mouseButton == LEFT))
  {
     //Step through all existing dataPoints to see if mouseX and mouseY is close to existing dataPoint xPos and yPos value
    for (int k = 0; k < dataPoints.size(); k++)
    {
      float distToDataPoint = dist(mouseX,mouseY, dataPoints.get(k).xPos, dataPoints.get(k).yPos);

      //If it is then set the flag
      if (distToDataPoint < 5)
      {
        dataPointFound = true;
        idx = k;
      }
    }

    //If flag set, then remove point from dataPoints
    if (dataPointFound)
    {
      dataPoints.remove(idx);
    }
    else
    {
      dataPoints.add(new Points(mouseX,mouseY));
    }
    println("Amount of dataPoints : "+dataPoints.size());
  }

  //Clears all the data when the RIGHT mouse button is pressed
  if (mousePressed && (mouseButton == RIGHT))
  {
    dataPoints.clear();
  }
}

class Points
{
  float xPos, yPos;

  Points(float _x, float _y)
  {
    xPos = _x;
    yPos = _y;
  }

  void display()
  {
    stroke(0);
    fill(0);
    ellipse(xPos, yPos, 10,10);
  }
}

 

Share

Leave a Reply

Your email address will not be published. Required fields are marked *