/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.execute;

import java.io.IOException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.Nullable;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.task.MLExecuteTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.tasks.Task;
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class TransportExecuteStreamTaskAction
extends HandledTransportAction<ActionRequest, MLExecuteTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportExecuteStreamTaskAction.class);
    private final MLTaskRunner<MLExecuteTaskRequest, MLExecuteTaskResponse> mlExecuteTaskRunner;
    private final TransportService transportService;
    public static StreamTransportService streamTransportService;
    private static StreamTransportService streamTransportServiceInstance;

    @Inject
    public TransportExecuteStreamTaskAction(TransportService transportService, ActionFilters actionFilters, MLExecuteTaskRunner mlExecuteTaskRunner, @Nullable StreamTransportService streamTransportService) {
        super("cluster:admin/opensearch/ml/execute/stream", transportService, actionFilters, MLExecuteTaskRequest::new);
        this.mlExecuteTaskRunner = mlExecuteTaskRunner;
        this.transportService = transportService;
        if (streamTransportServiceInstance == null) {
            streamTransportServiceInstance = streamTransportService;
        }
        TransportExecuteStreamTaskAction.streamTransportService = streamTransportServiceInstance;
        if (streamTransportService != null) {
            streamTransportService.registerRequestHandler("cluster:admin/opensearch/ml/execute/stream", "opensearch_ml_execute_stream", MLExecuteTaskRequest::new, this::messageReceived);
        } else {
            log.warn("StreamTransportService is not available.");
        }
    }

    public static StreamTransportService getStreamTransportService() {
        return streamTransportService;
    }

    public void messageReceived(MLExecuteTaskRequest request, final TransportChannel channel, Task task) {
        request.setStreamingChannel(channel);
        this.transportService.sendRequest(this.transportService.getLocalNode(), "cluster:admin/opensearch/ml/execute/stream", (TransportRequest)request, (TransportResponseHandler)new TransportResponseHandler<MLExecuteTaskResponse>(){

            public MLExecuteTaskResponse read(StreamInput in) throws IOException {
                return new MLExecuteTaskResponse(in);
            }

            public void handleResponse(MLExecuteTaskResponse response) {
            }

            public void handleException(TransportException exp) {
                try {
                    channel.sendResponse((Exception)exp);
                }
                catch (Exception e) {
                    log.error("Failed to send error response", (Throwable)e);
                }
            }

            public String executor() {
                return "same";
            }
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLExecuteTaskResponse> listener) {
        TransportChannel channel = ((MLExecuteTaskRequest)request).getStreamingChannel();
        if (channel != null) {
            this.doExecute(task, request, listener, channel);
        } else {
            listener.onFailure((Exception)new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests"));
        }
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLExecuteTaskResponse> listener, TransportChannel channel) {
        MLExecuteTaskRequest mlExecuteTaskRequest = MLExecuteTaskRequest.fromActionRequest((ActionRequest)request);
        mlExecuteTaskRequest.setStreamingChannel(channel);
        if (mlExecuteTaskRequest.getStreamingChannel() != null) {
            mlExecuteTaskRequest.setDispatchTask(false);
        }
        FunctionName functionName = mlExecuteTaskRequest.getFunctionName();
        this.mlExecuteTaskRunner.run(functionName, mlExecuteTaskRequest, (TransportService)streamTransportService, listener);
    }
}

