/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterStateListener;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.search.SearchHit;
import org.opensearch.threadpool.ThreadPool;

public class TrainingJobClusterStateListener
implements ClusterStateListener {
    @Generated
    private static final Logger log = LogManager.getLogger(TrainingJobClusterStateListener.class);
    private static TrainingJobClusterStateListener INSTANCE;
    private static ModelDao modelDao;
    private static ThreadPool threadPool;
    private static ClusterService clusterService;
    private String oldClusterManagerNodeId = "";
    private String currentClusterManagerNodeId = "";
    private boolean clusterManagerNodeRemoved = false;

    public static synchronized TrainingJobClusterStateListener getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new TrainingJobClusterStateListener();
        }
        return INSTANCE;
    }

    public static synchronized void initialize(ThreadPool threadPool, ModelDao modelDao, ClusterService clusterService) {
        TrainingJobClusterStateListener.threadPool = threadPool;
        TrainingJobClusterStateListener.modelDao = modelDao;
        TrainingJobClusterStateListener.clusterService = clusterService;
    }

    public void clusterChanged(ClusterChangedEvent event) {
        if (event.localNodeClusterManager()) {
            if (event.isNewCluster()) {
                threadPool.schedule(() -> {
                    try {
                        this.updateModelsNewCluster();
                    }
                    catch (IOException | InterruptedException | ExecutionException e) {
                        throw new RuntimeException(e);
                    }
                }, TimeValue.timeValueSeconds((long)1L), "generic");
            } else if (event.nodesRemoved()) {
                List removedNodes = event.nodesDelta().removedNodes();
                threadPool.schedule(() -> {
                    try {
                        this.updateModelsNodesRemoved(removedNodes);
                    }
                    catch (IOException | InterruptedException | ExecutionException e) {
                        throw new RuntimeException(e);
                    }
                }, TimeValue.timeValueSeconds((long)0L), "generic");
            }
        }
    }

    protected void updateModelsNewCluster() throws IOException, InterruptedException, ExecutionException {
        if (modelDao.isCreated()) {
            List<String> modelIds = this.searchModelIds();
            for (String modelId : modelIds) {
                ModelMetadata modelMetadata = this.getModelMetadata(modelId);
                if (!modelMetadata.getState().equals((Object)ModelState.TRAINING)) continue;
                this.updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as cluster crashed");
            }
        }
    }

    protected void updateModelsNodesRemoved(List<DiscoveryNode> removedNodes) throws IOException, InterruptedException, ExecutionException {
        if (modelDao.isCreated()) {
            List<String> modelIds = this.searchModelIds();
            for (DiscoveryNode removedNode : removedNodes) {
                for (String modelId : modelIds) {
                    ModelMetadata modelMetadata = this.getModelMetadata(modelId);
                    if (!modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId()) || !modelMetadata.getState().equals((Object)ModelState.TRAINING)) continue;
                    this.updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as node dropped");
                }
            }
        }
    }

    private List<String> searchModelIds() throws IOException, InterruptedException {
        final ArrayList<String> modelIds = new ArrayList<String>();
        final CountDownLatch latch = new CountDownLatch(1);
        modelDao.search(new SearchRequest(), new ActionListener<SearchResponse>(){

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            public void onResponse(SearchResponse searchResponse) {
                try {
                    for (SearchHit searchHit : searchResponse.getHits().getHits()) {
                        modelIds.add(searchHit.getId());
                    }
                }
                finally {
                    latch.countDown();
                }
            }

            public void onFailure(Exception e) {
                latch.countDown();
            }
        });
        latch.await();
        return modelIds;
    }

    private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException, InterruptedException {
        modelMetadata.setState(ModelState.FAILED);
        modelMetadata.setError(msg);
        final Model model = new Model(modelMetadata, null, modelId);
        modelDao.update(model, new ActionListener<IndexResponse>(){

            public void onResponse(IndexResponse indexResponse) {
                log.info("Model {} marked as {}", (Object)model.getModelID(), (Object)model.getModelMetadata().getState());
            }

            public void onFailure(Exception e) {
                log.error("Failed to update model state", (Throwable)e);
            }
        });
    }

    private ModelMetadata getModelMetadata(String modelId) throws ExecutionException, InterruptedException {
        ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
        if (modelMetadata == null) {
            log.info("Model metadata is null in cluster metadata. This can happen for models training on nodes prior to OpenSearch version 2.14.0.  Fetching model information from system index.");
            Model model = modelDao.get(modelId);
            return model.getModelMetadata();
        }
        return modelMetadata;
    }
}

