package com.example.crehairsegment;

import android.graphics.Bitmap;
import android.util.Size;
import com.example.crehairsegment.utils.SegmentMaskType;
import com.example.crehairsegment.vision.CreImage;
import com.example.crehairsegment.vision.CrePreparedImage;
import com.example.crehairsegment.vision.CreVisionSegmentResult;
import com.example.crehairsegment.vision.base.SegmentPredictorOptions;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;
import org.tensorflow.lite.Interpreter;

/* loaded from: classes2.dex */
public final class Predictor {
    private static final int NUM_CHANNELS = 3;
    private static final String TAG = "Predictor";
    private ByteBuffer inputByteBuffer;
    private int[] intValues;
    private Interpreter interpreter;
    private SegmentPredictorOptions options;
    private ByteBuffer outputByteBuffer;
    private final int inputSize = 224;
    private final int outputSize = 224;
    private SegmentMaskType[] segmentClassifications = {SegmentMaskType.SEG_BG, SegmentMaskType.SEG_HAIR};

    public Predictor(SegmentPredictorOptions segmentPredictorOptions) {
        initializeValues(segmentPredictorOptions);
    }

    private final void initializeValues(SegmentPredictorOptions segmentPredictorOptions) {
        int i = this.inputSize;
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(i * 4 * i * 3);
        this.inputByteBuffer = allocateDirect;
        allocateDirect.order(ByteOrder.nativeOrder());
        int i2 = this.outputSize;
        ByteBuffer allocateDirect2 = ByteBuffer.allocateDirect(i2 * 4 * i2 * this.segmentClassifications.length);
        this.outputByteBuffer = allocateDirect2;
        allocateDirect2.order(ByteOrder.nativeOrder());
        int i3 = this.inputSize;
        this.intValues = new int[i3 * i3];
        this.options = segmentPredictorOptions;
    }

    private final CreVisionSegmentResult postprocess(CreImage creImage, CrePreparedImage crePreparedImage) {
        int i = this.outputSize;
        int[][] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = new int[this.outputSize];
        }
        int i3 = this.outputSize;
        float[][] fArr = new float[i3];
        for (int i4 = 0; i4 < i3; i4++) {
            fArr[i4] = new float[this.outputSize];
        }
        int i5 = this.outputSize;
        for (int i6 = 0; i6 < i5; i6++) {
            int i7 = this.outputSize;
            int length = i6 * i7 * this.segmentClassifications.length;
            for (int i8 = 0; i8 < i7; i8++) {
                SegmentMaskType[] segmentMaskTypeArr = this.segmentClassifications;
                int length2 = (segmentMaskTypeArr.length * i8) + length;
                int length3 = segmentMaskTypeArr.length;
                float f = 0.0f;
                int i9 = 0;
                for (int i10 = 0; i10 < length3; i10++) {
                    float f2 = this.outputByteBuffer.getFloat((length2 + i10) * 4);
                    if (f2 > f) {
                        i9 = i10;
                        f = f2;
                    }
                }
                iArr[i6][i8] = i9;
                fArr[i6][i8] = f;
            }
        }
        SegmentPredictorOptions segmentPredictorOptions = this.options;
        SegmentMaskType[] segmentMaskTypeArr2 = this.segmentClassifications;
        Size targetInferenceSize = crePreparedImage.getTargetInferenceSize();
        int i11 = this.outputSize;
        return new CreVisionSegmentResult(creImage, crePreparedImage, segmentPredictorOptions, segmentMaskTypeArr2, targetInferenceSize, new Size(i11, i11), crePreparedImage.getOffsetX(), crePreparedImage.getOffsetY(), iArr, fArr);
    }

    private final void preprocess(Bitmap bitmap) {
        bitmap.getPixels(this.intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        this.inputByteBuffer.rewind();
        int i = this.inputSize;
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = this.inputSize;
            for (int i4 = 0; i4 < i3; i4++) {
                int i5 = this.intValues[(this.inputSize * i2) + i4];
                this.inputByteBuffer.putFloat((((i5 >> 16) & 255) / 255.0f) - 0.5f);
                this.inputByteBuffer.putFloat(((i5 & 255) / 255.0f) - 0.5f);
                this.inputByteBuffer.putFloat((((i5 >> 8) & 255) / 255.0f) - 0.5f);
            }
        }
    }

    private final SegmentMaskType[] setTargetClassifications(SegmentMaskType[] segmentMaskTypeArr, List list) {
        if (list == null) {
            return segmentMaskTypeArr;
        }
        int length = segmentMaskTypeArr.length;
        for (int i = 0; i < length; i++) {
            if (!list.contains(segmentMaskTypeArr[i])) {
                segmentMaskTypeArr[i] = SegmentMaskType.SEG_BG;
            }
        }
        return segmentMaskTypeArr;
    }

    public final CreVisionSegmentResult predict(CreImage creImage) {
        int i = this.inputSize;
        CrePreparedImage create = CrePreparedImage.create(creImage, this.options.getCropAndScaleOption(), new Size(i, i));
        preprocess(create.getBitmapForModel());
        this.outputByteBuffer.rewind();
        this.interpreter.run(this.inputByteBuffer, this.outputByteBuffer);
        return postprocess(creImage, create);
    }

    public final void setInterpreter(Interpreter interpreter) {
        this.interpreter = interpreter;
    }

    public final void setOptions(SegmentPredictorOptions segmentPredictorOptions) {
        this.options = segmentPredictorOptions;
        this.segmentClassifications = setTargetClassifications(this.segmentClassifications, segmentPredictorOptions.getTargetSegments());
    }
}
