/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.util;

import ai.djl.ndarray.NDArray;

public class MinMaxScaler
implements AutoCloseable {
    private NDArray fittedMin;
    private NDArray fittedMax;
    private NDArray fittedRange;
    private float minRange;
    private float maxRange = 1.0f;
    private boolean detached;

    public MinMaxScaler fit(NDArray data, int[] axises) {
        this.fittedMin = data.min(axises);
        this.fittedMax = data.max(axises);
        this.fittedRange = this.fittedMax.sub(this.fittedMin);
        if (this.detached) {
            this.detach();
        }
        return this;
    }

    public MinMaxScaler fit(NDArray data) {
        this.fit(data, new int[]{0});
        return this;
    }

    public NDArray transform(NDArray data) {
        if (this.fittedRange == null) {
            this.fit(data, new int[]{0});
        }
        NDArray std = data.sub(this.fittedMin).divi(this.fittedRange);
        return this.scale(std);
    }

    public NDArray transformi(NDArray data) {
        if (this.fittedRange == null) {
            this.fit(data, new int[]{0});
        }
        NDArray std = data.subi(this.fittedMin).divi(this.fittedRange);
        return this.scale(std);
    }

    private NDArray scale(NDArray std) {
        if (this.maxRange != 1.0f || this.minRange != 0.0f) {
            return std.muli(Float.valueOf(this.maxRange - this.minRange)).addi(Float.valueOf(this.minRange));
        }
        return std;
    }

    private NDArray inverseScale(NDArray std) {
        if (this.maxRange != 1.0f || this.minRange != 0.0f) {
            return std.sub(Float.valueOf(this.minRange)).divi(Float.valueOf(this.maxRange - this.minRange));
        }
        return std.duplicate();
    }

    private NDArray inverseScalei(NDArray std) {
        if (this.maxRange != 1.0f || this.minRange != 0.0f) {
            return std.subi(Float.valueOf(this.minRange)).divi(Float.valueOf(this.maxRange - this.minRange));
        }
        return std;
    }

    public NDArray inverseTransform(NDArray data) {
        this.throwsIllegalStateWhenNotFitted();
        NDArray result = this.inverseScale(data);
        return result.muli(this.fittedRange).addi(this.fittedMin);
    }

    public NDArray inverseTransformi(NDArray data) {
        this.throwsIllegalStateWhenNotFitted();
        NDArray result = this.inverseScalei(data);
        return result.muli(this.fittedRange).addi(this.fittedMin);
    }

    private void throwsIllegalStateWhenNotFitted() {
        if (this.fittedRange == null) {
            throw new IllegalStateException("Min Max Scaler is not fitted");
        }
    }

    public MinMaxScaler detach() {
        this.detached = true;
        if (this.fittedMin != null) {
            this.fittedMin.detach();
        }
        if (this.fittedMax != null) {
            this.fittedMax.detach();
        }
        if (this.fittedRange != null) {
            this.fittedRange.detach();
        }
        return this;
    }

    public MinMaxScaler optRange(float minRange, float maxRange) {
        this.minRange = minRange;
        this.maxRange = maxRange;
        return this;
    }

    public NDArray getMin() {
        this.throwsIllegalStateWhenNotFitted();
        return this.fittedMin;
    }

    public NDArray getMax() {
        this.throwsIllegalStateWhenNotFitted();
        return this.fittedMax;
    }

    @Override
    public void close() {
        if (this.fittedMin != null) {
            this.fittedMin.close();
        }
        if (this.fittedMax != null) {
            this.fittedMax.close();
        }
        if (this.fittedRange != null) {
            this.fittedRange.close();
        }
    }
}

