import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.apache.commons.math3.distribution.TDistribution;
public class RegressionTest {
public static void main(String[] args) {
// 예제 데이터 (X, Y)
double[][] data = {
{1, 2}, {2, 3.1}, {3, 5.2}, {4, 7.8}, {5, 11.3}, {6, 15.1}
};
// 각 회귀 모델 실행 및 결과 출력
RegressionResult linear = linearRegression(data);
RegressionResult quadratic = quadraticRegression(data);
RegressionResult log = logRegression(data);
RegressionResult power = powerRegression(data);
// 결과 출력
System.out.println("단순 선형: R² = " + linear.rSquared + ", p-value = " + linear.pValue);
System.out.println("제곱: R² = " + quadratic.rSquared + ", p-value = " + quadratic.pValue);
System.out.println("로그: R² = " + log.rSquared + ", p-value = " + log.pValue);
System.out.println("거듭제곱: R² = " + power.rSquared + ", p-value = " + power.pValue);
}
// 결과를 저장할 클래스
static class RegressionResult {
double rSquared;
double pValue;
RegressionResult(double rSquared, double pValue) {
this.rSquared = rSquared;
this.pValue = pValue;
}
}
// 1. 단순 선형 회귀 (y = a + bx)
public static RegressionResult linearRegression(double[][] data) {
SimpleRegression regression = new SimpleRegression();
for (double[] point : data) {
regression.addData(point[0], point[1]);
}
return new RegressionResult(regression.getRSquare(), calculatePValue(regression, data.length));
}
// 2. 제곱 회귀 (y = a + bx²)
public static RegressionResult quadraticRegression(double[][] data) {
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
double[][] x = new double[data.length][2]; // [X, X^2]
double[] y = new double[data.length];
for (int i = 0; i < data.length; i++) {
x[i][0] = data[i][0]; // X
x[i][1] = Math.pow(data[i][0], 2); // X^2
y[i] = data[i][1]; // Y
}
regression.newSampleData(y, x);
return new RegressionResult(regression.calculateRSquared(), calculatePValue(regression, data.length));
}
// 3. 로그 회귀 (y = a + b log(x))
public static RegressionResult logRegression(double[][] data) {
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
double[][] x = new double[data.length][1];
double[] y = new double[data.length];
for (int i = 0; i < data.length; i++) {
x[i][0] = Math.log(data[i][0]); // log(X)
y[i] = data[i][1]; // Y
}
regression.newSampleData(y, x);
return new RegressionResult(regression.calculateRSquared(), calculatePValue(regression, data.length));
}
// 4. 거듭제곱 회귀 (y = a * x^b -> log(y) = log(a) + b log(x))
public static RegressionResult powerRegression(double[][] data) {
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
double[][] x = new double[data.length][1];
double[] y = new double[data.length];
for (int i = 0; i < data.length; i++) {
x[i][0] = Math.log(data[i][0]); // log(X)
y[i] = Math.log(data[i][1]); // log(Y)
}
regression.newSampleData(y, x);
return new RegressionResult(regression.calculateRSquared(), calculatePValue(regression, data.length));
}
// p-value 계산 (t-분포 기반)
public static double calculatePValue(SimpleRegression regression, int dataSize) {
double slope = regression.getSlope();
double standardError = regression.getSlopeStdErr();
double tStatistic = slope / standardError;
return calculatePValueFromT(tStatistic, dataSize - 2); // 자유도 = n - 2
}
public static double calculatePValue(OLSMultipleLinearRegression regression, int dataSize) {
double[] coefficients = regression.estimateRegressionParameters();
double[] standardErrors = regression.estimateRegressionParametersStandardErrors();
if (coefficients.length < 2) return 1.0; // 데이터 부족 시 예외 처리
double tStatistic = coefficients[1] / standardErrors[1];
return calculatePValueFromT(tStatistic, dataSize - coefficients.length);
}
// t-분포를 이용한 p-value 계산
public static double calculatePValueFromT(double tStatistic, int degreesOfFreedom) {
if (degreesOfFreedom <= 0) return 1.0; // 자유도가 0 이하일 경우 예외 처리
TDistribution tDist = new TDistribution(degreesOfFreedom);
double pValue = 2 * (1 - tDist.cumulativeProbability(Math.abs(tStatistic))); // 양측 검정
return pValue;
}
}
카테고리 없음
통계 함수 개발
728x90