/*
 * Decompiled with CFR 0.152.
 */
package com.github.stanfordfuturedata.momentsketch.optimizer;

import com.github.stanfordfuturedata.momentsketch.MathUtil;
import com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian;
import com.github.stanfordfuturedata.momentsketch.optimizer.GenericOptimizer;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;

public class NewtonOptimizer
implements GenericOptimizer {
    protected FunctionWithHessian P;
    protected int maxIter;
    protected int stepCount;
    protected boolean converged;
    protected int dampedStepCount;
    private double alpha = 0.3;
    private double beta = 0.25;
    private boolean verbose = false;

    public NewtonOptimizer(FunctionWithHessian P) {
        this.P = P;
        this.maxIter = 200;
        this.stepCount = 0;
        this.dampedStepCount = 0;
        this.converged = false;
    }

    @Override
    public void setVerbose(boolean flag) {
        this.verbose = flag;
    }

    @Override
    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    @Override
    public int getStepCount() {
        return this.stepCount;
    }

    @Override
    public boolean isConverged() {
        return this.converged;
    }

    public int getDampedStepCount() {
        return this.dampedStepCount;
    }

    @Override
    public FunctionWithHessian getP() {
        return this.P;
    }

    @Override
    public double[] solve(double[] start, double gradTol) {
        int step;
        int k = this.P.dim();
        double[] x = (double[])start.clone();
        double requiredPrecision = gradTol / 10.0;
        this.P.computeAll(x, requiredPrecision);
        double gradTol2 = gradTol * gradTol;
        this.converged = false;
        for (step = 0; step < this.maxIter; ++step) {
            RealVector stepVector;
            double PVal = this.P.getValue();
            double[] grad = this.P.getGradient();
            double[][] hess = this.P.getHessian();
            double mse = MathUtil.getMSE(grad);
            if (this.verbose) {
                System.out.println(String.format("Step: %3d GradRMSE: %10.5g P: %10.5g", step, Math.sqrt(mse), PVal));
            }
            if (mse < gradTol2) {
                this.converged = true;
                break;
            }
            Array2DRowRealMatrix hhMat = new Array2DRowRealMatrix(hess, false);
            try {
                CholeskyDecomposition d = new CholeskyDecomposition((RealMatrix)hhMat, 0.0, 0.0);
                stepVector = d.getSolver().solve((RealVector)new ArrayRealVector(grad));
            }
            catch (Exception e) {
                SingularValueDecomposition d = new SingularValueDecomposition((RealMatrix)hhMat);
                stepVector = d.getSolver().solve((RealVector)new ArrayRealVector(grad));
            }
            stepVector.mapMultiplyToSelf(-1.0);
            double dfdx = 0.0;
            for (int i = 0; i < k; ++i) {
                dfdx += stepVector.getEntry(i) * grad[i];
            }
            double stepScaleFactor = 1.0;
            double[] newX = new double[k];
            for (int i = 0; i < k; ++i) {
                newX[i] = x[i] + stepScaleFactor * stepVector.getEntry(i);
            }
            this.P.computeAll(newX, requiredPrecision);
            if (dfdx * dfdx > gradTol2) {
                double f1;
                double delta;
                while (!((delta = PVal + this.alpha * stepScaleFactor * dfdx - (f1 = this.P.getValue())) >= -gradTol) && !(stepScaleFactor < 0.001)) {
                    stepScaleFactor *= this.beta;
                    for (int i = 0; i < k; ++i) {
                        newX[i] = x[i] + stepScaleFactor * stepVector.getEntry(i);
                    }
                    this.P.computeAll(newX, requiredPrecision);
                }
            }
            if (stepScaleFactor < 1.0) {
                ++this.dampedStepCount;
            }
            if (this.verbose && stepScaleFactor < 1.0) {
                System.out.println("Step Size: " + stepScaleFactor);
            }
            System.arraycopy(newX, 0, x, 0, k);
        }
        this.stepCount = step;
        return x;
    }
}

