본문 바로가기

카테고리 없음

통계 함수 개발

728x90

import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.apache.commons.math3.distribution.TDistribution;

public class RegressionTest {
    public static void main(String[] args) {
        // 예제 데이터 (X, Y)
        double[][] data = {
            {1, 2}, {2, 3.1}, {3, 5.2}, {4, 7.8}, {5, 11.3}, {6, 15.1}
        };

        // 각 회귀 모델 실행 및 결과 출력
        RegressionResult linear = linearRegression(data);
        RegressionResult quadratic = quadraticRegression(data);
        RegressionResult log = logRegression(data);
        RegressionResult power = powerRegression(data);

        // 결과 출력
        System.out.println("단순 선형: R² = " + linear.rSquared + ", p-value = " + linear.pValue);
        System.out.println("제곱: R² = " + quadratic.rSquared + ", p-value = " + quadratic.pValue);
        System.out.println("로그: R² = " + log.rSquared + ", p-value = " + log.pValue);
        System.out.println("거듭제곱: R² = " + power.rSquared + ", p-value = " + power.pValue);
    }

    // 결과를 저장할 클래스
    static class RegressionResult {
        double rSquared;
        double pValue;

        RegressionResult(double rSquared, double pValue) {
            this.rSquared = rSquared;
            this.pValue = pValue;
        }
    }

    // 1. 단순 선형 회귀 (y = a + bx)
    public static RegressionResult linearRegression(double[][] data) {
        SimpleRegression regression = new SimpleRegression();
        for (double[] point : data) {
            regression.addData(point[0], point[1]);
        }
        return new RegressionResult(regression.getRSquare(), calculatePValue(regression, data.length));
    }

    // 2. 제곱 회귀 (y = a + bx²)
    public static RegressionResult quadraticRegression(double[][] data) {
        OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
        double[][] x = new double[data.length][2]; // [X, X^2]
        double[] y = new double[data.length];

        for (int i = 0; i < data.length; i++) {
            x[i][0] = data[i][0];       // X
            x[i][1] = Math.pow(data[i][0], 2); // X^2
            y[i] = data[i][1];          // Y
        }

        regression.newSampleData(y, x);
        return new RegressionResult(regression.calculateRSquared(), calculatePValue(regression, data.length));
    }

    // 3. 로그 회귀 (y = a + b log(x))
    public static RegressionResult logRegression(double[][] data) {
        OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
        double[][] x = new double[data.length][1];
        double[] y = new double[data.length];

        for (int i = 0; i < data.length; i++) {
            x[i][0] = Math.log(data[i][0]); // log(X)
            y[i] = data[i][1];              // Y
        }

        regression.newSampleData(y, x);
        return new RegressionResult(regression.calculateRSquared(), calculatePValue(regression, data.length));
    }

    // 4. 거듭제곱 회귀 (y = a * x^b -> log(y) = log(a) + b log(x))
    public static RegressionResult powerRegression(double[][] data) {
        OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
        double[][] x = new double[data.length][1];
        double[] y = new double[data.length];

        for (int i = 0; i < data.length; i++) {
            x[i][0] = Math.log(data[i][0]); // log(X)
            y[i] = Math.log(data[i][1]);    // log(Y)
        }

        regression.newSampleData(y, x);
        return new RegressionResult(regression.calculateRSquared(), calculatePValue(regression, data.length));
    }

    // p-value 계산 (t-분포 기반)
    public static double calculatePValue(SimpleRegression regression, int dataSize) {
        double slope = regression.getSlope();
        double standardError = regression.getSlopeStdErr();
        double tStatistic = slope / standardError;
        return calculatePValueFromT(tStatistic, dataSize - 2); // 자유도 = n - 2
    }

    public static double calculatePValue(OLSMultipleLinearRegression regression, int dataSize) {
        double[] coefficients = regression.estimateRegressionParameters();
        double[] standardErrors = regression.estimateRegressionParametersStandardErrors();

        if (coefficients.length < 2) return 1.0; // 데이터 부족 시 예외 처리

        double tStatistic = coefficients[1] / standardErrors[1];
        return calculatePValueFromT(tStatistic, dataSize - coefficients.length);
    }

    // t-분포를 이용한 p-value 계산
    public static double calculatePValueFromT(double tStatistic, int degreesOfFreedom) {
        if (degreesOfFreedom <= 0) return 1.0; // 자유도가 0 이하일 경우 예외 처리

        TDistribution tDist = new TDistribution(degreesOfFreedom);
        double pValue = 2 * (1 - tDist.cumulativeProbability(Math.abs(tStatistic))); // 양측 검정

        return pValue;
    }
}