package edu.washington.gs.maccoss.encyclopedia.utils.math;

import java.util.Arrays;

/* loaded from: input_file:edu/washington/gs/maccoss/encyclopedia/utils/math/Sigmoid.class */
public class Sigmoid {
    protected float a;
    protected float b;
    protected float c;
    protected float d;
    private float alpha;

    public Sigmoid() {
        this(0.0f, 0.0f, 0.0f, 0.0f);
    }

    public Sigmoid(float f) {
        this(0.0f, 0.0f, 0.0f, 0.0f, f);
    }

    public Sigmoid(float f, float f2, float f3, float f4) {
        this(f, f2, f3, f4, 1.0E-7f);
    }

    public Sigmoid(float f, float f2, float f3, float f4, float f5) {
        this.a = f;
        this.b = f2;
        this.c = f3;
        this.d = f4;
        this.alpha = f5;
    }

    public void train(float[] fArr, float[] fArr2, int i) {
        float[] fArr3 = new float[fArr.length];
        Arrays.fill(fArr3, 1.0f);
        train(fArr, fArr2, fArr3, i);
    }

    public void train(float[] fArr, float[] fArr2, float[] fArr3, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            float f = this.a;
            float f2 = this.b;
            float f3 = this.c;
            float f4 = this.d;
            adjustA(this.alpha * gradientA(fArr, fArr2, fArr3));
            adjustB(this.alpha * gradientB(fArr, fArr2, fArr3));
            adjustC(this.alpha * gradientC(fArr, fArr2, fArr3));
            adjustD(this.alpha * gradientD(fArr, fArr2, fArr3));
            if (this.a == f && this.b == f2 && this.c == f3 && this.d == f4) {
                return;
            }
        }
    }

    protected float adjustD(float f) {
        float f2 = this.d - f;
        this.d = f2;
        return f2;
    }

    protected float adjustC(float f) {
        float f2 = this.c - f;
        this.c = f2;
        return f2;
    }

    protected float adjustB(float f) {
        float f2 = this.b - f;
        this.b = f2;
        return f2;
    }

    protected float adjustA(float f) {
        float f2 = this.a - f;
        this.a = f2;
        return f2;
    }

    private float gradientA(float[] fArr, float[] fArr2, float[] fArr3) {
        float[] add = General.add(General.multiply(fArr, this.b), this.a);
        float[] add2 = General.add(General.multiply(add, add), 1.0f);
        return General.sum(General.multiply(fArr3, General.divide(General.multiply(General.subtract(General.multiply(General.subtract(fArr2, this.d), General.protectedSqrt(add2)), General.multiply(add, this.c)), (-2.0f) * this.c), General.multiply(add2, add2))));
    }

    private float gradientB(float[] fArr, float[] fArr2, float[] fArr3) {
        float[] add = General.add(General.multiply(fArr, this.b), this.a);
        float[] add2 = General.add(General.multiply(add, add), 1.0f);
        return General.sum(General.multiply(fArr3, General.divide(General.multiply(General.subtract(General.multiply(General.subtract(fArr2, this.d), General.protectedSqrt(add2)), General.multiply(add, this.c)), General.multiply(fArr, (-2.0f) * this.c)), General.multiply(add2, add2))));
    }

    private float gradientC(float[] fArr, float[] fArr2, float[] fArr3) {
        float[] add = General.add(General.multiply(fArr, this.b), this.a);
        return General.sum(General.multiply(fArr3, General.divide(General.multiply(General.subtract(General.multiply(add, this.c), General.multiply(General.protectedSqrt(General.add(General.multiply(add, add), 1.0f)), General.subtract(fArr2, this.d))), 2.0f), General.add(add, 1.0f))));
    }

    private float gradientD(float[] fArr, float[] fArr2, float[] fArr3) {
        float[] add = General.add(General.multiply(fArr, this.b), this.a);
        return General.sum(General.multiply(fArr3, General.multiply(General.subtract(General.subtract(fArr2, this.d), General.divide(General.multiply(add, this.c), General.protectedSqrt(General.add(General.multiply(add, add), 1.0f)))), -2.0f)));
    }

    public float getValue(float f) {
        return (((this.a + (this.b * f)) * this.c) / ((float) Math.sqrt(1.0f + (r0 * r0)))) + this.d;
    }

    public float[] getValues(float[] fArr) {
        float[] fArr2 = new float[fArr.length];
        for (int i = 0; i < fArr2.length; i++) {
            fArr2[i] = getValue(fArr[i]);
        }
        return fArr2;
    }

    public float getA() {
        return this.a;
    }

    public float getB() {
        return this.b;
    }

    public float getC() {
        return this.c;
    }

    public float getD() {
        return this.d;
    }
}
