package deepboof.io.torch7;

import com.xshield.dc;
import deepboof.PaddingType;
import deepboof.factory.FactoryForwards;
import deepboof.forward.ConfigConvolve2D;
import deepboof.forward.ConfigPadding;
import deepboof.forward.ConfigSpatial;
import deepboof.forward.SpatialPadding2D_F32;
import deepboof.forward.SpatialPadding2D_F64;
import deepboof.graph.InputAddress;
import deepboof.graph.Node;
import deepboof.impl.forward.standard.BaseSpatialPadding2D;
import deepboof.impl.forward.standard.FunctionBatchNorm_F32;
import deepboof.impl.forward.standard.FunctionBatchNorm_F64;
import deepboof.impl.forward.standard.FunctionElementWiseMult_F32;
import deepboof.impl.forward.standard.FunctionElementWiseMult_F64;
import deepboof.impl.forward.standard.SpatialAveragePooling_F32;
import deepboof.impl.forward.standard.SpatialAveragePooling_F64;
import deepboof.impl.forward.standard.SpatialBatchNorm_F32;
import deepboof.impl.forward.standard.SpatialBatchNorm_F64;
import deepboof.impl.forward.standard.SpatialConvolve2D_F32;
import deepboof.impl.forward.standard.SpatialConvolve2D_F64;
import deepboof.impl.forward.standard.SpatialMaxPooling_F32;
import deepboof.impl.forward.standard.SpatialMaxPooling_F64;
import deepboof.io.torch7.struct.TorchBoolean;
import deepboof.io.torch7.struct.TorchGeneric;
import deepboof.io.torch7.struct.TorchList;
import deepboof.io.torch7.struct.TorchNumber;
import deepboof.io.torch7.struct.TorchObject;
import deepboof.io.torch7.struct.TorchReferenceable;
import deepboof.io.torch7.struct.TorchString;
import deepboof.io.torch7.struct.TorchTensor;
import deepboof.tensors.Tensor_F32;
import deepboof.tensors.Tensor_F64;
import deepboof.tensors.Tensor_S64;
import deepboof.tensors.Tensor_U8;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.ddogleg.struct.Tuple2;

/* loaded from: classes4.dex */
public class ConvertTorchToBoofForward {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes4.dex */
    public static /* synthetic */ class a {

        /* renamed from: a, reason: collision with root package name */
        static final /* synthetic */ int[] f853a;

        /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
        static {
            int[] iArr = new int[b.values().length];
            f853a = iArr;
            try {
                iArr[b.MAX.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                f853a[b.AVE.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes4.dex */
    public enum b {
        MAX,
        AVE
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static FunctionAndParameters a(TorchGeneric torchGeneric, b bVar, String str) {
        FunctionAndParameters functionAndParameters = new FunctionAndParameters();
        int f = f(torchGeneric, dc.m1348(-1477534333));
        int f2 = f(torchGeneric, dc.m1351(-1499420732));
        int f3 = f(torchGeneric, dc.m1355(-482168102));
        int f4 = f(torchGeneric, dc.m1348(-1477534461));
        int f5 = f(torchGeneric, dc.m1350(-1227280346));
        int f6 = f(torchGeneric, dc.m1348(-1477534397));
        ConfigPadding configPadding = new ConfigPadding();
        configPadding.d = f;
        configPadding.b = f;
        configPadding.c = f2;
        configPadding.f844a = f2;
        int[] iArr = a.f853a;
        int i = iArr[bVar.ordinal()];
        String m1353 = dc.m1353(-905203595);
        if (i == 1) {
            configPadding.e = PaddingType.CLIPPED;
        } else {
            if (i != 2) {
                throw new IllegalArgumentException(m1353);
            }
            configPadding.e = PaddingType.ZERO;
        }
        ConfigSpatial configSpatial = new ConfigSpatial();
        configSpatial.d = f5;
        configSpatial.c = f6;
        configSpatial.b = f3;
        configSpatial.f845a = f4;
        str.hashCode();
        if (str.equals("torch.FloatTensor")) {
            BaseSpatialPadding2D a2 = FactoryForwards.a(configPadding, Tensor_F32.class);
            int i2 = iArr[bVar.ordinal()];
            if (i2 == 1) {
                functionAndParameters.f855a = new SpatialMaxPooling_F32(configSpatial, (SpatialPadding2D_F32) a2);
            } else {
                if (i2 != 2) {
                    throw new RuntimeException(m1353);
                }
                functionAndParameters.f855a = new SpatialAveragePooling_F32(configSpatial, (SpatialPadding2D_F32) a2);
            }
        } else {
            if (!str.equals("torch.DoubleTensor")) {
                throw new RuntimeException(dc.m1353(-905203283) + str);
            }
            BaseSpatialPadding2D a3 = FactoryForwards.a(configPadding, Tensor_F64.class);
            int i3 = iArr[bVar.ordinal()];
            if (i3 == 1) {
                functionAndParameters.f855a = new SpatialMaxPooling_F64(configSpatial, (SpatialPadding2D_F64) a3);
            } else {
                if (i3 != 2) {
                    throw new RuntimeException(m1353);
                }
                functionAndParameters.f855a = new SpatialAveragePooling_F64(configSpatial, (SpatialPadding2D_F64) a3);
            }
        }
        return functionAndParameters;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static FunctionAndParameters a(TorchGeneric torchGeneric, String str) {
        FunctionAndParameters functionAndParameters = new FunctionAndParameters();
        str.hashCode();
        if (str.equals("torch.FloatTensor")) {
            Tuple2<Tensor_F32, Float> b2 = b(torchGeneric);
            FunctionBatchNorm_F32 functionBatchNorm_F32 = new FunctionBatchNorm_F32(b2.data0.e(1) == 4);
            functionBatchNorm_F32.a(b2.data1.floatValue());
            functionAndParameters.f855a = functionBatchNorm_F32;
            functionAndParameters.b.add(b2.data0);
        } else {
            if (!str.equals("torch.DoubleTensor")) {
                throw new RuntimeException(dc.m1353(-905203283) + str);
            }
            Tuple2<Tensor_F64, Double> c = c(torchGeneric);
            FunctionBatchNorm_F64 functionBatchNorm_F64 = new FunctionBatchNorm_F64(c.data0.e(1) == 4);
            functionBatchNorm_F64.a(c.data1.doubleValue());
            functionAndParameters.f855a = functionBatchNorm_F64;
            functionAndParameters.b.add(c.data0);
        }
        return functionAndParameters;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static Tensor_F32 a(TorchTensor torchTensor) {
        int[] iArr = torchTensor.d;
        if (iArr == null || iArr.length == 0) {
            return new Tensor_F32();
        }
        Tensor_F32 tensor_F32 = new Tensor_F32();
        tensor_F32.f840a = torchTensor.d;
        tensor_F32.e();
        if (torchTensor.e == 0 || torchTensor.a() == torchTensor.f.c()) {
            tensor_F32.e = (float[]) torchTensor.f.a();
        } else {
            tensor_F32.e = new float[torchTensor.a()];
            Object a2 = torchTensor.f.a();
            int i = torchTensor.e;
            float[] fArr = tensor_F32.e;
            System.arraycopy(a2, i, fArr, 0, fArr.length);
        }
        return tensor_F32;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:121:0x026d, code lost:
    
        if (r0.equals(com.xshield.dc.m1355(-482177486)) == false) goto L121;
     */
    /* JADX WARN: Code restructure failed: missing block: B:92:0x00c6, code lost:
    
        if (r11.equals("nn.SpatialDropout") == false) goto L11;
     */
    /* JADX WARN: Type inference failed for: r0v12, types: [deepboof.io.torch7.FunctionAndParameters, T] */
    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static <T> T a(deepboof.io.torch7.struct.TorchObject r13) {
        /*
            Method dump skipped, instructions count: 838
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: deepboof.io.torch7.ConvertTorchToBoofForward.a(deepboof.io.torch7.struct.TorchObject):java.lang.Object");
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static String a(TorchGeneric torchGeneric) {
        String str;
        Map<Object, TorchObject> map = torchGeneric.d;
        String m1355 = dc.m1355(-482177262);
        if (!map.containsKey(m1355)) {
            Iterator<Object> it = torchGeneric.d.keySet().iterator();
            String str2 = null;
            while (true) {
                if (!it.hasNext()) {
                    str = str2;
                    break;
                }
                TorchObject torchObject = torchGeneric.d.get(it.next());
                if (torchObject instanceof TorchTensor) {
                    str = ((TorchTensor) torchObject).f864a;
                    break;
                }
                if (torchObject instanceof TorchList) {
                    List<TorchObject> list = ((TorchList) torchObject).d;
                    for (int i = 0; i < list.size() && (!(list.get(i) instanceof TorchGeneric) || (str2 = a((TorchGeneric) list.get(i))) == null); i++) {
                    }
                } else if (torchObject instanceof TorchGeneric) {
                    TorchGeneric torchGeneric2 = (TorchGeneric) torchObject;
                    if (torchGeneric2.d.containsKey(m1355)) {
                        str = ((TorchString) torchGeneric2.d.get(m1355)).f865a;
                        break;
                    }
                } else {
                    continue;
                }
            }
        } else {
            str = ((TorchString) torchGeneric.d.get(m1355)).f865a;
        }
        return (str == null || !str.equals("torch.CudaTensor")) ? str : "torch.FloatTensor";
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static FunctionAndParameters b(TorchGeneric torchGeneric, String str) {
        Map<Object, TorchObject> map = torchGeneric.d;
        String m1351 = dc.m1351(-1499424220);
        boolean z = map.containsKey(m1351) ? true ^ ((TorchBoolean) torchGeneric.d.get(m1351)).f861a : true;
        Map<Object, TorchObject> map2 = torchGeneric.d;
        String m1355 = dc.m1355(-482177206);
        if (map2.containsKey(m1355) && ((TorchBoolean) torchGeneric.d.get(m1355)).f861a) {
            throw new IllegalArgumentException("stochastic_inference is not yet supported.  This means that it should always behave as if it's in training mode");
        }
        if (!z) {
            return null;
        }
        FunctionAndParameters functionAndParameters = new FunctionAndParameters();
        double d = 1.0d - ((TorchNumber) torchGeneric.d.get(dc.m1350(-1227270490))).f862a;
        str.hashCode();
        if (str.equals("torch.FloatTensor")) {
            functionAndParameters.f855a = new FunctionElementWiseMult_F32((float) d);
        } else {
            if (!str.equals("torch.DoubleTensor")) {
                throw new RuntimeException(dc.m1355(-481292334) + str);
            }
            functionAndParameters.f855a = new FunctionElementWiseMult_F64(d);
        }
        return functionAndParameters;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static Tensor_F64 b(TorchTensor torchTensor) {
        int[] iArr = torchTensor.d;
        if (iArr == null || iArr.length == 0) {
            return new Tensor_F64();
        }
        Tensor_F64 tensor_F64 = new Tensor_F64();
        tensor_F64.f840a = torchTensor.d;
        tensor_F64.e();
        if (torchTensor.e == 0 || torchTensor.a() == torchTensor.f.c()) {
            tensor_F64.e = (double[]) torchTensor.f.a();
        } else {
            tensor_F64.e = new double[torchTensor.a()];
            Object a2 = torchTensor.f.a();
            int i = torchTensor.e;
            double[] dArr = tensor_F64.e;
            System.arraycopy(a2, i, dArr, 0, dArr.length);
        }
        return tensor_F64;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static Tuple2<Tensor_F32, Float> b(TorchGeneric torchGeneric) {
        Tensor_F32 tensor_F32;
        Tensor_F32 tensor_F322 = (Tensor_F32) a(torchGeneric.d.get(dc.m1351(-1499423404)));
        Tensor_F32 tensor_F323 = (Tensor_F32) a(torchGeneric.d.get(dc.m1350(-1227289706)));
        float floatValue = ((Double) a(torchGeneric.d.get(dc.m1350(-1227289842)))).floatValue();
        int h = tensor_F322.h();
        Map<Object, TorchObject> map = torchGeneric.d;
        String m1350 = dc.m1350(-1227289442);
        int i = 0;
        if (map.containsKey(m1350)) {
            Tensor_F32 tensor_F324 = (Tensor_F32) a(torchGeneric.d.get(m1350));
            Tensor_F32 tensor_F325 = (Tensor_F32) a(torchGeneric.d.get(dc.m1351(-1499423828)));
            tensor_F32 = new Tensor_F32(h, 4);
            while (i < h) {
                float[] fArr = tensor_F32.e;
                int i2 = i * 4;
                fArr[i2] = tensor_F322.e[i];
                fArr[i2 + 1] = tensor_F323.e[i];
                fArr[i2 + 2] = tensor_F324.e[i];
                fArr[i2 + 3] = tensor_F325.e[i];
                i++;
            }
        } else {
            tensor_F32 = new Tensor_F32(h, 2);
            while (i < h) {
                float[] fArr2 = tensor_F32.e;
                int i3 = i * 2;
                fArr2[i3] = tensor_F322.e[i];
                fArr2[i3 + 1] = tensor_F323.e[i];
                i++;
            }
        }
        return new Tuple2<>(tensor_F32, Float.valueOf(floatValue));
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static SequenceAndParameters c(TorchGeneric torchGeneric, String str) {
        SequenceAndParameters sequenceAndParameters = new SequenceAndParameters();
        TorchList torchList = (TorchList) torchGeneric.d.get("modules");
        str.hashCode();
        if (str.equals("torch.FloatTensor")) {
            sequenceAndParameters.c = Tensor_F32.class;
        } else {
            if (!str.equals("torch.DoubleTensor")) {
                throw new RuntimeException(dc.m1355(-481292334) + str);
            }
            sequenceAndParameters.c = Tensor_F64.class;
        }
        for (int i = 0; i < torchList.d.size(); i++) {
            TorchObject torchObject = torchList.d.get(i);
            Object a2 = a(torchObject);
            if (a2 != null) {
                if (a2 instanceof FunctionAndParameters) {
                    FunctionAndParameters functionAndParameters = (FunctionAndParameters) a2;
                    Node node = new Node();
                    node.d = functionAndParameters.f855a;
                    String str2 = dc.m1348(-1477532125) + ((TorchReferenceable) torchObject).c;
                    node.b = str2;
                    sequenceAndParameters.b.put(str2, functionAndParameters.b);
                    if (sequenceAndParameters.f858a.size() > 0) {
                        InputAddress inputAddress = new InputAddress();
                        inputAddress.f847a = ((Node) sequenceAndParameters.f858a.get(r3.size() - 1)).b;
                        node.f848a.add(inputAddress);
                    }
                    sequenceAndParameters.f858a.add(node);
                } else {
                    if (!(a2 instanceof SequenceAndParameters)) {
                        throw new RuntimeException("Unexpected type");
                    }
                    SequenceAndParameters sequenceAndParameters2 = (SequenceAndParameters) a2;
                    for (int i2 = 0; i2 < sequenceAndParameters2.f858a.size(); i2++) {
                        Node node2 = (Node) sequenceAndParameters2.f858a.get(i2);
                        if (i2 == 0 && sequenceAndParameters.f858a.size() > 0) {
                            InputAddress inputAddress2 = new InputAddress();
                            inputAddress2.f847a = ((Node) sequenceAndParameters.f858a.get(r6.size() - 1)).b;
                            node2.f848a.add(inputAddress2);
                        }
                        sequenceAndParameters.f858a.add(node2);
                        Map<String, List<T>> map = sequenceAndParameters.b;
                        String str3 = node2.b;
                        map.put(str3, sequenceAndParameters2.b.get(str3));
                    }
                }
            }
        }
        return sequenceAndParameters;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static Tensor_S64 c(TorchTensor torchTensor) {
        int[] iArr = torchTensor.d;
        if (iArr == null || iArr.length == 0) {
            return new Tensor_S64();
        }
        Tensor_S64 tensor_S64 = new Tensor_S64();
        tensor_S64.f840a = torchTensor.d;
        tensor_S64.e();
        if (torchTensor.e == 0 || torchTensor.a() == torchTensor.f.c()) {
            tensor_S64.e = (long[]) torchTensor.f.a();
        } else {
            tensor_S64.e = new long[torchTensor.a()];
            Object a2 = torchTensor.f.a();
            int i = torchTensor.e;
            long[] jArr = tensor_S64.e;
            System.arraycopy(a2, i, jArr, 0, jArr.length);
        }
        return tensor_S64;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static Tuple2<Tensor_F64, Double> c(TorchGeneric torchGeneric) {
        Tensor_F64 tensor_F64;
        Tensor_F64 tensor_F642 = (Tensor_F64) a(torchGeneric.d.get(dc.m1351(-1499423404)));
        Tensor_F64 tensor_F643 = (Tensor_F64) a(torchGeneric.d.get(dc.m1350(-1227289706)));
        double doubleValue = ((Double) a(torchGeneric.d.get(dc.m1350(-1227289842)))).doubleValue();
        int h = tensor_F642.h();
        Map<Object, TorchObject> map = torchGeneric.d;
        String m1350 = dc.m1350(-1227289442);
        int i = 0;
        if (map.containsKey(m1350)) {
            Tensor_F64 tensor_F644 = (Tensor_F64) a(torchGeneric.d.get(m1350));
            Tensor_F64 tensor_F645 = (Tensor_F64) a(torchGeneric.d.get(dc.m1351(-1499423828)));
            tensor_F64 = new Tensor_F64(h, 4);
            while (i < h) {
                double[] dArr = tensor_F64.e;
                int i2 = i * 4;
                dArr[i2] = tensor_F642.e[i];
                dArr[i2 + 1] = tensor_F643.e[i];
                dArr[i2 + 2] = tensor_F644.e[i];
                dArr[i2 + 3] = tensor_F645.e[i];
                i++;
            }
        } else {
            tensor_F64 = new Tensor_F64(h, 2);
            while (i < h) {
                double[] dArr2 = tensor_F64.e;
                int i3 = i * 2;
                dArr2[i3] = tensor_F642.e[i];
                dArr2[i3 + 1] = tensor_F643.e[i];
                i++;
            }
        }
        return new Tuple2<>(tensor_F64, Double.valueOf(doubleValue));
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static FunctionAndParameters d(TorchGeneric torchGeneric, String str) {
        FunctionAndParameters functionAndParameters = new FunctionAndParameters();
        str.hashCode();
        if (str.equals("torch.FloatTensor")) {
            Tuple2<Tensor_F32, Float> b2 = b(torchGeneric);
            SpatialBatchNorm_F32 spatialBatchNorm_F32 = new SpatialBatchNorm_F32(b2.data0.e(1) == 4);
            spatialBatchNorm_F32.a(b2.data1.floatValue());
            functionAndParameters.f855a = spatialBatchNorm_F32;
            functionAndParameters.b.add(b2.data0);
        } else {
            if (!str.equals("torch.DoubleTensor")) {
                throw new RuntimeException(dc.m1353(-905203283) + str);
            }
            Tuple2<Tensor_F64, Double> c = c(torchGeneric);
            SpatialBatchNorm_F64 spatialBatchNorm_F64 = new SpatialBatchNorm_F64(c.data0.e(1) == 4);
            spatialBatchNorm_F64.a(c.data1.doubleValue());
            functionAndParameters.f855a = spatialBatchNorm_F64;
            functionAndParameters.b.add(c.data0);
        }
        return functionAndParameters;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static Tensor_U8 d(TorchTensor torchTensor) {
        int[] iArr = torchTensor.d;
        if (iArr == null || iArr.length == 0) {
            return new Tensor_U8();
        }
        Tensor_U8 tensor_U8 = new Tensor_U8();
        tensor_U8.f840a = torchTensor.d;
        tensor_U8.e();
        if (torchTensor.e == 0 || torchTensor.a() == torchTensor.f.c()) {
            tensor_U8.e = (byte[]) torchTensor.f.a();
        } else {
            tensor_U8.e = new byte[torchTensor.a()];
            Object a2 = torchTensor.f.a();
            int i = torchTensor.e;
            byte[] bArr = tensor_U8.e;
            System.arraycopy(a2, i, bArr, 0, bArr.length);
        }
        return tensor_U8;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static FunctionAndParameters e(TorchGeneric torchGeneric, String str) {
        FunctionAndParameters functionAndParameters = new FunctionAndParameters();
        int f = f(torchGeneric, dc.m1348(-1477534333));
        int f2 = f(torchGeneric, dc.m1351(-1499420732));
        int f3 = f(torchGeneric, dc.m1355(-482168102));
        int f4 = f(torchGeneric, dc.m1348(-1477534461));
        int f5 = f(torchGeneric, dc.m1350(-1227280346));
        int f6 = f(torchGeneric, dc.m1348(-1477534397));
        int f7 = f(torchGeneric, dc.m1352(778736217));
        ConfigPadding configPadding = new ConfigPadding();
        configPadding.d = f;
        configPadding.b = f;
        configPadding.c = f2;
        configPadding.f844a = f2;
        configPadding.e = PaddingType.ZERO;
        ConfigConvolve2D configConvolve2D = new ConfigConvolve2D();
        configConvolve2D.d = f5;
        configConvolve2D.c = f6;
        configConvolve2D.e = f7;
        configConvolve2D.b = f3;
        configConvolve2D.f845a = f4;
        str.hashCode();
        if (str.equals("torch.FloatTensor")) {
            functionAndParameters.f855a = new SpatialConvolve2D_F32(configConvolve2D, (SpatialPadding2D_F32) FactoryForwards.a(configPadding, Tensor_F32.class));
        } else {
            if (!str.equals("torch.DoubleTensor")) {
                throw new RuntimeException(dc.m1353(-905203283) + str);
            }
            functionAndParameters.f855a = new SpatialConvolve2D_F64(configConvolve2D, (SpatialPadding2D_F64) FactoryForwards.a(configPadding, Tensor_F64.class));
        }
        functionAndParameters.b.add(a(torchGeneric.d.get(dc.m1350(-1227289442))));
        functionAndParameters.b.add(a(torchGeneric.d.get(dc.m1351(-1499423828))));
        return functionAndParameters;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    private static int f(TorchGeneric torchGeneric, String str) {
        return (int) ((TorchNumber) torchGeneric.d.get(str)).f862a;
    }
}
