package org.pytorch.torchvision;

import android.graphics.Bitmap;
import android.media.Image;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Locale;
import org.pytorch.MemoryFormat;
import org.pytorch.Tensor;
import u5.a;
import u5.c;

/* loaded from: classes2.dex */
public final class TensorImageUtils {
    public static float[] TORCHVISION_NORM_MEAN_RGB = {0.485f, 0.456f, 0.406f};
    public static float[] TORCHVISION_NORM_STD_RGB = {0.229f, 0.224f, 0.225f};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes2.dex */
    public static class NativePeer {
        static {
            if (!a.b()) {
                a.a(new c());
            }
            a.c("pytorch_vision_jni");
        }

        private NativePeer() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static native void imageYUV420CenterCropToFloatBuffer(ByteBuffer byteBuffer, int i10, int i11, ByteBuffer byteBuffer2, ByteBuffer byteBuffer3, int i12, int i13, int i14, int i15, int i16, int i17, int i18, float[] fArr, float[] fArr2, Buffer buffer, int i19, int i20);
    }

    public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, int i10, int i11, int i12, int i13, float[] fArr, float[] fArr2) {
        return bitmapToFloat32Tensor(bitmap, i10, i11, i12, i13, fArr, fArr2, MemoryFormat.CONTIGUOUS);
    }

    public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, int i10, int i11, int i12, int i13, float[] fArr, float[] fArr2, MemoryFormat memoryFormat) {
        checkNormMeanArg(fArr);
        checkNormStdArg(fArr2);
        FloatBuffer allocateFloatBuffer = Tensor.allocateFloatBuffer(i12 * 3 * i13);
        bitmapToFloatBuffer(bitmap, i10, i11, i12, i13, fArr, fArr2, allocateFloatBuffer, 0, memoryFormat);
        return Tensor.fromBlob(allocateFloatBuffer, new long[]{1, 3, i13, i12}, memoryFormat);
    }

    public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, float[] fArr, float[] fArr2) {
        return bitmapToFloat32Tensor(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), fArr, fArr2, MemoryFormat.CONTIGUOUS);
    }

    public static Tensor bitmapToFloat32Tensor(Bitmap bitmap, float[] fArr, float[] fArr2, MemoryFormat memoryFormat) {
        checkNormMeanArg(fArr);
        checkNormStdArg(fArr2);
        return bitmapToFloat32Tensor(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), fArr, fArr2, memoryFormat);
    }

    public static void bitmapToFloatBuffer(Bitmap bitmap, int i10, int i11, int i12, int i13, float[] fArr, float[] fArr2, FloatBuffer floatBuffer, int i14) {
        bitmapToFloatBuffer(bitmap, i10, i11, i12, i13, fArr, fArr2, floatBuffer, i14, MemoryFormat.CONTIGUOUS);
    }

    public static void bitmapToFloatBuffer(Bitmap bitmap, int i10, int i11, int i12, int i13, float[] fArr, float[] fArr2, FloatBuffer floatBuffer, int i14, MemoryFormat memoryFormat) {
        checkOutBufferCapacity(floatBuffer, i14, i12, i13);
        checkNormMeanArg(fArr);
        checkNormStdArg(fArr2);
        MemoryFormat memoryFormat2 = MemoryFormat.CONTIGUOUS;
        if (memoryFormat != memoryFormat2 && memoryFormat != MemoryFormat.CHANNELS_LAST) {
            throw new IllegalArgumentException("Unsupported memory format " + memoryFormat);
        }
        int i15 = i13 * i12;
        int[] iArr = new int[i15];
        bitmap.getPixels(iArr, 0, i12, i10, i11, i12, i13);
        if (memoryFormat2 == memoryFormat) {
            int i16 = i15 * 2;
            for (int i17 = 0; i17 < i15; i17++) {
                int i18 = iArr[i17];
                floatBuffer.put(i14 + i17, ((((i18 >> 16) & 255) / 255.0f) - fArr[0]) / fArr2[0]);
                floatBuffer.put(i14 + i15 + i17, ((((i18 >> 8) & 255) / 255.0f) - fArr[1]) / fArr2[1]);
                floatBuffer.put(i14 + i16 + i17, (((i18 & 255) / 255.0f) - fArr[2]) / fArr2[2]);
            }
            return;
        }
        for (int i19 = 0; i19 < i15; i19++) {
            int i20 = iArr[i19];
            int i21 = (i19 * 3) + i14;
            floatBuffer.put(i21 + 0, ((((i20 >> 16) & 255) / 255.0f) - fArr[0]) / fArr2[0]);
            floatBuffer.put(i21 + 1, ((((i20 >> 8) & 255) / 255.0f) - fArr[1]) / fArr2[1]);
            floatBuffer.put(i21 + 2, (((i20 & 255) / 255.0f) - fArr[2]) / fArr2[2]);
        }
    }

    private static void checkNormMeanArg(float[] fArr) {
        if (fArr.length != 3) {
            throw new IllegalArgumentException("normMeanRGB length must be 3");
        }
    }

    private static void checkNormStdArg(float[] fArr) {
        if (fArr.length != 3) {
            throw new IllegalArgumentException("normStdRGB length must be 3");
        }
    }

    private static void checkOutBufferCapacity(FloatBuffer floatBuffer, int i10, int i11, int i12) {
        if (i10 + (i11 * 3 * i12) > floatBuffer.capacity()) {
            throw new IllegalStateException("Buffer underflow");
        }
    }

    private static void checkRotateCWDegrees(int i10) {
        if (i10 != 0 && i10 != 90 && i10 != 180 && i10 != 270) {
            throw new IllegalArgumentException("rotateCWDegrees must be one of 0, 90, 180, 270");
        }
    }

    private static void checkTensorSize(int i10, int i11) {
        if (i11 <= 0 || i10 <= 0) {
            throw new IllegalArgumentException("tensorHeight and tensorWidth must be positive");
        }
    }

    public static Tensor imageYUV420CenterCropToFloat32Tensor(Image image, int i10, int i11, int i12, float[] fArr, float[] fArr2) {
        return imageYUV420CenterCropToFloat32Tensor(image, i10, i11, i12, fArr, fArr2, MemoryFormat.CONTIGUOUS);
    }

    public static Tensor imageYUV420CenterCropToFloat32Tensor(Image image, int i10, int i11, int i12, float[] fArr, float[] fArr2, MemoryFormat memoryFormat) {
        if (image.getFormat() != 35) {
            throw new IllegalArgumentException(String.format(Locale.US, "Image format %d != ImageFormat.YUV_420_888", Integer.valueOf(image.getFormat())));
        }
        checkNormMeanArg(fArr);
        checkNormStdArg(fArr2);
        checkRotateCWDegrees(i10);
        checkTensorSize(i11, i12);
        FloatBuffer allocateFloatBuffer = Tensor.allocateFloatBuffer(i11 * 3 * i12);
        imageYUV420CenterCropToFloatBuffer(image, i10, i11, i12, fArr, fArr2, allocateFloatBuffer, 0, memoryFormat);
        return Tensor.fromBlob(allocateFloatBuffer, new long[]{1, 3, i12, i11}, memoryFormat);
    }

    public static void imageYUV420CenterCropToFloatBuffer(Image image, int i10, int i11, int i12, float[] fArr, float[] fArr2, FloatBuffer floatBuffer, int i13) {
        imageYUV420CenterCropToFloatBuffer(image, i10, i11, i12, fArr, fArr2, floatBuffer, i13, MemoryFormat.CONTIGUOUS);
    }

    public static void imageYUV420CenterCropToFloatBuffer(Image image, int i10, int i11, int i12, float[] fArr, float[] fArr2, FloatBuffer floatBuffer, int i13, MemoryFormat memoryFormat) {
        checkOutBufferCapacity(floatBuffer, i13, i11, i12);
        if (image.getFormat() != 35) {
            throw new IllegalArgumentException(String.format(Locale.US, "Image format %d != ImageFormat.YUV_420_888", Integer.valueOf(image.getFormat())));
        }
        checkNormMeanArg(fArr);
        checkNormStdArg(fArr2);
        checkRotateCWDegrees(i10);
        checkTensorSize(i11, i12);
        Image.Plane[] planes = image.getPlanes();
        Image.Plane plane = planes[0];
        Image.Plane plane2 = planes[1];
        NativePeer.imageYUV420CenterCropToFloatBuffer(plane.getBuffer(), plane.getRowStride(), plane.getPixelStride(), plane2.getBuffer(), planes[2].getBuffer(), plane2.getRowStride(), plane2.getPixelStride(), image.getWidth(), image.getHeight(), i10, i11, i12, fArr, fArr2, floatBuffer, i13, MemoryFormat.CONTIGUOUS == memoryFormat ? 1 : MemoryFormat.CHANNELS_LAST == memoryFormat ? 2 : 0);
    }
}
