/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.converter;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelStats;
import org.dmg.pmml.PMML;
import org.dmg.pmml.UnivariateStats;
import org.dmg.pmml.Visitable;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.visitors.FeatureExpander;
import org.jpmml.converter.visitors.ModelCleanerBattery;
import org.jpmml.converter.visitors.PMMLCleanerBattery;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelEncoder
extends PMMLEncoder {
    private List<Model> transformers = new ArrayList<Model>();
    private Map<FieldName, List<Decorator>> decorators = new LinkedHashMap<FieldName, List<Decorator>>();
    private Map<Model, ListMultimap<FieldName, Number>> featureImportances = new LinkedHashMap<Model, ListMultimap<FieldName, Number>>();
    private Map<FieldName, UnivariateStats> univariateStats = new LinkedHashMap<FieldName, UnivariateStats>();
    private static final Logger logger = LoggerFactory.getLogger(ModelEncoder.class);

    public PMML encodePMML(Model model) {
        PMML pmml = this.encodePMML();
        List<Model> transformers = this.getTransformers();
        if (transformers.size() > 0) {
            ArrayList<Model> models = new ArrayList<Model>(transformers);
            if (model != null) {
                models.add(model);
            }
            model = MiningModelUtil.createModelChain(models);
        }
        if (model != null) {
            pmml.addModels(new Model[]{model});
            ModelCleanerBattery modelCleanerBattery = new ModelCleanerBattery();
            modelCleanerBattery.applyTo((Visitable)pmml);
            MiningSchema miningSchema = model.getMiningSchema();
            List miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                UnivariateStats univariateStats;
                FieldName name = miningField.getName();
                DataField dataField = this.getDataField(name);
                if (dataField == null) {
                    throw new IllegalArgumentException("Field " + name.getValue() + " is not referentiable");
                }
                List<Decorator> decorators = this.getDecorators(name);
                if (decorators != null) {
                    for (Decorator decorator : decorators) {
                        decorator.decorate(miningField);
                    }
                }
                if ((univariateStats = this.getUnivariateStats(name)) == null) continue;
                ModelStats modelStats = ModelUtil.ensureModelStats(model);
                modelStats.addUnivariateStats(new UnivariateStats[]{univariateStats});
            }
            this.encodeFeatureImportances(pmml);
        }
        PMMLCleanerBattery pmmlCleanerBattery = new PMMLCleanerBattery();
        pmmlCleanerBattery.applyTo((Visitable)pmml);
        return pmml;
    }

    public List<Model> getTransformers() {
        return this.transformers;
    }

    public void addTransformer(Model transformer) {
        this.transformers.add(transformer);
    }

    public List<Decorator> getDecorators(FieldName name) {
        return this.decorators.get(name);
    }

    public void addDecorator(DataField dataField, Decorator decorator) {
        this.addDecorator(dataField.getName(), decorator);
    }

    public void addDecorator(FieldName name, Decorator decorator) {
        List<Decorator> decorators = this.decorators.get(name);
        if (decorators == null) {
            decorators = new ArrayList<Decorator>();
            this.decorators.put(name, decorators);
        }
        decorators.add(decorator);
    }

    public void addFeatureImportance(FieldName name, Number featureImportance) {
        this.addFeatureImportance(null, name, featureImportance);
    }

    public void addFeatureImportance(Model model, FieldName name, Number featureImportance) {
        ArrayListMultimap featureImportances = this.featureImportances.get(model);
        if (featureImportances == null) {
            featureImportances = ArrayListMultimap.create();
            this.featureImportances.put(model, (ListMultimap<FieldName, Number>)featureImportances);
        }
        featureImportances.put((Object)name, (Object)featureImportance);
    }

    public void transferFeatureImportances(Model model) {
        this.transferFeatureImportances(null, model);
    }

    public void transferFeatureImportances(Model left, Model right) {
        ListMultimap<FieldName, Number> featureImportances = this.featureImportances.remove(left);
        if (featureImportances != null && !featureImportances.isEmpty()) {
            this.featureImportances.put(right, featureImportances);
        }
    }

    public Map<Model, ListMultimap<FieldName, Number>> getFeatureImportances() {
        return this.featureImportances;
    }

    public UnivariateStats getUnivariateStats(FieldName name) {
        return this.univariateStats.get(name);
    }

    public void putUnivariateStats(UnivariateStats univariateStats) {
        this.putUnivariateStats(univariateStats.getField(), univariateStats);
    }

    public void putUnivariateStats(FieldName name, UnivariateStats univariateStats) {
        this.univariateStats.put(name, univariateStats);
    }

    private void encodeFeatureImportances(PMML pmml) {
        Map<Model, ListMultimap<FieldName, Number>> importances = this.getFeatureImportances();
        if (importances.isEmpty()) {
            return;
        }
        if (importances.containsKey(null)) {
            throw new IllegalStateException();
        }
        Map<Model, Set<FieldName>> expandableFeatures = importances.entrySet().stream().collect(Collectors.toMap(entry -> (Model)entry.getKey(), entry -> ((ListMultimap)entry.getValue()).keySet()));
        FeatureExpander featureExpander = new FeatureExpander(expandableFeatures);
        featureExpander.applyTo((Visitable)pmml);
        Set<Map.Entry<Model, ListMultimap<FieldName, Number>>> entries = importances.entrySet();
        for (Map.Entry entry2 : entries) {
            Model model = (Model)entry2.getKey();
            ListMultimap featureImportances = (ListMultimap)entry2.getValue();
            Map<FieldName, Set<Field<?>>> featureFields = featureExpander.getExpandedFeatures(model);
            if (featureFields == null) {
                throw new IllegalArgumentException();
            }
            ArrayListMultimap fieldImportances = ArrayListMultimap.create();
            Set importanceEntries = featureImportances.asMap().entrySet();
            for (Map.Entry entry3 : importanceEntries) {
                FieldName featureName = (FieldName)entry3.getKey();
                Double featureImportanceSum = ((Collection)entry3.getValue()).stream().collect(Collectors.summingDouble(Number::doubleValue));
                if (ValueUtil.isZero(featureImportanceSum)) continue;
                Set<Field<?>> fields = featureFields.get(featureName);
                if (fields == null) {
                    logger.warn("Unused feature '" + featureName.getValue() + "' has non-zero importance");
                    continue;
                }
                Double fieldImportance = featureImportanceSum / (double)fields.size();
                for (Field<?> field : fields) {
                    FieldName fieldName = field.getName();
                    fieldImportances.put((Object)fieldName, (Object)fieldImportance);
                }
            }
            MiningSchema miningSchema = model.getMiningSchema();
            if (miningSchema == null || !miningSchema.hasMiningFields()) continue;
            List list = miningSchema.getMiningFields();
            block6: for (MiningField miningField : list) {
                FieldName name = miningField.getName();
                MiningField.UsageType usageType = miningField.getUsageType();
                switch (usageType) {
                    case ACTIVE: {
                        break;
                    }
                    default: {
                        continue block6;
                    }
                }
                List fieldImportance = fieldImportances.get((Object)name);
                if (fieldImportance == null) continue;
                Double fieldImportanceSum = fieldImportance.stream().collect(Collectors.summingDouble(Number::doubleValue));
                miningField.setImportance((Number)fieldImportanceSum);
            }
        }
    }
}

