package org.jpmml.evaluator.mining;

import com.google.common.base.Function;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.PMMLException;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.aj;
import org.jpmml.evaluator.ba;
import org.jpmml.evaluator.bb;
import org.jpmml.evaluator.bf;
import org.jpmml.evaluator.bg;
import org.jpmml.evaluator.bh;
import org.jpmml.evaluator.bi;
import org.jpmml.evaluator.bk;
import org.jpmml.evaluator.bt;
import org.jpmml.evaluator.bz;
import org.jpmml.evaluator.cd;
import org.jpmml.evaluator.n;
import org.jpmml.evaluator.q;
import org.jpmml.evaluator.t;

/* loaded from: classes8.dex */
public class MiningModelEvaluator extends ModelEvaluator<MiningModel> implements aj<Segment> {
    private transient BiMap<String, Segment> entityRegistry;
    private ModelEvaluatorFactory modelEvaluatorFactory;
    private ConcurrentMap<String, SegmentHandler> segmentHandlers;
    private static final Set<Segmentation.MultipleModelMethod> REGRESSION_METHODS = EnumSet.of(Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE, Segmentation.MultipleModelMethod.MEDIAN, Segmentation.MultipleModelMethod.WEIGHTED_MEDIAN, Segmentation.MultipleModelMethod.SUM, Segmentation.MultipleModelMethod.WEIGHTED_SUM);
    private static final Set<Segmentation.MultipleModelMethod> CLASSIFICATION_METHODS = EnumSet.of(Segmentation.MultipleModelMethod.MAJORITY_VOTE, Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE, Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE, Segmentation.MultipleModelMethod.MEDIAN, Segmentation.MultipleModelMethod.MAX);
    private static final Set<Segmentation.MultipleModelMethod> CLUSTERING_METHODS = EnumSet.of(Segmentation.MultipleModelMethod.MAJORITY_VOTE, Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE);
    private static final LoadingCache<MiningModel, BiMap<String, Segment>> entityCache = org.jpmml.evaluator.c.buildLoadingCache(new CacheLoader<MiningModel, BiMap<String, Segment>>() { // from class: org.jpmml.evaluator.mining.MiningModelEvaluator.1
        public BiMap<String, Segment> load(MiningModel miningModel) {
            return n.buildBiMap(miningModel.getSegmentation().getSegments());
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes8.dex */
    public static class SegmentHandler implements Serializable {
        private boolean compatible;
        private ModelEvaluator<?> modelEvaluator;

        private SegmentHandler(ModelEvaluator<?> modelEvaluator) {
            setModelEvaluator(modelEvaluator);
            boolean z = true;
            Iterator<InputField> it = modelEvaluator.getInputFields().iterator();
            while (true) {
                boolean z2 = z;
                if (!it.hasNext()) {
                    setCompatible(z2);
                    return;
                } else {
                    InputField next = it.next();
                    z = next.getField() instanceof DataField ? ba.isDefault(next.getMiningField()) & z2 : z2;
                }
            }
        }

        private void setCompatible(boolean z) {
            this.compatible = z;
        }

        private void setModelEvaluator(ModelEvaluator<?> modelEvaluator) {
            this.modelEvaluator = modelEvaluator;
        }

        public ModelEvaluator<?> getModelEvaluator() {
            return this.modelEvaluator;
        }

        public boolean isCompatible() {
            return this.compatible;
        }
    }

    public MiningModelEvaluator(PMML pmml) {
        this(pmml, (MiningModel) selectModel(pmml, MiningModel.class));
    }

    public MiningModelEvaluator(PMML pmml, MiningModel miningModel) {
        super(pmml, miningModel);
        this.segmentHandlers = new ConcurrentHashMap();
        this.entityRegistry = null;
        if (miningModel.hasEmbeddedModels()) {
            throw new UnsupportedElementException((EmbeddedModel) Iterables.getFirst(miningModel.getEmbeddedModels(), (Object) null));
        }
        Segmentation segmentation = miningModel.getSegmentation();
        if (segmentation == null) {
            throw new MissingElementException(miningModel, bh.MININGMODEL_SEGMENTATION);
        }
        if (segmentation.getMultipleModelMethod() == null) {
            throw new MissingAttributeException(segmentation, bg.SEGMENTATION_MULTIPLEMODELMETHOD);
        }
        if (!segmentation.hasSegments()) {
            throw new MissingElementException(segmentation, bh.SEGMENTATION_SEGMENTS);
        }
        LocalTransformations localTransformations = segmentation.getLocalTransformations();
        if (localTransformations != null) {
            throw new UnsupportedElementException(localTransformations);
        }
    }

    private static void checkMiningFunction(Model model, MiningFunction miningFunction) {
        MiningFunction miningFunction2 = model.getMiningFunction();
        if (!miningFunction.equals(miningFunction2)) {
            throw new InvalidAttributeException(InvalidAttributeException.formatMessage(cd.formatElement(model.getClass()) + "@miningFunction=" + miningFunction2.value()), model);
        }
    }

    private List<OutputField> createNestedOutputFields() {
        List<Segment> segments = getModel().getSegmentation().getSegments();
        switch (r0.getMultipleModelMethod()) {
            case SELECT_ALL:
            default:
                return Collections.emptyList();
            case SELECT_FIRST:
                return createNestedOutputFields(getActiveHead(segments));
            case MODEL_CHAIN:
                return createNestedOutputFields(getActiveTail(segments));
        }
    }

    private List<OutputField> createNestedOutputFields(List<Segment> list) {
        ArrayList arrayList = new ArrayList();
        BiMap<String, Segment> entityRegistry = getEntityRegistry();
        int size = list.size();
        for (int i = 0; i < size; i++) {
            Segment segment = list.get(i);
            Model model = segment.getModel();
            if (model == null) {
                throw new MissingElementException(MissingElementException.formatMessage(cd.formatElement(segment.getClass()) + "/<Model>"), segment);
            }
            String id = n.getId(segment, entityRegistry);
            SegmentHandler segmentHandler = this.segmentHandlers.get(id);
            if (segmentHandler == null) {
                segmentHandler = createSegmentHandler(model);
                this.segmentHandlers.putIfAbsent(id, segmentHandler);
            }
            Iterator<OutputField> it = segmentHandler.getModelEvaluator().getOutputFields().iterator();
            while (it.hasNext()) {
                arrayList.add(new OutputField(it.next()));
            }
        }
        return ImmutableList.copyOf(arrayList);
    }

    private SegmentHandler createSegmentHandler(Model model) {
        ModelEvaluatorFactory modelEvaluatorFactory = getModelEvaluatorFactory();
        if (modelEvaluatorFactory == null) {
            modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        }
        return new SegmentHandler(modelEvaluatorFactory.newModelEvaluator(getPMML(), model));
    }

    private Map<FieldName, ?> evaluateAny(a aVar) {
        return getSegmentationResult(Collections.emptySet(), evaluateSegmentation(aVar));
    }

    private <V extends Number> Map<FieldName, ?> evaluateClassification(ValueFactory<V> valueFactory, a aVar) {
        bk bkVar;
        MiningModel model = getModel();
        List<c> evaluateSegmentation = evaluateSegmentation(aVar);
        Map<FieldName, ?> segmentationResult = getSegmentationResult(CLASSIFICATION_METHODS, evaluateSegmentation);
        if (segmentationResult != null) {
            return segmentationResult;
        }
        TargetField targetField = getTargetField();
        Segmentation segmentation = model.getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL:
            case SELECT_FIRST:
            case MODEL_CHAIN:
            case WEIGHTED_MEDIAN:
            case SUM:
            case WEIGHTED_SUM:
                throw new InvalidAttributeException(segmentation, multipleModelMethod);
            case AVERAGE:
            case WEIGHTED_AVERAGE:
            case MEDIAN:
            case MAX:
                bkVar = new bk(b.aggregateProbabilities(valueFactory, evaluateSegmentation, t.getTargetCategories(targetField.getDataField()), multipleModelMethod));
                break;
            case MAJORITY_VOTE:
            case WEIGHTED_MAJORITY_VOTE:
                ValueMap aggregateVotes = b.aggregateVotes(valueFactory, evaluateSegmentation, multipleModelMethod);
                bz.normalizeSimpleMax(aggregateVotes);
                bkVar = new bk(aggregateVotes);
                break;
            default:
                throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
        }
        return bt.evaluateClassification(targetField, bkVar);
    }

    private <V extends Number> Map<FieldName, ?> evaluateClustering(ValueFactory<V> valueFactory, a aVar) {
        MiningModel model = getModel();
        List<c> evaluateSegmentation = evaluateSegmentation(aVar);
        Map<FieldName, ?> segmentationResult = getSegmentationResult(CLUSTERING_METHODS, evaluateSegmentation);
        if (segmentationResult != null) {
            return segmentationResult;
        }
        Segmentation segmentation = model.getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL:
            case SELECT_FIRST:
            case MODEL_CHAIN:
            case AVERAGE:
            case WEIGHTED_AVERAGE:
            case MEDIAN:
            case WEIGHTED_MEDIAN:
            case SUM:
            case WEIGHTED_SUM:
            case MAX:
                throw new InvalidAttributeException(segmentation, multipleModelMethod);
            case MAJORITY_VOTE:
            case WEIGHTED_MAJORITY_VOTE:
                d dVar = new d(b.aggregateVotes(valueFactory, evaluateSegmentation, multipleModelMethod));
                dVar.computeResult(DataType.STRING);
                return Collections.singletonMap(getTargetFieldName(), dVar);
            default:
                throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
        }
    }

    private <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, a aVar) {
        MiningModel model = getModel();
        List<c> evaluateSegmentation = evaluateSegmentation(aVar);
        Map<FieldName, ?> segmentationResult = getSegmentationResult(REGRESSION_METHODS, evaluateSegmentation);
        if (segmentationResult != null) {
            return segmentationResult;
        }
        Segmentation segmentation = model.getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL:
            case SELECT_FIRST:
            case MODEL_CHAIN:
            case MAJORITY_VOTE:
            case WEIGHTED_MAJORITY_VOTE:
            case MAX:
                throw new InvalidAttributeException(segmentation, multipleModelMethod);
            case AVERAGE:
            case WEIGHTED_AVERAGE:
            case MEDIAN:
            case WEIGHTED_MEDIAN:
            case SUM:
            case WEIGHTED_SUM:
                return bt.evaluateRegression(getTargetField(), b.aggregateValues(valueFactory, evaluateSegmentation, multipleModelMethod));
            default:
                throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:34:0x00fb. Please report as an issue. */
    private List<c> evaluateSegmentation(a aVar) {
        boolean z;
        Model model;
        Model model2;
        bb bbVar;
        bb bbVar2;
        a aVar2;
        MiningModel model3 = getModel();
        BiMap<String, Segment> entityRegistry = getEntityRegistry();
        MiningFunction miningFunction = model3.getMiningFunction();
        Segmentation segmentation = model3.getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        Model model4 = null;
        boolean z2 = false;
        a aVar3 = null;
        bb bbVar3 = null;
        List<Segment> segments = segmentation.getSegments();
        ArrayList arrayList = new ArrayList(segments.size());
        int size = segments.size();
        int i = 0;
        while (i < size) {
            Segment segment = segments.get(i);
            Boolean evaluatePredicateContainer = bi.evaluatePredicateContainer(segment, aVar);
            if (evaluatePredicateContainer == null) {
                z = z2;
                model = model4;
            } else if (evaluatePredicateContainer.booleanValue()) {
                Model model5 = segment.getModel();
                if (model5 == null) {
                    throw new MissingElementException(MissingElementException.formatMessage(cd.formatElement(segment.getClass()) + "/<Model>"), segment);
                }
                switch (multipleModelMethod) {
                    case MODEL_CHAIN:
                        model2 = model5;
                        break;
                    default:
                        checkMiningFunction(model5, miningFunction);
                        model2 = model4;
                        break;
                }
                String id = n.getId(segment, entityRegistry);
                SegmentHandler segmentHandler = this.segmentHandlers.get(id);
                if (segmentHandler == null) {
                    segmentHandler = createSegmentHandler(model5);
                    this.segmentHandlers.putIfAbsent(id, segmentHandler);
                }
                SegmentHandler segmentHandler2 = segmentHandler;
                ModelEvaluator<?> modelEvaluator = segmentHandler2.getModelEvaluator();
                boolean z3 = z2 | ((segmentHandler2.isCompatible() && modelEvaluator.isPrimitive()) ? false : true);
                if (modelEvaluator instanceof MiningModelEvaluator) {
                    MiningModelEvaluator miningModelEvaluator = (MiningModelEvaluator) modelEvaluator;
                    if (aVar3 == null) {
                        aVar2 = new a(aVar, miningModelEvaluator);
                    } else {
                        aVar3.reset(miningModelEvaluator, z3);
                        aVar2 = aVar3;
                    }
                    bbVar2 = aVar2;
                    aVar3 = aVar2;
                } else {
                    if (bbVar3 == null) {
                        bbVar = new bb(aVar, modelEvaluator);
                    } else {
                        bbVar3.reset(modelEvaluator, z3);
                        bbVar = bbVar3;
                    }
                    bbVar2 = bbVar;
                    bbVar3 = bbVar;
                }
                bbVar2.setCompatible(segmentHandler2.isCompatible());
                try {
                    c cVar = new c(segment, id, modelEvaluator.evaluate(bbVar2), modelEvaluator.getTargetFieldName());
                    aVar.a(id, cVar);
                    switch (multipleModelMethod) {
                        case MODEL_CHAIN:
                            for (OutputField outputField : modelEvaluator.getOutputFields()) {
                                FieldName name = outputField.getName();
                                if (outputField.getDepth() <= 0) {
                                    aVar.a(outputField.getOutputField());
                                    FieldValue field = bbVar2.getField(name);
                                    if (field == null) {
                                        throw new MissingValueException(name, segment);
                                    }
                                    aVar.declare(name, field);
                                }
                            }
                        default:
                            List<String> warnings = bbVar2.getWarnings();
                            if (warnings.size() > 0) {
                                Iterator<String> it = warnings.iterator();
                                while (it.hasNext()) {
                                    aVar.addWarning(it.next());
                                }
                            }
                            switch (multipleModelMethod) {
                                case SELECT_FIRST:
                                    return Collections.singletonList(cVar);
                                default:
                                    arrayList.add(cVar);
                                    z = z3;
                                    model = model2;
                                    break;
                            }
                    }
                } catch (PMMLException e) {
                    throw e.ensureContext(segment);
                }
            } else {
                z = z2;
                model = model4;
            }
            i++;
            z2 = z;
            model4 = model;
        }
        switch (multipleModelMethod) {
            case MODEL_CHAIN:
                if (model4 != null) {
                    checkMiningFunction(model4, miningFunction);
                    break;
                }
                break;
        }
        return arrayList;
    }

    private List<Segment> getActiveHead(List<Segment> list) {
        int size = list.size();
        for (int i = 0; i < size; i++) {
            if (bi.ensurePredicate(list.get(i)) instanceof True) {
                return list.subList(0, i + 1);
            }
        }
        return list;
    }

    private List<Segment> getActiveTail(List<Segment> list) {
        return Lists.reverse(getActiveHead(Lists.reverse(list)));
    }

    private Map<FieldName, ?> getSegmentationResult(Set<Segmentation.MultipleModelMethod> set, List<c> list) {
        Segmentation segmentation = getModel().getSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL:
                return selectAll(list);
            case SELECT_FIRST:
                if (list.size() > 0) {
                    return (Map) list.get(0);
                }
                break;
            case MODEL_CHAIN:
                if (list.size() > 0) {
                    return (Map) list.get(list.size() - 1);
                }
                break;
            default:
                if (!set.contains(multipleModelMethod)) {
                    throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
                }
                break;
        }
        if (list.size() == 0) {
            return Collections.singletonMap(getTargetFieldName(), null);
        }
        return null;
    }

    private static Map<FieldName, ?> selectAll(List<c> list) {
        ArrayListMultimap create = ArrayListMultimap.create();
        LinkedHashSet<FieldName> linkedHashSet = null;
        for (c cVar : list) {
            if (linkedHashSet == null) {
                linkedHashSet = new LinkedHashSet(cVar.keySet());
            }
            if (!linkedHashSet.equals(cVar.keySet())) {
                Function<Object, String> function = new Function<Object, String>() { // from class: org.jpmml.evaluator.mining.MiningModelEvaluator.2
                    public String apply(Object obj) {
                        return PMMLException.formatKey(obj);
                    }
                };
                throw new EvaluationException("Field sets " + Iterables.transform(linkedHashSet, function) + " and " + Iterables.transform(cVar.keySet(), function) + " do not match");
            }
            for (FieldName fieldName : linkedHashSet) {
                create.put(fieldName, cVar.get(fieldName));
            }
        }
        return create.asMap();
    }

    private void setModelEvaluatorFactory(ModelEvaluatorFactory modelEvaluatorFactory) {
        this.modelEvaluatorFactory = modelEvaluatorFactory;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public void configure(ModelEvaluatorFactory modelEvaluatorFactory) {
        super.configure(modelEvaluatorFactory);
        setModelEvaluatorFactory(modelEvaluatorFactory);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public List<OutputField> createOutputFields() {
        List<OutputField> createOutputFields = super.createOutputFields();
        List<OutputField> createNestedOutputFields = createNestedOutputFields();
        return createNestedOutputFields.size() > 0 ? ImmutableList.copyOf(Iterables.concat(createNestedOutputFields, createOutputFields)) : createOutputFields;
    }

    @Override // org.jpmml.evaluator.ModelEvaluator, org.jpmml.evaluator.q
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> map) {
        a aVar = new a(this);
        aVar.setArguments(map);
        return evaluate(aVar);
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(bb bbVar) {
        return evaluate((a) bbVar);
    }

    public Map<FieldName, ?> evaluate(a aVar) {
        Map<FieldName, ?> evaluateClustering;
        MiningModel ensureScorableModel = ensureScorableModel();
        MathContext mathContext = ensureScorableModel.getMathContext();
        switch (mathContext) {
            case FLOAT:
            case DOUBLE:
                ValueFactory<?> valueFactory = getValueFactory();
                switch (ensureScorableModel.getMiningFunction()) {
                    case REGRESSION:
                        evaluateClustering = evaluateRegression(valueFactory, aVar);
                        break;
                    case CLASSIFICATION:
                        evaluateClustering = evaluateClassification(valueFactory, aVar);
                        break;
                    case CLUSTERING:
                        evaluateClustering = evaluateClustering(valueFactory, aVar);
                        break;
                    default:
                        evaluateClustering = evaluateAny(aVar);
                        break;
                }
                return bf.evaluate(evaluateClustering, aVar);
            default:
                throw new UnsupportedAttributeException(ensureScorableModel, mathContext);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public DataField getDataField() {
        switch (getModel().getSegmentation().getMultipleModelMethod()) {
            case SELECT_ALL:
            case SELECT_FIRST:
            case MODEL_CHAIN:
                return null;
            default:
                return super.getDataField();
        }
    }

    @Override // org.jpmml.evaluator.aj
    public BiMap<String, Segment> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = (BiMap) getValue(entityCache);
        }
        return this.entityRegistry;
    }

    public ModelEvaluatorFactory getModelEvaluatorFactory() {
        return this.modelEvaluatorFactory;
    }

    @Override // org.jpmml.evaluator.q
    public String getSummary() {
        return "Ensemble model";
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public FieldName getTargetFieldName() {
        return super.getTargetFields().size() == 0 ? q.DEFAULT_TARGET_NAME : super.getTargetFieldName();
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public boolean isPrimitive() {
        return false;
    }
}
