package edu.ucla.sspace.purandare;

import edu.ucla.sspace.clustering.Assignments;
import edu.ucla.sspace.clustering.ClutoClustering;
import edu.ucla.sspace.common.SemanticSpace;
import edu.ucla.sspace.common.Statistics;
import edu.ucla.sspace.matrix.AtomicGrowingMatrix;
import edu.ucla.sspace.matrix.AtomicMatrix;
import edu.ucla.sspace.matrix.Matrix;
import edu.ucla.sspace.matrix.SparseRowMaskedMatrix;
import edu.ucla.sspace.matrix.YaleSparseMatrix;
import edu.ucla.sspace.text.IteratorFactory;
import edu.ucla.sspace.util.WorkerThread;
import edu.ucla.sspace.vector.CompactSparseVector;
import edu.ucla.sspace.vector.DoubleVector;
import edu.ucla.sspace.vector.VectorMath;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOError;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: classes.dex */
public class PurandareFirstOrder implements SemanticSpace {
    private static final Logger LOGGER = Logger.getLogger(PurandareFirstOrder.class.getName());
    public static final String MAX_CONTEXTS_PER_WORD = "edu.ucla.sspace.purandare.PurandareFirstOrder.maxContexts";
    private static final String PROPERTY_PREFIX = "edu.ucla.sspace.purandare.PurandareFirstOrder";
    private File compressedDocuments;
    private DataOutputStream compressedDocumentsWriter;
    private final int contextWindowSize;
    private final AtomicMatrix cooccurrenceMatrix;
    private final AtomicInteger documentCounter;
    private final int maxContextsPerWord;
    private final List<AtomicInteger> termCounts;
    private final Map<String, Integer> termToIndex;
    private final Map<String, DoubleVector> termToVector;
    private final int windowSize;
    private int wordIndexCounter;

    public PurandareFirstOrder() {
        this(System.getProperties());
    }

    public PurandareFirstOrder(Properties properties) {
        this.cooccurrenceMatrix = new AtomicGrowingMatrix();
        this.termToIndex = new ConcurrentHashMap();
        this.termToVector = new ConcurrentHashMap();
        this.termCounts = new CopyOnWriteArrayList();
        this.windowSize = 5;
        this.contextWindowSize = 20;
        this.documentCounter = new AtomicInteger(0);
        String property = properties.getProperty(MAX_CONTEXTS_PER_WORD);
        if (property == null) {
            this.maxContextsPerWord = Integer.MAX_VALUE;
        } else {
            int parseInt = Integer.parseInt(property);
            if (parseInt <= 0) {
                throw new IllegalArgumentException("The number of contexts must be a positive number");
            }
            this.maxContextsPerWord = parseInt;
        }
        try {
            this.compressedDocuments = File.createTempFile("petersen-documents", ".dat");
            this.compressedDocumentsWriter = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(this.compressedDocuments)));
        } catch (IOException e) {
            throw new IOError(e);
        }
    }

    private BitSet calculateTermFeatures(String str, int i) {
        int intValue = this.termToIndex.get(str).intValue();
        LOGGER.fine(String.format("Calculating feature set for %6d/%d: %s", Integer.valueOf(intValue), Integer.valueOf(this.cooccurrenceMatrix.rows()), str));
        DoubleVector rowVector = this.cooccurrenceMatrix.getRowVector(intValue);
        int i2 = this.termCounts.get(intValue).get();
        BitSet bitSet = new BitSet(this.wordIndexCounter);
        for (int i3 = 0; i3 < rowVector.length(); i3++) {
            double d = rowVector.get(i3);
            if (d != 0.0d) {
                double d2 = this.termCounts.get(i3).get();
                Double.isNaN(d2);
                double d3 = d2 - d;
                double d4 = i2;
                Double.isNaN(d4);
                double d5 = d4 - d;
                double d6 = i;
                Double.isNaN(d6);
                if (logLikelihood(d, d3, d5, d6 - ((d + d3) + d5)) > 3.841d) {
                    bitSet.set(i3);
                }
            }
        }
        if (LOGGER.isLoggable(Level.FINE)) {
            LOGGER.fine(str + " had " + bitSet.cardinality() + " features");
        }
        return bitSet;
    }

    private final int getIndexFor(String str) {
        Integer num = this.termToIndex.get(str);
        if (num == null) {
            synchronized (this) {
                num = this.termToIndex.get(str);
                if (num == null) {
                    int i = this.wordIndexCounter;
                    this.wordIndexCounter = i + 1;
                    this.termCounts.add(new AtomicInteger(0));
                    this.termToIndex.put(str, Integer.valueOf(i));
                    return i;
                }
            }
        }
        return num.intValue();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Matrix getTermContexts(int i, BitSet bitSet) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(this.compressedDocuments)));
        int i2 = this.documentCounter.get();
        YaleSparseMatrix yaleSparseMatrix = new YaleSparseMatrix(this.termCounts.get(i).get(), this.termToIndex.size());
        int i3 = 0;
        for (int i4 = 0; i4 < i2; i4++) {
            int readInt = dataInputStream.readInt();
            dataInputStream.readInt();
            int[] iArr = new int[readInt];
            for (int i5 = 0; i5 < readInt; i5++) {
                iArr[i5] = dataInputStream.readInt();
            }
            i3 += processIntDocument(i, iArr, yaleSparseMatrix, i3, bitSet);
        }
        dataInputStream.close();
        if (this.maxContextsPerWord < Integer.MAX_VALUE) {
            int rows = yaleSparseMatrix.rows();
            int i6 = this.maxContextsPerWord;
            if (rows > i6) {
                return new SparseRowMaskedMatrix(yaleSparseMatrix, Statistics.randomDistribution(i6, yaleSparseMatrix.rows()));
            }
        }
        return yaleSparseMatrix;
    }

    private static double logLikelihood(double d, double d2, double d3, double d4) {
        double d5 = d + d3;
        double d6 = d2 + d4;
        double d7 = d + d2;
        double d8 = d3 + d4;
        double d9 = d7 + d8;
        double d10 = d7 / d9;
        double d11 = d10 * d5;
        double d12 = d10 * d6;
        double d13 = d8 / d9;
        double d14 = d5 * d13;
        double d15 = d13 * d6;
        double log = d == 0.0d ? 0.0d : Math.log(d / d11) * d;
        double log2 = d2 == 0.0d ? 0.0d : Math.log(d2 / d12) * d2;
        return (log + log2 + (d3 == 0.0d ? 0.0d : Math.log(d3 / d14) * d3) + (d4 != 0.0d ? Math.log(d4 / d15) * d4 : 0.0d)) * 2.0d;
    }

    /*  JADX ERROR: NullPointerException in pass: LoopRegionVisitor
        java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.SSAVar.use(jadx.core.dex.instructions.args.RegisterArg)" because "ssaVar" is null
        	at jadx.core.dex.nodes.InsnNode.rebindArgs(InsnNode.java:489)
        	at jadx.core.dex.nodes.InsnNode.rebindArgs(InsnNode.java:492)
        */
    /* JADX WARN: Multi-variable type inference failed */
    private int processIntDocument(int r17, int[] r18, edu.ucla.sspace.matrix.Matrix r19, int r20, java.util.BitSet r21) {
        /*
            r16 = this;
            r0 = r16
            r1 = r18
            r2 = r21
            r3 = 0
            r4 = 0
            r5 = 0
        L9:
            int r6 = r1.length
            if (r4 >= r6) goto L9e
            r6 = r1[r4]
            r7 = r17
            if (r6 == r7) goto L16
            r15 = r19
            goto L9a
        L16:
            edu.ucla.sspace.util.SparseHashArray r6 = new edu.ucla.sspace.util.SparseHashArray
            r6.<init>()
            int r8 = r0.contextWindowSize
            int r8 = r4 - r8
            int r8 = java.lang.Math.max(r8, r3)
        L23:
            r9 = 1
            if (r8 >= r4) goto L48
            r10 = r1[r8]
            if (r10 < 0) goto L45
            boolean r11 = r2.get(r10)
            if (r11 == 0) goto L45
            java.lang.Object r11 = r6.get(r10)
            java.lang.Integer r11 = (java.lang.Integer) r11
            if (r11 != 0) goto L39
            goto L3e
        L39:
            int r11 = r11.intValue()
            int r9 = r9 + r11
        L3e:
            java.lang.Integer r9 = java.lang.Integer.valueOf(r9)
            r6.set(r10, r9)
        L45:
            int r8 = r8 + 1
            goto L23
        L48:
            int r8 = r0.contextWindowSize
            int r8 = r8 + r4
            int r10 = r1.length
            int r8 = java.lang.Math.min(r8, r10)
            int r10 = r4 + 1
        L52:
            if (r10 >= r8) goto L77
            r11 = r1[r10]
            if (r11 < 0) goto L74
            boolean r12 = r2.get(r11)
            if (r12 == 0) goto L74
            java.lang.Object r12 = r6.get(r11)
            java.lang.Integer r12 = (java.lang.Integer) r12
            if (r12 != 0) goto L68
            r12 = 1
            goto L6d
        L68:
            int r12 = r12.intValue()
            int r12 = r12 + r9
        L6d:
            java.lang.Integer r12 = java.lang.Integer.valueOf(r12)
            r6.set(r11, r12)
        L74:
            int r10 = r10 + 1
            goto L52
        L77:
            int r8 = r20 + r5
            int[] r9 = r6.getElementIndices()
            int r10 = r9.length
            r11 = 0
        L7f:
            if (r11 >= r10) goto L96
            r12 = r9[r11]
            java.lang.Object r13 = r6.get(r12)
            java.lang.Integer r13 = (java.lang.Integer) r13
            int r13 = r13.intValue()
            double r13 = (double) r13
            r15 = r19
            r15.set(r8, r12, r13)
            int r11 = r11 + 1
            goto L7f
        L96:
            r15 = r19
            int r5 = r5 + 1
        L9a:
            int r4 = r4 + 1
            goto L9
        L9e:
            return r5
        */
        throw new UnsupportedOperationException("Method not decompiled: edu.ucla.sspace.purandare.PurandareFirstOrder.processIntDocument(int, int[], edu.ucla.sspace.matrix.Matrix, int, java.util.BitSet):int");
    }

    private void processSpace() throws IOException {
        this.compressedDocumentsWriter.close();
        String[] strArr = new String[this.termToIndex.size()];
        for (Map.Entry<String, Integer> entry : this.termToIndex.entrySet()) {
            strArr[entry.getValue().intValue()] = entry.getKey();
        }
        Iterator<AtomicInteger> it = this.termCounts.iterator();
        int i = 0;
        while (it.hasNext()) {
            i += it.next().get();
        }
        final int rows = this.cooccurrenceMatrix.rows();
        LOGGER.info("calculating term features");
        final BitSet[] bitSetArr = new BitSet[this.wordIndexCounter];
        for (int i2 = 0; i2 < rows; i2++) {
            bitSetArr[i2] = calculateTermFeatures(strArr[i2], i);
        }
        LOGGER.info("reprocessing corpus to generate feature vectors");
        LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue();
        for (int i3 = 0; i3 < Runtime.getRuntime().availableProcessors(); i3++) {
            new WorkerThread(linkedBlockingQueue).start();
        }
        final Semaphore semaphore = new Semaphore(0);
        for (int i4 = 0; i4 < rows; i4++) {
            final String str = strArr[i4];
            final int i5 = i4;
            linkedBlockingQueue.offer(new Runnable() { // from class: edu.ucla.sspace.purandare.PurandareFirstOrder.1
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        try {
                            PurandareFirstOrder.LOGGER.fine(String.format("processing term %6d/%d: %s", Integer.valueOf(i5), Integer.valueOf(rows), str));
                            PurandareFirstOrder.this.senseInduce(str, PurandareFirstOrder.this.getTermContexts(i5, bitSetArr[i5]));
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    } finally {
                        semaphore.release();
                    }
                }
            });
        }
        try {
            semaphore.acquire(rows);
            LOGGER.info("finished reprocessing all terms");
        } catch (InterruptedException e) {
            throw new Error("interrupted while waiting for terms to finish reprocessing", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void senseInduce(String str, Matrix matrix) throws IOException {
        LOGGER.fine("Clustering " + matrix.rows() + " contexts for " + str);
        int min = Math.min(7, matrix.rows());
        int i = 0;
        if (!str.matches("[a-zA-z]+") || min <= 6) {
            CompactSparseVector compactSparseVector = new CompactSparseVector(this.termToIndex.size());
            int rows = matrix.rows();
            while (i < rows) {
                VectorMath.add((DoubleVector) compactSparseVector, matrix.getRowVector(i));
                i++;
            }
            this.termToVector.put(str, compactSparseVector);
            return;
        }
        Assignments cluster = new ClutoClustering().cluster(matrix, min, ClutoClustering.Method.AGGLOMERATIVE, ClutoClustering.Criterion.UPGMA);
        LOGGER.fine("Generative sense vectors for " + str);
        int[] iArr = new int[min];
        CompactSparseVector[] compactSparseVectorArr = new CompactSparseVector[min];
        for (int i2 = 0; i2 < compactSparseVectorArr.length; i2++) {
            compactSparseVectorArr[i2] = new CompactSparseVector(this.termToIndex.size());
        }
        for (int i3 = 0; i3 < cluster.size(); i3++) {
            if (cluster.get(i3).assignments().length != 0) {
                int i4 = cluster.get(i3).assignments()[0];
                iArr[i4] = iArr[i4] + 1;
                VectorMath.add((DoubleVector) compactSparseVectorArr[i4], matrix.getRowVector(i3));
            }
        }
        int i5 = 0;
        while (i < min) {
            double d = iArr[i];
            double rows2 = matrix.rows();
            Double.isNaN(d);
            Double.isNaN(rows2);
            if (d / rows2 > 0.02d) {
                String str2 = i5 == 0 ? str : str + "-" + i5;
                i5++;
                this.termToVector.put(str2, compactSparseVectorArr[i]);
            }
            i++;
        }
        LOGGER.fine("Discovered " + i5 + " senses for " + str);
    }

    @Override // edu.ucla.sspace.common.SemanticSpace
    public String getSpaceName() {
        return "purandare-petersen";
    }

    @Override // edu.ucla.sspace.common.SemanticSpace
    public DoubleVector getVector(String str) {
        return this.termToVector.get(str);
    }

    @Override // edu.ucla.sspace.common.SemanticSpace
    public int getVectorLength() {
        return this.termToIndex.size();
    }

    @Override // edu.ucla.sspace.common.SemanticSpace
    public Set<String> getWords() {
        return Collections.unmodifiableSet(this.termToVector.keySet());
    }

    @Override // edu.ucla.sspace.common.SemanticSpace
    public void processDocument(BufferedReader bufferedReader) throws IOException {
        this.documentCounter.getAndIncrement();
        ArrayDeque<String> arrayDeque = new ArrayDeque();
        ArrayDeque<String> arrayDeque2 = new ArrayDeque();
        Iterator<String> it = IteratorFactory.tokenizeOrdered(bufferedReader);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(4096);
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        for (int i = 0; i < this.windowSize && it.hasNext(); i++) {
            arrayDeque.offer(it.next());
        }
        int i2 = 0;
        int i3 = 0;
        while (!arrayDeque.isEmpty()) {
            i2++;
            String str = (String) arrayDeque.remove();
            if (it.hasNext()) {
                arrayDeque.offer(it.next());
            }
            if (str.equals("")) {
                dataOutputStream.writeInt(-1);
                arrayDeque2.offer(str);
                if (arrayDeque2.size() > this.windowSize) {
                    arrayDeque2.remove();
                }
            } else {
                int indexFor = getIndexFor(str);
                dataOutputStream.writeInt(indexFor);
                this.termCounts.get(indexFor).incrementAndGet();
                i3++;
                for (String str2 : arrayDeque) {
                    if (!str2.equals("")) {
                        this.cooccurrenceMatrix.addAndGet(indexFor, getIndexFor(str2), 1.0d);
                    }
                }
                for (String str3 : arrayDeque2) {
                    if (!str3.equals("")) {
                        this.cooccurrenceMatrix.addAndGet(indexFor, getIndexFor(str3), 1.0d);
                    }
                }
                arrayDeque2.offer(str);
                if (arrayDeque2.size() > this.windowSize) {
                    arrayDeque2.remove();
                }
            }
        }
        dataOutputStream.close();
        byte[] byteArray = byteArrayOutputStream.toByteArray();
        synchronized (this.compressedDocumentsWriter) {
            this.compressedDocumentsWriter.writeInt(i2);
            this.compressedDocumentsWriter.writeInt(i3);
            this.compressedDocumentsWriter.write(byteArray, 0, byteArray.length);
        }
    }

    @Override // edu.ucla.sspace.common.SemanticSpace
    public void processSpace(Properties properties) {
        try {
            processSpace();
        } catch (IOException e) {
            throw new IOError(e);
        }
    }
}
