package edu.washington.gs.maccoss.encyclopedia.utils.math;

import edu.washington.gs.maccoss.encyclopedia.utils.math.distributions.Distribution;

/* loaded from: input_file:edu/washington/gs/maccoss/encyclopedia/utils/math/ProphetMixtureModel.class */
public class ProphetMixtureModel implements RTProbabilityModel {
    private Distribution positive;
    private Distribution negative;
    private final boolean fixedMeans;

    public ProphetMixtureModel(Distribution distribution, Distribution distribution2, boolean z) {
        this.positive = distribution;
        this.negative = distribution2;
        this.fixedMeans = z;
    }

    @Override // edu.washington.gs.maccoss.encyclopedia.utils.math.RTProbabilityModel
    public float getProbability(float f, float f2) {
        return getProbability(f2);
    }

    public float getProbability(float f) {
        double probability = this.positive.getProbability(f);
        double probability2 = probability + this.negative.getProbability(f);
        if (probability2 == 0.0d) {
            return 0.0f;
        }
        return (float) (probability / probability2);
    }

    public Distribution getPositive() {
        return this.positive;
    }

    public Distribution getNegative() {
        return this.negative;
    }

    public void train(float[] fArr, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            runIteration(fArr);
        }
    }

    void runIteration(float[] fArr) {
        float[] fArr2 = new float[fArr.length];
        float[] fArr3 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            double probability = this.positive.getProbability(fArr[i]);
            double probability2 = this.negative.getProbability(fArr[i]);
            double d = probability + probability2;
            if (d == 0.0d) {
                fArr2[i] = 0.0f;
                fArr3[i] = 1.0f;
            } else {
                fArr2[i] = (float) (probability / d);
                fArr3[i] = (float) (probability2 / d);
            }
        }
        this.positive = getNewDistribution(this.positive, fArr, fArr2, this.fixedMeans);
        this.negative = getNewDistribution(this.negative, fArr, fArr3, this.fixedMeans);
    }

    static Distribution getNewDistribution(Distribution distribution, float[] fArr, float[] fArr2, boolean z) {
        float weightedMean;
        float weightedStdev;
        float sum = General.sum(fArr2);
        int i = 0;
        for (float f : fArr2) {
            if (f >= 0.5f) {
                i++;
            }
            if (i > 2) {
                break;
            }
        }
        if (i > 2) {
            weightedMean = trimmedMean(fArr, fArr2, 0.5f);
            weightedStdev = trimmedStdev(fArr, fArr2, weightedMean, 0.5f);
        } else {
            weightedMean = weightedMean(fArr, fArr2);
            weightedStdev = weightedStdev(fArr, fArr2, weightedMean);
        }
        return z ? distribution.clone(distribution.getMean(), weightedStdev, sum) : distribution.clone(weightedMean, weightedStdev, sum);
    }

    static float weightedMean(float[] fArr, float[] fArr2) {
        return General.sum(fArr) / General.sum(fArr2);
    }

    static float weightedStdev(float[] fArr, float[] fArr2, float f) {
        float f2 = 0.0f;
        float f3 = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr2[i2] > 0.0f) {
                i++;
                f3 += fArr2[i2];
                float f4 = fArr[i2] - f;
                f2 += fArr2[i2] * f4 * f4;
            }
        }
        return (float) Math.sqrt(f2 / (((i - 1) * f3) / i));
    }

    static float trimmedMean(float[] fArr, float[] fArr2, float f) {
        float f2 = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < fArr2.length; i2++) {
            if (fArr2[i2] >= f) {
                f2 += fArr[i2];
                i++;
            }
        }
        return f2 / i;
    }

    static float trimmedStdev(float[] fArr, float[] fArr2, float f, float f2) {
        float f3 = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < fArr2.length; i2++) {
            if (fArr2[i2] >= f2) {
                float f4 = fArr[i2] - f;
                f3 += f4 * f4;
                i++;
            }
        }
        return (float) Math.sqrt(f3 / (i - 1));
    }
}
