package edu.berkeley.nlp.util;

import edu.berkeley.nlp.util.MapFactory;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/berkeley/nlp/util/CounterMap.class */
public class CounterMap<K, V> implements Serializable {
    private static final long serialVersionUID = 4230378533059209021L;
    MapFactory<V, Double> mf;
    Map<K, Counter<V>> counterMap;
    int currentModCount;
    int cacheModCount;
    double cacheTotalCount;

    protected Counter<V> ensureCounter(K k) {
        Counter<V> counter = this.counterMap.get(k);
        if (counter == null) {
            counter = new Counter<>(this.mf);
            this.counterMap.put(k, counter);
        }
        return counter;
    }

    public Set<K> keySet() {
        return this.counterMap.keySet();
    }

    public void setCount(K k, V v, double d) {
        ensureCounter(k).setCount(v, d);
        this.currentModCount++;
    }

    public void incrementCount(K k, V v, double d) {
        ensureCounter(k).incrementCount(v, d);
        this.currentModCount++;
    }

    public void incrementAll(Map<K, V> map, double d) {
        for (Map.Entry<K, V> entry : map.entrySet()) {
            incrementCount(entry.getKey(), entry.getValue(), d);
        }
    }

    public void incrementAll(Collection<Pair<K, V>> collection, double d) {
        for (Pair<K, V> pair : collection) {
            incrementCount(pair.getFirst(), pair.getSecond(), d);
        }
    }

    public double getCount(K k, V v) {
        Counter<V> counter = this.counterMap.get(k);
        if (counter == null) {
            return 0.0d;
        }
        return counter.getCount(v);
    }

    public Counter<V> getCounter(K k) {
        return ensureCounter(k);
    }

    public boolean containsKey(K k) {
        return this.counterMap.containsKey(k);
    }

    public double totalCount() {
        if (this.currentModCount != this.cacheModCount) {
            double d = 0.0d;
            Iterator<Map.Entry<K, Counter<V>>> it = this.counterMap.entrySet().iterator();
            while (it.hasNext()) {
                d += it.next().getValue().totalCount();
            }
            this.cacheTotalCount = d;
            this.cacheModCount = this.currentModCount;
        }
        return this.cacheTotalCount;
    }

    public int totalSize() {
        int i = 0;
        Iterator<Map.Entry<K, Counter<V>>> it = this.counterMap.entrySet().iterator();
        while (it.hasNext()) {
            i += it.next().getValue().size();
        }
        return i;
    }

    public void normalize() {
        Iterator<Map.Entry<K, Counter<V>>> it = this.counterMap.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().normalize();
        }
        this.currentModCount++;
    }

    public int size() {
        return this.counterMap.size();
    }

    public boolean isEmpty() {
        return size() == 0;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("[\n");
        for (Map.Entry<K, Counter<V>> entry : this.counterMap.entrySet()) {
            sb.append("  ");
            sb.append(entry.getKey());
            sb.append(" -> ");
            sb.append(entry.getValue());
            sb.append("\n");
        }
        sb.append("]");
        return sb.toString();
    }

    public CounterMap() {
        this(new MapFactory.HashMapFactory(), new MapFactory.HashMapFactory());
    }

    public CounterMap(MapFactory<K, Counter<V>> mapFactory, MapFactory<V, Double> mapFactory2) {
        this.currentModCount = 0;
        this.cacheModCount = -1;
        this.cacheTotalCount = 0.0d;
        this.mf = mapFactory2;
        this.counterMap = mapFactory.newMap();
    }

    public static void main(String[] strArr) {
        CounterMap counterMap = new CounterMap();
        counterMap.incrementCount("people", "run", 1.0d);
        counterMap.incrementCount("cats", "growl", 2.0d);
        counterMap.incrementCount("cats", "scamper", 3.0d);
        System.out.println(counterMap);
        System.out.println("Entries for cats: " + counterMap.getCounter("cats"));
        System.out.println("Entries for dogs: " + counterMap.getCounter("dogs"));
        System.out.println("Count of cats scamper: " + counterMap.getCount("cats", "scamper"));
        System.out.println("Count of snakes slither: " + counterMap.getCount("snakes", "slither"));
        System.out.println("Total size: " + counterMap.totalSize());
        System.out.println("Total count: " + counterMap.totalCount());
        System.out.println(counterMap);
    }

    public Pair<K, V> argMax() {
        double d = Double.NEGATIVE_INFINITY;
        Pair<K, V> pair = null;
        for (Map.Entry<K, Counter<V>> entry : this.counterMap.entrySet()) {
            Counter<V> value = entry.getValue();
            V argMax = value.argMax();
            if (value.getCount(argMax) > d || pair == null) {
                pair = new Pair<>(entry.getKey(), argMax);
                d = value.getCount(argMax);
            }
        }
        return pair;
    }
}
