/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.MLTaskRequest;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public abstract class MLTaskRunner<Request extends MLTaskRequest, Response extends TransportResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskRunner.class);
    public static final int TIMEOUT_IN_MILLIS = 2000;
    protected final MLTaskManager mlTaskManager;
    protected final MLStats mlStats;
    protected final DiscoveryNodeHelper nodeHelper;
    protected final MLTaskDispatcher mlTaskDispatcher;
    protected final MLCircuitBreakerService mlCircuitBreakerService;
    private final ClusterService clusterService;

    public MLTaskRunner(MLTaskManager mlTaskManager, MLStats mlStats, DiscoveryNodeHelper nodeHelper, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService, ClusterService clusterService) {
        this.mlTaskManager = mlTaskManager;
        this.mlStats = mlStats;
        this.nodeHelper = nodeHelper;
        this.mlTaskDispatcher = mlTaskDispatcher;
        this.mlCircuitBreakerService = mlCircuitBreakerService;
        this.clusterService = clusterService;
    }

    protected void handleAsyncMLTaskFailure(MLTask mlTask, Exception e) {
        if (mlTask.isAsync()) {
            ImmutableMap updatedFields = ImmutableMap.of((Object)"state", (Object)MLTaskState.FAILED.name(), (Object)"error", (Object)e.getMessage());
            this.mlTaskManager.updateMLTask(mlTask.getTaskId(), null, (Map<String, Object>)updatedFields, 2000L, true);
        }
    }

    protected void handleAsyncMLTaskComplete(MLTask mlTask) {
        if (mlTask.isAsync()) {
            HashMap<String, Object> updatedFields = new HashMap<String, Object>();
            updatedFields.put("state", MLTaskState.COMPLETED);
            if (mlTask.getModelId() != null) {
                updatedFields.put("model_id", mlTask.getModelId());
            }
            this.mlTaskManager.updateMLTask(mlTask.getTaskId(), null, updatedFields, 2000L, true);
        }
    }

    public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
        if (!request.isDispatchTask()) {
            log.debug("Run ML request {} locally", (Object)request.getRequestID());
            MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
            this.checkCBAndExecute(functionName, request, listener);
            return;
        }
        this.dispatchTask(functionName, request, transportService, listener);
    }

    protected ActionListener<MLTaskResponse> wrappedCleanupListener(ActionListener<MLTaskResponse> listener, String taskId) {
        ActionListener internalListener = ActionListener.runAfter(listener, () -> {
            this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
            this.mlTaskManager.remove(taskId);
        });
        return internalListener;
    }

    public void dispatchTask(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
        this.mlTaskDispatcher.dispatch(functionName, (ActionListener<DiscoveryNode>)ActionListener.wrap(node -> {
            String nodeId = node.getId();
            if (this.clusterService.localNode().getId().equals(nodeId)) {
                log.debug("Execute ML request {} locally on node {}", (Object)request.getRequestID(), (Object)nodeId);
                MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
                this.executeTask(request, listener);
            } else {
                log.debug("Execute ML request {} remotely on node {}", (Object)request.getRequestID(), (Object)nodeId);
                request.setDispatchTask(false);
                transportService.sendRequest(node, this.getTransportActionName(), (TransportRequest)request, this.getResponseHandler(listener));
            }
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    protected abstract String getTransportActionName();

    protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> var1);

    protected abstract void executeTask(Request var1, ActionListener<Response> var2);

    protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
        if (functionName != FunctionName.REMOTE) {
            MLNodeUtils.checkOpenCircuitBreaker(this.mlCircuitBreakerService, this.mlStats);
        }
        this.executeTask(request, listener);
    }
}

