Quadratic Regression in Java

A comprehensive 5 minute tutorial on building a quadratic regression tool in Java.

https://pagead2.googlesyndication.com/pagead/js/adsbygoogle.js

Data science is taking the world by storm

In this tutorial, you will learn how to create a simple quadratic regression algorithm using Java. This algorithm can be modified to work as a linear, logistic or polynomial regression tool, making it quite versatile. This uses a brute-force approach, one that is not typically used when it comes to machine learning. This is the first of a series of tutorials, all using the same problem-solving approach. This is better suited to beginners with some knowledge of linear algebra, although it is not necessary.

The mathematics behind this

The only formula that you’ll have to be familiar with is the standard quadratic equation form:

y = ax2 + bx + c

We will be using three nested for loops for each of the variables a, b and c, which will be used to come up with the most accurate equation.

Setting up the class in java

We will first create a Dataset class with the necessary instance variables and a constructor:

public class Dataset {
	private int[] xdata;
	private int[] ydata;
	public Dataset(int[] x, int[] y) {
		xdata=x;
		ydata=y;
        }
}

As you can see, the class contains the arrays for each variable and a parameterized constructor to initialize a Dataset object. For the sake of simplicity, we are assuming that the array of x-coordinates have been sorted in ascending order. Now that that’s out of the way, let’s move onto the real coding.

The Cost function

We are going to need a method that takes in the three variables and lets us know how close our equation comes to the actual values. This would be calculated by the formula:

Change in Cost = Predicted Value-Actual Value

This total cost for each set of specific values a, b and c would be equal to the sum of all the changes in cost at each x value. We implement this by using a for loop that accesses all the x points and compares the corresponding y values with the predicted values.

public double quadraticCost(double a, double b, double c) {
		double cost = 0;
		for(int i=0;i<xdata.length;i++) {
			double predictedValue = a*xdata[i]*xdata[i]+b*xdata[i]+c;
			double actualValue = ydata[i];
			cost+=Math.abs(actualValue-predictedValue);
		}
		return cost;
	}

This method works as a helper method for the main method, which we will be creating now.

Getting the best set of variables

Here, we will be using 3 nested loops which try to cover all possible relevant values. It looks something like this:

public void quadraticEquation() {
		double min= Double.MAX_VALUE;
		double a=0.0;
		double b=0.0;
		double c=0.0;
		for(double a1=-9.9;a1<10.0;a1+=0.1) {
			for(double b1=-9.9;b1<10.0;b1+=0.1) {
				for(double c1=-9.9;c1<10.0;c1+=0.1) {
					double cost = quadraticCost(a1,b1,c1);
					if(cost<min) {
						min=cost;
						a=a1;
						b=b1;
						c=c1;
					}
				}
			}
		}
		System.out.println("a="+a+"\nb="+b+"\nc="+c);
	}

Of course, these loops will not be suitable for every graph and will have to be adjusted according to your data-set. So up till now, our class looks like:

public class Dataset {
	private int[] xdata;
	private int[] ydata;
	public Dataset(int[] x, int[] y) {
		xdata=x;
		ydata=y;
        }
        public double quadraticCost(double a, double b, double c) {
		double cost = 0;
		for(int i=0;i<xdata.length;i++) {
			double predictedValue = a*xdata[i]*xdata[i]+b*xdata[i]+c;
			double actualValue = ydata[i];
			cost+=Math.abs(actualValue-predictedValue);
		}
		return cost;
	}
        public void quadraticEquation() {
		double min= Double.MAX_VALUE;
		double a=0.0;
		double b=0.0;
		double c=0.0;
		for(double a1=-9.9;a1<10.0;a1+=0.1) {
			for(double b1=-9.9;b1<10.0;b1+=0.1) {
				for(double c1=-9.9;c1<10.0;c1+=0.1) {
					double cost = quadraticCost(a1,b1,c1);
					if(cost<min) {
						min=cost;
						a=a1;
						b=b1;
						c=c1;
					}
				}
			}
		}
		System.out.println("a="+a+"\nb="+b+"\nc="+c);
	}
        
}

Testing

It is now time to test our class and see if it works. I will use arrays that show a clear pattern of regression:

public class tester {
	public static void main(String[] args) {
			int[] x = {1,2,3,4,5};
			int[] y = {1,3,6,10,15};
			Dataset set = new Dataset(x,y);
			set.quadraticEquation();
		}
	
}

This code produces the output:

a=0.5
b=0.5
c=0

We can check the accuracy of these values using Desmos, an excellent graphing tool.

That appears to be correct. Let me try using a more complex group of points:

public class tester {
	public static void main(String[] args) {
			int[] x = {1,2,3,4,5,6,7};
			int[] y ={20,19,17,18,15,13,5};
			Dataset set = new Dataset(x,y);
			set.quadraticEquation();
		}
	
}

Initially, using the current model, I got an answer that was accurate but still wasn’t quite on point. To increase accuracy, I changed the termination condition for the innermost loop to c1<20.0 instead of c1<10.0. The output was:

a=-0.6
b=2.4
c=18

This yielded a much more accurate result.

Note that the quadraticEquation() method can be varied to be more precise by changing the initial values of the variables a, b and c, by altering the update statements of each loop, and also by changing the termination conditions of the loops. To increase accuracy, simply decrease the increment value in the loop. However, understand that doing this would take up more stack space and at the same time it would be more time consuming.

You are free to test and see how it works with other inputs. Although this is definitely not the most efficient algorithm, I think it is still useful in applications where time-efficiency isn’t a priority.

Published by clhruv

Programmer, Data Scientist and Full Stack Developer

Leave a comment

Design a site like this with WordPress.com
Get started