/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.metrics_correlation;

import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors;
import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelRequest;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.utils.IndexUtils;
import org.opensearch.ml.engine.algorithms.DLModelExecute;
import org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelationTranslator;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;

@Function(value=FunctionName.METRICS_CORRELATION)
public class MetricsCorrelation
extends DLModelExecute {
    @Generated
    private static final Logger log = LogManager.getLogger(MetricsCorrelation.class);
    private static final int AWAIT_BUSY_THRESHOLD = 1000;
    public static final String MODEL_CONTENT_HASH = "fa7c832e458b085e242f05fbe8938570f97b11aa9155dcd4ad3fbac07af85d3b";
    private Client client;
    private final Settings settings;
    private final ClusterService clusterService;
    public static final String MCORR_ML_VERSION = "1.0.0b2";
    public static final String MODEL_TYPE = "in-house";
    public static final String MCORR_MODEL_URL = "https://artifacts.opensearch.org/models/ml-models/amazon/metrics_correlation/1.0.0b2/torch_script/metrics_correlation-1.0.0b2-torch_script.zip";

    public MetricsCorrelation(Client client, Settings settings, ClusterService clusterService) {
        this.client = client;
        this.settings = settings;
        this.clusterService = clusterService;
    }

    @Override
    public void execute(Input input, ActionListener<org.opensearch.ml.common.output.Output> listener) {
        Output djlOutput;
        if (!(input instanceof MetricsCorrelationInput)) {
            throw new ExecuteException("wrong input");
        }
        ArrayList<MCorrModelTensors> tensorOutputs = new ArrayList<MCorrModelTensors>();
        MetricsCorrelationInput metricsCorrelation = (MetricsCorrelationInput)input;
        List inputData = metricsCorrelation.getInputData();
        float[][] processedInputData = this.processedInput(inputData);
        if (this.modelId == null) {
            boolean hasModelIndex;
            boolean hasModelGroupIndex = this.clusterService.state().getMetadata().hasIndex(".plugins-ml-model-group");
            if (!hasModelGroupIndex) {
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    String mappingContent = IndexUtils.getMappingFromFile((String)"index-mappings/ml_model_group.json");
                    CreateIndexRequest request = new CreateIndexRequest(".plugins-ml-model-group").mapping(mappingContent, XContentType.JSON);
                    CreateIndexResponse createIndexResponse = (CreateIndexResponse)this.client.admin().indices().create(request).actionGet(1000L);
                    if (!createIndexResponse.isAcknowledged()) {
                        throw new MLException("Failed to create model group index");
                    }
                }
                catch (IOException e2) {
                    throw new MLException("Failed to load model group index mapping", (Throwable)e2);
                }
            }
            if (!(hasModelIndex = this.clusterService.state().getMetadata().hasIndex(".plugins-ml-model"))) {
                log.warn("Model Index Not found. Register metric correlation model");
                try {
                    this.registerModel((ActionListener<MLRegisterModelResponse>)ActionListener.wrap(registerModelResponse -> {
                        this.modelId = this.getTask(registerModelResponse.getTaskId()).getModelId();
                    }, ex -> log.error("Exception during registering the Metrics correlation model", (Throwable)ex)));
                }
                catch (InterruptedException ex2) {
                    throw new RuntimeException(ex2);
                }
            } else {
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    GetRequest getModelRequest = new GetRequest(".plugins-ml-model").id(FunctionName.METRICS_CORRELATION.name());
                    ActionListener actionListener = ActionListener.wrap(r -> {
                        if (r.isExists()) {
                            this.modelId = r.getId();
                            Map sourceAsMap = r.getSourceAsMap();
                            String state = (String)sourceAsMap.get("model_state");
                            if (!MLModelState.DEPLOYED.name().equals(state) && !MLModelState.PARTIALLY_DEPLOYED.name().equals(state)) {
                                this.deployModel(r.getId(), (ActionListener<MLDeployModelResponse>)ActionListener.wrap(deployModelResponse -> {
                                    this.modelId = this.getTask(deployModelResponse.getTaskId()).getModelId();
                                }, e -> log.error("Metrics correlation model didn't get deployed to the index successfully", (Throwable)e)));
                            }
                        } else {
                            log.info("metric correlation model not registered yet");
                            this.registerModel((ActionListener<MLRegisterModelResponse>)ActionListener.wrap(registerModelResponse -> {
                                this.modelId = this.getTask(registerModelResponse.getTaskId()).getModelId();
                            }, e -> log.error("Metrics correlation model didn't get registered to the index successfully", (Throwable)e)));
                        }
                    }, e -> log.error("Failed to get model", (Throwable)e));
                    this.client.get(getModelRequest, ActionListener.runBefore((ActionListener)actionListener, () -> ((ThreadContext.StoredContext)context).restore()));
                }
            }
        } else {
            MLModel model = this.getModel(this.modelId);
            if (model.getModelState() != MLModelState.DEPLOYED && model.getModelState() != MLModelState.PARTIALLY_DEPLOYED) {
                this.deployModel(this.modelId, (ActionListener<MLDeployModelResponse>)ActionListener.wrap(deployModelResponse -> {
                    this.modelId = this.getTask(deployModelResponse.getTaskId()).getModelId();
                }, e -> log.error("Metrics correlation model didn't get deployed to the index successfully", (Throwable)e)));
            }
        }
        MetricsCorrelation.waitUntil(() -> {
            if (this.modelId != null) {
                MLModelState modelState = this.getModel(this.modelId).getModelState();
                if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED) {
                    log.info("Model deployed: " + String.valueOf(modelState));
                    return true;
                }
                if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) {
                    log.info("Model not deployed: " + String.valueOf(modelState));
                    this.deployModel(this.modelId, (ActionListener<MLDeployModelResponse>)ActionListener.wrap(deployModelResponse -> {
                        this.modelId = this.getTask(deployModelResponse.getTaskId()).getModelId();
                    }, e -> log.error("Metrics correlation model didn't get deployed to the index successfully", (Throwable)e)));
                    return false;
                }
            }
            return false;
        }, 120L, TimeUnit.SECONDS);
        try {
            if (this.modelId == null) {
                throw new ExecuteException("Model is not loaded yet. Please try again.");
            }
            djlOutput = (Output)this.getPredictor().predict((Object)processedInputData);
        }
        catch (TranslateException translateException) {
            throw new ExecuteException((Throwable)translateException);
        }
        tensorOutputs.add(this.parseModelTensorOutput(djlOutput, null));
        listener.onResponse((Object)new MetricsCorrelationOutput(tensorOutputs));
    }

    @VisibleForTesting
    void registerModel(ActionListener<MLRegisterModelResponse> listener) throws InterruptedException {
        FunctionName functionName = FunctionName.METRICS_CORRELATION;
        MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT;
        MetricsCorrelationModelConfig modelConfig = MetricsCorrelationModelConfig.builder().modelType(MODEL_TYPE).allConfig(null).build();
        MLRegisterModelInput input = MLRegisterModelInput.builder().functionName(functionName).modelName(FunctionName.METRICS_CORRELATION.name()).version(MCORR_ML_VERSION).modelGroupId(functionName.name()).modelFormat(modelFormat).hashValue(MODEL_CONTENT_HASH).modelConfig((MLModelConfig)modelConfig).url(MCORR_MODEL_URL).deployModel(true).build();
        MLRegisterModelRequest registerRequest = MLRegisterModelRequest.builder().registerModelInput(input).build();
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            IndexRequest createModelGroupRequest = new IndexRequest(".plugins-ml-model-group").id(functionName.name());
            MLModelGroup modelGroup = MLModelGroup.builder().name(functionName.name()).access(AccessMode.PUBLIC.getValue()).createdTime(Instant.now()).build();
            XContentBuilder builder = XContentBuilder.builder((XContent)XContentType.JSON.xContent());
            modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS);
            createModelGroupRequest.source(builder);
            this.client.index(createModelGroupRequest, ActionListener.runBefore((ActionListener)ActionListener.wrap(r -> this.client.execute((ActionType)MLRegisterModelAction.INSTANCE, (ActionRequest)registerRequest, ActionListener.wrap(arg_0 -> ((ActionListener)listener).onResponse(arg_0), e -> {
                log.error("Failed to Register Model", (Throwable)e);
                listener.onFailure(e);
            })), arg_0 -> listener.onFailure(arg_0)), () -> ((ThreadContext.StoredContext)context).restore()));
        }
        catch (IOException e) {
            throw new MLException((Throwable)e);
        }
    }

    @VisibleForTesting
    void deployModel(String modelId, ActionListener<MLDeployModelResponse> listener) {
        MLDeployModelRequest loadRequest = MLDeployModelRequest.builder().modelId(modelId).async(false).dispatchTask(false).build();
        this.client.execute((ActionType)MLDeployModelAction.INSTANCE, (ActionRequest)loadRequest, ActionListener.wrap(arg_0 -> listener.onResponse(arg_0), e -> {
            log.error("Failed to deploy Model", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    @VisibleForTesting
    float[][] processedInput(List<float[]> input) {
        float[][] processInput = new float[input.size()][];
        for (int i = 0; i < input.size(); ++i) {
            float[] innerList = input.get(i);
            processInput[i] = new float[innerList.length];
            float[] temp = processInput[i];
            System.arraycopy(innerList, 0, temp, 0, temp.length);
        }
        return processInput;
    }

    @Override
    public MetricsCorrelationTranslator getTranslator() {
        return new MetricsCorrelationTranslator();
    }

    @VisibleForTesting
    SearchRequest getSearchRequest() {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.fetchSource(new String[]{"model_id", "name", "model_state", "model_version", "model_content"}, new String[]{"model_content"});
        BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should((QueryBuilder)QueryBuilders.termQuery((String)"name", (String)FunctionName.METRICS_CORRELATION.name())).should((QueryBuilder)QueryBuilders.termQuery((String)"model_version", (String)MCORR_ML_VERSION));
        searchSourceBuilder.query((QueryBuilder)boolQueryBuilder);
        return new SearchRequest().source(searchSourceBuilder).indices(new String[]{".plugins-ml-model"});
    }

    public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, TimeUnit unit) throws ExecuteException {
        long maxTimeInMillis = TimeUnit.MILLISECONDS.convert(maxWaitTime, unit);
        long timeInMillis = 1L;
        long sum = 0L;
        while (sum + timeInMillis < maxTimeInMillis) {
            if (breakSupplier.getAsBoolean()) {
                return true;
            }
            try {
                Thread.sleep(timeInMillis);
            }
            catch (InterruptedException interruptedException) {
                throw new ExecuteException((Throwable)interruptedException);
            }
            timeInMillis = Math.min(1000L, timeInMillis * 2L);
            log.info("Waiting... Time elapsed: " + (sum += timeInMillis) + "ms");
        }
        timeInMillis = maxTimeInMillis - sum;
        try {
            Thread.sleep(Math.max(timeInMillis, 0L));
        }
        catch (InterruptedException interruptedException) {
            throw new ExecuteException((Throwable)interruptedException);
        }
        return breakSupplier.getAsBoolean();
    }

    public MLTask getTask(String taskId) {
        MLTaskGetRequest getRequest = new MLTaskGetRequest(taskId, null);
        MLTaskGetResponse response = (MLTaskGetResponse)this.client.execute((ActionType)MLTaskGetAction.INSTANCE, (ActionRequest)getRequest).actionGet(10000L);
        return response.getMlTask();
    }

    public MLModel getModel(String modelId) {
        MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, false, null);
        ActionFuture future = this.client.execute((ActionType)MLModelGetAction.INSTANCE, (ActionRequest)getRequest);
        MLModelGetResponse response = (MLModelGetResponse)future.actionGet(5000L);
        return response.getMlModel();
    }

    public MCorrModelTensors parseModelTensorOutput(Output output, ModelResultFilter resultFilter) {
        if (output == null) {
            throw new MLException("No output generated");
        }
        byte[] bytes = output.getData().getAsBytes();
        MCorrModelTensors tensorOutput = MCorrModelTensors.fromBytes((byte[])bytes);
        if (resultFilter != null) {
            tensorOutput.filter(resultFilter);
        }
        return tensorOutput;
    }
}

