/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.resource.enumeration;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.hadoop.util.Lists;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.resource.CloudInstance;
import org.apache.sysds.resource.ResourceCompiler;
import org.apache.sysds.resource.enumeration.EnumerationUtils;
import org.apache.sysds.resource.enumeration.Enumerator;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;

public class PruneBasedEnumerator
extends Enumerator {
    long insufficientSingleNodeMemory = -1L;
    long singleNodeOnlyMemory = Long.MAX_VALUE;
    HashMap<Long, Integer> maxExecutorsPerInstanceMap = new HashMap();

    public PruneBasedEnumerator(Enumerator.Builder builder) {
        super(builder);
    }

    @Override
    public void preprocessing() {
        this.driverSpace.initSpace(this.instances);
        this.executorSpace.initSpace(this.instances);
        for (Map.Entry eMemoryEntry : this.executorSpace.entrySet()) {
            for (Integer eCores : ((TreeMap)eMemoryEntry.getValue()).keySet()) {
                long combinationHash = PruneBasedEnumerator.combineHash((Long)eMemoryEntry.getKey(), eCores);
                this.maxExecutorsPerInstanceMap.put(combinationHash, this.maxExecutors);
            }
        }
    }

    @Override
    public void processing() {
        for (Map.Entry dMemoryEntry : this.driverSpace.entrySet()) {
            long driverMemory = (Long)dMemoryEntry.getKey();
            for (Map.Entry dCoresEntry : ((TreeMap)dMemoryEntry.getValue()).entrySet()) {
                EnumerationUtils.ConfigurationPoint configurationPoint;
                int driverCores = (Integer)dCoresEntry.getKey();
                if (this.evaluateSingleNodeExecution(driverMemory, driverCores)) {
                    ResourceCompiler.setSingleNodeResourceConfigs(driverMemory, driverCores);
                    this.program = ResourceCompiler.doFullRecompilation(this.program);
                    for (CloudInstance dInstance : (LinkedList)dCoresEntry.getValue()) {
                        configurationPoint = new EnumerationUtils.ConfigurationPoint(dInstance);
                        double[] newEstimates = this.getCostEstimate(configurationPoint);
                        if (PruneBasedEnumerator.isInvalidConfiguration(newEstimates)) {
                            this.insufficientSingleNodeMemory = driverMemory;
                            break;
                        }
                        this.updateOptimalSolution(newEstimates[0], newEstimates[1], configurationPoint);
                    }
                }
                if (driverMemory >= this.singleNodeOnlyMemory) continue;
                for (Map.Entry eMemoryEntry : this.executorSpace.entrySet()) {
                    if (driverMemory >= this.singleNodeOnlyMemory) continue;
                    long executorMemory = (Long)eMemoryEntry.getKey();
                    for (Map.Entry eCoresEntry : ((TreeMap)eMemoryEntry.getValue()).entrySet()) {
                        if (driverMemory >= this.singleNodeOnlyMemory) continue;
                        int executorCores = (Integer)eCoresEntry.getKey();
                        ArrayList<Integer> numberExecutorsSet = this.estimateRangeExecutors(driverCores, (Long)eMemoryEntry.getKey(), (Integer)eCoresEntry.getKey());
                        double localBestCostScore = Double.MAX_VALUE;
                        int newLocalBestNumberExecutors = -1;
                        Iterator iterator = numberExecutorsSet.iterator();
                        while (iterator.hasNext()) {
                            int numberExecutors = (Integer)iterator.next();
                            try {
                                ResourceCompiler.setSparkClusterResourceConfigs(driverMemory, driverCores, numberExecutors, executorMemory, executorCores);
                            }
                            catch (IllegalArgumentException e) {
                                break;
                            }
                            this.program = ResourceCompiler.doFullRecompilation(this.program);
                            if (!PruneBasedEnumerator.hasSparkInstructions(this.program)) {
                                this.singleNodeOnlyMemory = driverMemory;
                                break;
                            }
                            for (CloudInstance dInstance : (LinkedList)dCoresEntry.getValue()) {
                                for (CloudInstance eInstance : (LinkedList)eCoresEntry.getValue()) {
                                    configurationPoint = new EnumerationUtils.ConfigurationPoint(dInstance, eInstance, numberExecutors);
                                    double[] newEstimates = this.getCostEstimate(configurationPoint);
                                    this.updateOptimalSolution(newEstimates[0], newEstimates[1], configurationPoint);
                                    if (this.optStrategy == Enumerator.OptimizationStrategy.MinCosts) {
                                        double optimalScore = PruneBasedEnumerator.linearScoringFunction(newEstimates[0], newEstimates[1]);
                                        if (!(localBestCostScore > optimalScore)) continue;
                                        localBestCostScore = optimalScore;
                                        newLocalBestNumberExecutors = configurationPoint.numberExecutors;
                                        continue;
                                    }
                                    if (this.optStrategy == Enumerator.OptimizationStrategy.MinTime) {
                                        if (!(localBestCostScore > newEstimates[0])) continue;
                                        localBestCostScore = newEstimates[0];
                                        newLocalBestNumberExecutors = configurationPoint.numberExecutors;
                                        continue;
                                    }
                                    if (!(localBestCostScore > newEstimates[1])) continue;
                                    localBestCostScore = newEstimates[1];
                                    newLocalBestNumberExecutors = configurationPoint.numberExecutors;
                                }
                            }
                        }
                        if (!(localBestCostScore < Double.MAX_VALUE) || newLocalBestNumberExecutors <= 0) continue;
                        long combinationHash = PruneBasedEnumerator.combineHash(executorMemory, executorCores);
                        this.maxExecutorsPerInstanceMap.put(combinationHash, newLocalBestNumberExecutors);
                    }
                }
            }
        }
    }

    @Override
    public boolean evaluateSingleNodeExecution(long driverMemory, int cores) {
        if (cores > CPU_QUOTA || this.minExecutors > 0) {
            return false;
        }
        return this.insufficientSingleNodeMemory != driverMemory;
    }

    @Override
    public ArrayList<Integer> estimateRangeExecutors(int driverCores, long executorMemory, int executorCores) {
        int maxAchievableLevelOfParallelism = CPU_QUOTA - driverCores;
        int currentMax = Math.min(this.maxExecutors, maxAchievableLevelOfParallelism / executorCores);
        long combinationHash = PruneBasedEnumerator.combineHash(executorMemory, executorCores);
        int maxExecutorsToConsider = this.maxExecutorsPerInstanceMap.get(combinationHash);
        currentMax = Math.min(currentMax, maxExecutorsToConsider);
        ArrayList<Integer> result = new ArrayList<Integer>();
        for (int i = 1; i <= currentMax; ++i) {
            result.add(i);
        }
        return result;
    }

    public static long combineHash(long executorMemory, int cores) {
        return executorMemory + (long)cores;
    }

    public static boolean isInvalidConfiguration(double[] estimates) {
        return estimates[0] == Double.MAX_VALUE && estimates[1] == Double.MAX_VALUE;
    }

    public static boolean hasSparkInstructions(Program program) {
        boolean hasSparkInst;
        HashMap<String, FunctionProgramBlock> funcMap = program.getFunctionProgramBlocks();
        if (funcMap != null && !funcMap.isEmpty()) {
            for (Map.Entry entry : funcMap.entrySet()) {
                String fkey = (String)entry.getKey();
                FunctionProgramBlock fpb = (FunctionProgramBlock)entry.getValue();
                for (ProgramBlock pb : fpb.getChildBlocks()) {
                    hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pb);
                    if (!hasSparkInst) continue;
                    return true;
                }
                if (!program.containsFunctionProgramBlock(fkey, false)) continue;
                FunctionProgramBlock fpb2 = program.getFunctionProgramBlock(fkey, false);
                for (ProgramBlock pb : fpb2.getChildBlocks()) {
                    hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pb);
                    if (!hasSparkInst) continue;
                    return true;
                }
            }
        }
        for (ProgramBlock programBlock : program.getProgramBlocks()) {
            hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(programBlock);
            if (!hasSparkInst) continue;
            return true;
        }
        return false;
    }

    private static boolean hasSparkInstructions(ProgramBlock pb) {
        if (pb instanceof FunctionProgramBlock) {
            FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
            for (ProgramBlock pbc : fpb.getChildBlocks()) {
                boolean hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pbc);
                if (!hasSparkInst) continue;
                return true;
            }
        } else if (pb instanceof WhileProgramBlock) {
            WhileProgramBlock wpb = (WhileProgramBlock)pb;
            boolean hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(wpb.getPredicate());
            if (hasSparkInst) {
                return true;
            }
            for (ProgramBlock pbc : wpb.getChildBlocks()) {
                hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pbc);
                if (!hasSparkInst) continue;
                return true;
            }
            if (wpb.getExitInstruction() != null) {
                hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(Lists.newArrayList((Object[])new Instruction[]{wpb.getExitInstruction()}));
                return hasSparkInst;
            }
        } else if (pb instanceof IfProgramBlock) {
            IfProgramBlock ipb = (IfProgramBlock)pb;
            boolean hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(ipb.getPredicate());
            if (hasSparkInst) {
                return true;
            }
            for (ProgramBlock pbc : ipb.getChildBlocksIfBody()) {
                hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pbc);
                if (!hasSparkInst) continue;
                return true;
            }
            if (!ipb.getChildBlocksElseBody().isEmpty()) {
                for (ProgramBlock pbc : ipb.getChildBlocksElseBody()) {
                    hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pbc);
                    if (!hasSparkInst) continue;
                    return true;
                }
            }
            if (ipb.getExitInstruction() != null) {
                hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(Lists.newArrayList((Object[])new Instruction[]{ipb.getExitInstruction()}));
                return hasSparkInst;
            }
        } else if (pb instanceof ForProgramBlock) {
            ForProgramBlock fpb = (ForProgramBlock)pb;
            boolean hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(fpb.getFromInstructions());
            if (hasSparkInst) {
                return true;
            }
            hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(fpb.getToInstructions());
            if (hasSparkInst) {
                return true;
            }
            hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(fpb.getIncrementInstructions());
            if (hasSparkInst) {
                return true;
            }
            for (ProgramBlock pbc : fpb.getChildBlocks()) {
                hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(pbc);
                if (!hasSparkInst) continue;
                return true;
            }
            if (fpb.getExitInstruction() != null) {
                hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(Lists.newArrayList((Object[])new Instruction[]{fpb.getExitInstruction()}));
                return hasSparkInst;
            }
        } else if (pb instanceof BasicProgramBlock) {
            BasicProgramBlock bpb = (BasicProgramBlock)pb;
            boolean hasSparkInst = PruneBasedEnumerator.hasSparkInstructions(bpb.getInstructions());
            return hasSparkInst;
        }
        return false;
    }

    private static boolean hasSparkInstructions(List<Instruction> instructions) {
        for (Instruction inst : instructions) {
            String opcode;
            Instruction.IType iType = inst.getType();
            if (!iType.equals((Object)Instruction.IType.SPARK) || (opcode = inst.getOpcode()).contains(Opcodes.RBLK.toString()) || opcode.contains("chkpoint")) continue;
            return true;
        }
        return false;
    }
}

