package edu.ucla.sspace.util;

import edu.ucla.sspace.common.SemanticSpace;
import edu.ucla.sspace.common.Similarity;
import edu.ucla.sspace.common.VectorMapSemanticSpace;
import edu.ucla.sspace.vector.DenseVector;
import edu.ucla.sspace.vector.DoubleVector;
import edu.ucla.sspace.vector.Vector;
import edu.ucla.sspace.vector.VectorMath;
import edu.ucla.sspace.vector.Vectors;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: classes2.dex */
public class PartitioningNearestNeighborFinder implements NearestNeighborFinder, Serializable {
    private static final Logger LOGGER = Logger.getLogger(PartitioningNearestNeighborFinder.class.getName());
    private static final long serialVersionUID = 1;
    private final MultiMap<DoubleVector, String> principleVectorToNearestTerms;
    private transient SemanticSpace sspace;
    private transient WorkQueue workQueue;

    public PartitioningNearestNeighborFinder(SemanticSpace semanticSpace) {
        this(semanticSpace, (int) Math.ceil(Math.log(semanticSpace.getWords().size())));
    }

    public PartitioningNearestNeighborFinder(SemanticSpace semanticSpace, int i) {
        if (semanticSpace == null) {
            throw new NullPointerException();
        }
        if (i > semanticSpace.getWords().size()) {
            throw new IllegalArgumentException("Cannot have more principle vectors than word vectors: " + i);
        }
        if (i < 1) {
            throw new IllegalArgumentException("Must have at least one principle vector");
        }
        this.sspace = semanticSpace;
        this.principleVectorToNearestTerms = new HashMultiMap();
        this.workQueue = new WorkQueue();
        computePrincipleVectors(i);
    }

    private void computePrincipleVectors(int i) {
        int i2;
        char c;
        int size = this.sspace.getWords().size();
        ArrayList arrayList = new ArrayList(size);
        String[] strArr = new String[size];
        char c2 = 0;
        int i3 = 0;
        for (String str : this.sspace.getWords()) {
            arrayList.add(Vectors.asDouble(this.sspace.getVector(str)));
            strArr[i3] = str;
            i3++;
        }
        Random random = new Random();
        final DoubleVector[] doubleVectorArr = new DoubleVector[i];
        for (int i4 = 0; i4 < doubleVectorArr.length; i4++) {
            doubleVectorArr[i4] = new DenseVector(this.sspace.getVectorLength());
        }
        final HashMultiMap hashMultiMap = new HashMultiMap();
        for (int i5 = 0; i5 < size; i5++) {
            hashMultiMap.put(Integer.valueOf(random.nextInt(i)), Integer.valueOf(i5));
        }
        int i6 = 0;
        while (true) {
            i2 = 2;
            c = 1;
            if (i6 >= 1) {
                break;
            }
            Logger logger = LOGGER;
            Object[] objArr = new Object[2];
            int i7 = i6 + 1;
            objArr[c2] = Integer.valueOf(i7);
            objArr[1] = 1;
            LoggerUtil.verbose(logger, "Computing principle vectors (round %d/%d)", objArr);
            for (Map.Entry entry : hashMultiMap.asMap().entrySet()) {
                int intValue = ((Integer) entry.getKey()).intValue();
                DenseVector denseVector = new DenseVector(this.sspace.getVectorLength());
                doubleVectorArr[intValue] = denseVector;
                Iterator it = ((Set) entry.getValue()).iterator();
                while (it.hasNext()) {
                    VectorMath.add((DoubleVector) denseVector, (DoubleVector) arrayList.get(((Integer) it.next()).intValue()));
                }
            }
            hashMultiMap.clear();
            int availableThreads = (size / this.workQueue.availableThreads()) + 1;
            WorkQueue workQueue = this.workQueue;
            Object registerTaskGroup = workQueue.registerTaskGroup(workQueue.availableThreads());
            final int i8 = 0;
            while (i8 < size) {
                int i9 = i8 + availableThreads;
                final int min = Math.min(size, i9);
                final ArrayList arrayList2 = arrayList;
                ArrayList arrayList3 = arrayList;
                Object obj = registerTaskGroup;
                this.workQueue.add(obj, new Runnable() { // from class: edu.ucla.sspace.util.PartitioningNearestNeighborFinder.1
                    static final /* synthetic */ boolean $assertionsDisabled = false;

                    @Override // java.lang.Runnable
                    public void run() {
                        HashMultiMap hashMultiMap2 = new HashMultiMap();
                        for (int i10 = i8; i10 < min; i10++) {
                            DoubleVector doubleVector = (DoubleVector) arrayList2.get(i10);
                            double d = -1.7976931348623157E308d;
                            int i11 = -1;
                            int i12 = 0;
                            while (true) {
                                DoubleVector[] doubleVectorArr2 = doubleVectorArr;
                                if (i12 < doubleVectorArr2.length) {
                                    double cosineSimilarity = Similarity.cosineSimilarity(doubleVector, doubleVectorArr2[i12]);
                                    if (cosineSimilarity > d) {
                                        i11 = i12;
                                        d = cosineSimilarity;
                                    }
                                    i12++;
                                }
                            }
                            hashMultiMap2.put(Integer.valueOf(i11), Integer.valueOf(i10));
                        }
                        synchronized (hashMultiMap) {
                            hashMultiMap.putAll(hashMultiMap2);
                        }
                    }
                });
                registerTaskGroup = obj;
                i7 = i7;
                i8 = i9;
                arrayList = arrayList3;
            }
            this.workQueue.await(registerTaskGroup);
            i6 = i7;
            arrayList = arrayList;
            c2 = 0;
        }
        double d = size;
        double d2 = i;
        Double.isNaN(d);
        Double.isNaN(d2);
        double d3 = d / d2;
        double d4 = 0.0d;
        for (Map.Entry entry2 : hashMultiMap.asMap().entrySet()) {
            Set set = (Set) entry2.getValue();
            HashSet hashSet = new HashSet();
            Iterator it2 = set.iterator();
            while (it2.hasNext()) {
                hashSet.add(strArr[((Integer) it2.next()).intValue()]);
            }
            Logger logger2 = LOGGER;
            Object[] objArr2 = new Object[i2];
            objArr2[0] = entry2.getKey();
            objArr2[c] = Integer.valueOf(hashSet.size());
            LoggerUtil.verbose(logger2, "Principle vectod %d is closest to %d terms", objArr2);
            double size2 = hashSet.size();
            Double.isNaN(size2);
            double d5 = d3 - size2;
            d4 += d5 * d5;
            this.principleVectorToNearestTerms.putMany(doubleVectorArr[((Integer) entry2.getKey()).intValue()], hashSet);
            i2 = 2;
            c = 1;
        }
        Logger logger3 = LOGGER;
        Double.isNaN(d2);
        LoggerUtil.verbose(logger3, "Average number terms per principle vector: %f, (%f stddev)", Double.valueOf(d3), Double.valueOf(Math.sqrt(d4 / d2)));
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        objectInputStream.defaultReadObject();
        this.workQueue = new WorkQueue();
        this.sspace = (SemanticSpace) objectInputStream.readObject();
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        SemanticSpace semanticSpace = this.sspace;
        if (semanticSpace instanceof Serializable) {
            objectOutputStream.writeObject(semanticSpace);
            return;
        }
        LoggerUtil.verbose(LOGGER, "%s is not serializable, so writing a copy of the data", semanticSpace.getSpaceName());
        HashMap hashMap = new HashMap(this.sspace.getWords().size());
        for (String str : this.sspace.getWords()) {
            hashMap.put(str, this.sspace.getVector(str));
        }
        objectOutputStream.writeObject(new VectorMapSemanticSpace(hashMap, "copy of " + this.sspace.getSpaceName(), this.sspace.getVectorLength()));
    }

    @Override // edu.ucla.sspace.util.NearestNeighborFinder
    public SortedMultiMap<Double, String> getMostSimilar(final Vector vector, final int i) {
        if (vector == null) {
            return null;
        }
        BoundedSortedMultiMap boundedSortedMultiMap = new BoundedSortedMultiMap(i, false);
        for (Map.Entry<DoubleVector, Set<String>> entry : this.principleVectorToNearestTerms.asMap().entrySet()) {
            boundedSortedMultiMap.put(Double.valueOf(Similarity.cosineSimilarity(vector, entry.getKey())), entry);
        }
        final BoundedSortedMultiMap boundedSortedMultiMap2 = new BoundedSortedMultiMap(i, false);
        Object registerTaskGroup = this.workQueue.registerTaskGroup(i);
        Iterator it = boundedSortedMultiMap.values2().iterator();
        int i2 = 0;
        while (it.hasNext()) {
            final Set set = (Set) ((Map.Entry) it.next()).getValue();
            i2 += set.size();
            this.workQueue.add(registerTaskGroup, new Runnable() { // from class: edu.ucla.sspace.util.PartitioningNearestNeighborFinder.2
                @Override // java.lang.Runnable
                public void run() {
                    BoundedSortedMultiMap boundedSortedMultiMap3 = new BoundedSortedMultiMap(i, false);
                    for (String str : set) {
                        boundedSortedMultiMap3.put(Double.valueOf(Similarity.cosineSimilarity(vector, PartitioningNearestNeighborFinder.this.sspace.getVector(str))), str);
                    }
                    synchronized (boundedSortedMultiMap2) {
                        Iterator it2 = boundedSortedMultiMap3.entrySet().iterator();
                        while (it2.hasNext()) {
                            Map.Entry entry2 = (Map.Entry) it2.next();
                            boundedSortedMultiMap2.put(entry2.getKey(), entry2.getValue());
                        }
                    }
                }
            });
        }
        this.workQueue.await(registerTaskGroup);
        LoggerUtil.verbose(LOGGER, "Compared %d of the total %d terms to find the %d-nearest neighbors", Integer.valueOf(i2), Integer.valueOf(this.sspace.getWords().size()), Integer.valueOf(i));
        return boundedSortedMultiMap2;
    }

    @Override // edu.ucla.sspace.util.NearestNeighborFinder
    public SortedMultiMap<Double, String> getMostSimilar(String str, int i) {
        Vector vector = this.sspace.getVector(str);
        if (vector == null) {
            return null;
        }
        SortedMultiMap<Double, String> mostSimilar = getMostSimilar(vector, i + 1);
        Iterator<Map.Entry<Double, String>> it = mostSimilar.entrySet().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (str.equals(it.next().getValue())) {
                it.remove();
                break;
            }
        }
        return mostSimilar;
    }

    @Override // edu.ucla.sspace.util.NearestNeighborFinder
    public SortedMultiMap<Double, String> getMostSimilar(Set<String> set, int i) {
        if (set.isEmpty()) {
            return null;
        }
        DenseVector denseVector = new DenseVector(this.sspace.getVectorLength());
        int i2 = 0;
        for (String str : set) {
            Vector vector = this.sspace.getVector(str);
            if (vector == null) {
                LoggerUtil.info(LOGGER, "No vector for term " + str, new Object[0]);
            } else {
                VectorMath.add(denseVector, vector);
                i2++;
            }
        }
        if (i2 == 0) {
            return null;
        }
        SortedMultiMap<Double, String> mostSimilar = getMostSimilar(denseVector, set.size() + i);
        Iterator<Map.Entry<Double, String>> it = mostSimilar.entrySet().iterator();
        while (it.hasNext()) {
            if (set.contains(it.next().getValue())) {
                it.remove();
            }
        }
        while (mostSimilar.size() > i) {
            mostSimilar.remove(mostSimilar.firstKey());
        }
        return mostSimilar;
    }
}
