package edu.berkeley.nlp.util;

import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.util.MapFactory;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/berkeley/nlp/util/UnaryCounterTable.class */
public class UnaryCounterTable implements Serializable {
    private static final long serialVersionUID = 1;
    Map<UnaryRule, double[][]> entries;
    short[] numSubStates;
    UnaryRule searchKey;

    public Set<UnaryRule> keySet() {
        return this.entries.keySet();
    }

    public int size() {
        return this.entries.size();
    }

    public boolean isEmpty() {
        return size() == 0;
    }

    public boolean containsKey(UnaryRule unaryRule) {
        return this.entries.containsKey(unaryRule);
    }

    public double[][] getCount(UnaryRule unaryRule) {
        return this.entries.get(unaryRule);
    }

    public double[][] getCount(short s, short s2) {
        this.searchKey.setNodes(s, s2);
        return this.entries.get(this.searchKey);
    }

    public void setCount(UnaryRule unaryRule, double[][] dArr) {
        this.entries.put(unaryRule, dArr);
    }

    public void incrementCount(UnaryRule unaryRule, double[][] dArr) {
        double[][] count = getCount(unaryRule);
        if (count == null) {
            setCount(unaryRule, dArr);
            return;
        }
        for (int i = 0; i < count.length; i++) {
            if (dArr[i] != null) {
                if (count[i] == null) {
                    count[i] = new double[dArr[i].length];
                }
                for (int i2 = 0; i2 < count[i].length; i2++) {
                    double[] dArr2 = count[i];
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + dArr[i][i2];
                }
            }
        }
        setCount(unaryRule, count);
    }

    public void incrementCount(UnaryRule unaryRule, double d) {
        double[][] count = getCount(unaryRule);
        if (count == null) {
            double[][] scores2 = unaryRule.getScores2();
            double[][] dArr = new double[scores2.length][scores2[0].length];
            ArrayUtil.fill(dArr, d);
            setCount(unaryRule, dArr);
            return;
        }
        for (int i = 0; i < count.length; i++) {
            if (count[i] == null) {
                count[i] = new double[this.numSubStates[unaryRule.getParentState()]];
            }
            for (int i2 = 0; i2 < count[i].length; i2++) {
                double[] dArr2 = count[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + d;
            }
        }
        setCount(unaryRule, count);
    }

    public UnaryCounterTable(short[] sArr) {
        this(new MapFactory.HashMapFactory(), sArr);
    }

    public UnaryCounterTable(MapFactory<UnaryRule, double[][]> mapFactory, short[] sArr) {
        this.entries = mapFactory.newMap();
        this.searchKey = new UnaryRule((short) 0, (short) 0);
        this.numSubStates = sArr;
    }
}
