package edu.berkeley.nlp.util;

import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.util.MapFactory;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/berkeley/nlp/util/BinaryCounterTable.class */
public class BinaryCounterTable implements Serializable {
    private static final long serialVersionUID = 1;
    Map<BinaryRule, double[][][]> entries;
    short[] numSubStates;
    BinaryRule searchKey;

    public Set<BinaryRule> keySet() {
        return this.entries.keySet();
    }

    public int size() {
        return this.entries.size();
    }

    public boolean isEmpty() {
        return size() == 0;
    }

    public boolean containsKey(BinaryRule binaryRule) {
        return this.entries.containsKey(binaryRule);
    }

    public double[][][] getCount(BinaryRule binaryRule) {
        return this.entries.get(binaryRule);
    }

    public double[][][] getCount(short s, short s2, short s3) {
        this.searchKey.setNodes(s, s2, s3);
        return this.entries.get(this.searchKey);
    }

    public void setCount(BinaryRule binaryRule, double[][][] dArr) {
        this.entries.put(binaryRule, dArr);
    }

    public void incrementCount(BinaryRule binaryRule, double[][][] dArr) {
        double[][][] count = getCount(binaryRule);
        if (count == null) {
            setCount(binaryRule, dArr);
            return;
        }
        for (int i = 0; i < count.length; i++) {
            for (int i2 = 0; i2 < count[i].length; i2++) {
                if (dArr[i][i2] != null) {
                    if (count[i][i2] == null) {
                        count[i][i2] = new double[dArr[i][i2].length];
                    }
                    for (int i3 = 0; i3 < count[i][i2].length; i3++) {
                        double[] dArr2 = count[i][i2];
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] + dArr[i][i2][i3];
                    }
                }
            }
        }
        setCount(binaryRule, count);
    }

    public void incrementCount(BinaryRule binaryRule, double d) {
        double[][][] count = getCount(binaryRule);
        if (count == null) {
            double[][][] scores2 = binaryRule.getScores2();
            double[][][] dArr = new double[scores2.length][scores2[0].length][scores2[0][0].length];
            ArrayUtil.fill(dArr, d);
            setCount(binaryRule, dArr);
            return;
        }
        for (int i = 0; i < count.length; i++) {
            for (int i2 = 0; i2 < count[i].length; i2++) {
                if (count[i][i2] == null) {
                    count[i][i2] = new double[this.numSubStates[binaryRule.getParentState()]];
                }
                for (int i3 = 0; i3 < count[i][i2].length; i3++) {
                    double[] dArr2 = count[i][i2];
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] + d;
                }
            }
        }
        setCount(binaryRule, count);
    }

    public BinaryCounterTable(short[] sArr) {
        this(new MapFactory.HashMapFactory(), sArr);
    }

    public BinaryCounterTable(MapFactory<BinaryRule, double[][][]> mapFactory, short[] sArr) {
        this.entries = mapFactory.newMap();
        this.searchKey = new BinaryRule((short) 0, (short) 0, (short) 0);
        this.numSubStates = sArr;
    }

    public static void main(String[] strArr) {
        Counter counter = new Counter();
        System.out.println(counter);
        counter.incrementCount("planets", 7.0d);
        System.out.println(counter);
        counter.incrementCount("planets", 1.0d);
        System.out.println(counter);
        counter.setCount("suns", 1.0d);
        System.out.println(counter);
        counter.setCount("aliens", 0.0d);
        System.out.println(counter);
        System.out.println(counter.toString(2));
        System.out.println("Total: " + counter.totalCount());
    }
}
