package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.io.AbstractMapLabel;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.CounterMap;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.SubIndexer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/classify/MaximumEntropyClassifier.class */
public class MaximumEntropyClassifier<I, F, L> {
    private double[] weights;
    private Encoding<F, L> encoding;
    private IndexLinearizer indexLinearizer;
    private FeatureExtractor<I, F> featureExtractor;
    private static double numLogs = 0.0d;
    private static double numLogsSaved = 0.0d;
    private ObjectiveFunction objective;

    /* loaded from: input_file:edu/berkeley/nlp/classify/MaximumEntropyClassifier$Factory.class */
    public static class Factory<I, F, L> {
        double sigma;
        int iterations;
        FeatureExtractor<I, F> featureExtractor;

        public MaximumEntropyClassifier<I, F, L> trainClassifier(List<LabeledInstance<I, L>> list) {
            Encoding<F, L> buildEncoding = buildEncoding(list);
            IndexLinearizer buildIndexLinearizer = buildIndexLinearizer(buildEncoding);
            double[] buildInitialWeights = buildInitialWeights(buildIndexLinearizer);
            EncodedDatum[] encodeData = encodeData(list, buildEncoding);
            LBFGSMinimizer lBFGSMinimizer = new LBFGSMinimizer(this.iterations);
            ProperNameObjectiveFunction properNameObjectiveFunction = new ProperNameObjectiveFunction(buildEncoding, encodeData, buildIndexLinearizer, this.sigma);
            return new MaximumEntropyClassifier<>(lBFGSMinimizer.minimize(properNameObjectiveFunction, buildInitialWeights, 1.0E-4d), buildEncoding, buildIndexLinearizer, this.featureExtractor, properNameObjectiveFunction);
        }

        private void testDerivatives(DifferentiableFunction differentiableFunction) {
            double[] constantArray = DoubleArrays.constantArray(0.0d, differentiableFunction.dimension());
            double[] derivativeAt = differentiableFunction.derivativeAt(constantArray);
            double[] constantArray2 = DoubleArrays.constantArray(0.0d, differentiableFunction.dimension());
            double valueAt = differentiableFunction.valueAt(constantArray);
            for (int i = 0; i < constantArray.length; i++) {
                double[] clone = DoubleArrays.clone(constantArray);
                int i2 = i;
                clone[i2] = clone[i2] + 1.0E-4d;
                double valueAt2 = differentiableFunction.valueAt(clone);
                constantArray2[i] = (valueAt2 - valueAt) / 1.0E-4d;
                System.out.println(String.valueOf(derivativeAt[i]) + " " + constantArray2[i] + "(" + valueAt2 + ", " + valueAt + ")");
            }
        }

        private double[] buildInitialWeights(IndexLinearizer indexLinearizer) {
            return DoubleArrays.constantArray(0.0d, indexLinearizer.getNumLinearIndexes());
        }

        private IndexLinearizer buildIndexLinearizer(Encoding<F, L> encoding) {
            return new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumSubLabels());
        }

        private Encoding<F, L> buildEncoding(List<LabeledInstance<I, L>> list) {
            Indexer indexer = new Indexer();
            SubIndexer subIndexer = new SubIndexer();
            for (LabeledInstance<I, L> labeledInstance : list) {
                L label = labeledInstance.getLabel();
                Counter<F> extractFeatures = this.featureExtractor.extractFeatures(labeledInstance.getInput());
                subIndexer.add((SubIndexer) label, labeledInstance.getNumSubStates());
                Iterator<F> it = extractFeatures.keySet().iterator();
                while (it.hasNext()) {
                    indexer.add(it.next());
                }
            }
            return new Encoding<>(indexer, subIndexer);
        }

        private EncodedDatum[] encodeData(List<LabeledInstance<I, L>> list, Encoding<F, L> encoding) {
            EncodedDatum[] encodedDatumArr = new EncodedDatum[list.size()];
            for (int i = 0; i < list.size(); i++) {
                LabeledInstance<I, L> labeledInstance = list.get(i);
                encodedDatumArr[i] = EncodedDatum.encodeLabeledDatum(encoding, this.featureExtractor.extractFeatures(labeledInstance.getInput()), labeledInstance.getLabel(), labeledInstance.getWeights());
            }
            return encodedDatumArr;
        }

        public Factory(double d, int i, FeatureExtractor<I, F> featureExtractor) {
            this.sigma = d;
            this.iterations = i;
            this.featureExtractor = featureExtractor;
        }
    }

    public static void displaySavings() {
        System.out.println("Saved " + ((100.0d * numLogsSaved) / numLogs) + "% calls to log()");
    }

    public CounterMap<L, Integer> getProbabilities(I i) {
        return logProbabilityArrayToProbabilityCounter(this.objective.getLogProbabilities(EncodedDatum.encodeDatum(this.encoding, this.featureExtractor.extractFeatures(i)), this.weights, this.encoding, this.indexLinearizer));
    }

    private CounterMap<L, Integer> logProbabilityArrayToProbabilityCounter(double[] dArr) {
        CounterMap<L, Integer> counterMap = new CounterMap<>();
        for (int i = 0; i < this.encoding.getNumLabels(); i++) {
            L label = this.encoding.getLabel(i);
            int labelSubindexBegin = this.encoding.getLabelSubindexBegin(i);
            int labelSubindexEnd = this.encoding.getLabelSubindexEnd(i);
            for (int i2 = labelSubindexBegin; i2 < labelSubindexEnd; i2++) {
                counterMap.setCount(label, Integer.valueOf(i2 - labelSubindexBegin), Math.exp(dArr[i2]));
            }
        }
        return counterMap;
    }

    public Pair<L, Integer> getLabel(I i) {
        return getProbabilities(i).argMax();
    }

    public MaximumEntropyClassifier(double[] dArr, Encoding<F, L> encoding, IndexLinearizer indexLinearizer, FeatureExtractor<I, F> featureExtractor, ObjectiveFunction objectiveFunction) {
        this.weights = dArr;
        this.encoding = encoding;
        this.indexLinearizer = indexLinearizer;
        this.featureExtractor = featureExtractor;
        this.objective = objectiveFunction;
    }

    public static void main(String[] strArr) {
        DoubleArrays.constantArray(1.0d / 2, 2);
        LabeledInstance labeledInstance = new LabeledInstance(AbstractMapLabel.CATEGORY_KEY, new String[]{"fuzzy", "claws", "small"});
        LabeledInstance labeledInstance2 = new LabeledInstance("bear", new String[]{"fuzzy", "claws", "big"});
        LabeledInstance labeledInstance3 = new LabeledInstance(AbstractMapLabel.CATEGORY_KEY, new String[]{"claws", "medium"});
        LabeledInstance labeledInstance4 = new LabeledInstance(AbstractMapLabel.CATEGORY_KEY, new String[]{"claws", "small"});
        ArrayList arrayList = new ArrayList();
        arrayList.add(labeledInstance);
        arrayList.add(labeledInstance2);
        arrayList.add(labeledInstance3);
        new ArrayList().add(labeledInstance4);
        System.out.println("Probabilities on test instance: " + new Factory(1.0d, 20, new FeatureExtractor<String[], String>() { // from class: edu.berkeley.nlp.classify.MaximumEntropyClassifier.1
            @Override // edu.berkeley.nlp.classify.FeatureExtractor
            public Counter<String> extractFeatures(String[] strArr2) {
                return new Counter<>(Arrays.asList(strArr2));
            }
        }).trainClassifier(arrayList).getProbabilities((String[]) labeledInstance4.getInput()));
    }
}
