/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine;

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Progress;
import com.google.gson.stream.JsonReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.BaseModelConfig;
import org.opensearch.ml.common.model.MLDeploySetting;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.utils.FileUtils;

public class ModelHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(ModelHelper.class);
    public static final String CHUNK_FILES = "chunk_files";
    public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes";
    public static final String MODEL_FILE_HASH = "model_file_hash";
    public static final int CHUNK_SIZE = 10000000;
    public static final String PYTORCH_FILE_EXTENSION = ".pt";
    public static final String ONNX_FILE_EXTENSION = ".onnx";
    public static final String TOKENIZER_FILE_NAME = "tokenizer.json";
    public static final String PYTORCH_ENGINE = "PyTorch";
    public static final String ONNX_ENGINE = "OnnxRuntime";
    private final MLEngine mlEngine;

    public ModelHelper(MLEngine mlEngine) {
        this.mlEngine = mlEngine;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelInput> listener) {
        String modelName = registerModelInput.getModelName();
        FunctionName algorithm = registerModelInput.getFunctionName();
        String version = registerModelInput.getVersion();
        MLModelFormat modelFormat = registerModelInput.getModelFormat();
        Boolean isHidden = registerModelInput.getIsHidden();
        boolean deployModel = registerModelInput.isDeployModel();
        String[] modelNodeIds = registerModelInput.getModelNodeIds();
        String modelGroupId = registerModelInput.getModelGroupId();
        MLDeploySetting mlDeploySetting = registerModelInput.getDeploySetting();
        try {
            AccessController.doPrivileged(() -> {
                Path registerModelPath = this.mlEngine.getRegisterModelPath(taskId, modelName, version);
                String configCacheFilePath = registerModelPath.resolve("config.json").toString();
                String configFileUrl = this.mlEngine.getPrebuiltModelConfigPath(modelName, version, modelFormat);
                String modelZipFileUrl = this.mlEngine.getPrebuiltModelPath(modelName, version, modelFormat);
                DownloadUtils.download((String)configFileUrl, (String)configCacheFilePath, (Progress)new ProgressBar());
                Map config = null;
                try (JsonReader reader = new JsonReader((Reader)new FileReader(configCacheFilePath));){
                    config = (Map)StringUtils.gson.fromJson(reader, Map.class);
                }
                if (config == null) {
                    listener.onFailure((Exception)new IllegalArgumentException("model config not found"));
                    return null;
                }
                MLRegisterModelInput.MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder();
                String functionName = config.containsKey("function_name") ? (String)config.get("function_name") : (String)config.get("model_task_type");
                builder.modelName(modelName).version(version).url(modelZipFileUrl).deployModel(deployModel).modelNodeIds(modelNodeIds).isHidden(isHidden).modelGroupId(modelGroupId).functionName(FunctionName.from((String)functionName)).deploySetting(mlDeploySetting);
                config.entrySet().forEach(entry -> {
                    switch (entry.getKey().toString()) {
                        case "model_format": {
                            builder.modelFormat(MLModelFormat.from((String)entry.getValue().toString()));
                            break;
                        }
                        case "model_config": {
                            if (FunctionName.QUESTION_ANSWERING.equals((Object)algorithm)) {
                                QuestionAnsweringModelConfig.QuestionAnsweringModelConfigBuilder configBuilder = QuestionAnsweringModelConfig.builder();
                                Map configMap = (Map)entry.getValue();
                                for (Map.Entry configEntry : configMap.entrySet()) {
                                    switch (configEntry.getKey().toString()) {
                                        case "model_type": {
                                            configBuilder.modelType(configEntry.getValue().toString());
                                            break;
                                        }
                                        case "all_config": {
                                            configBuilder.allConfig(configEntry.getValue().toString());
                                            break;
                                        }
                                        case "framework_type": {
                                            configBuilder.frameworkType(QuestionAnsweringModelConfig.FrameworkType.from((String)configEntry.getValue().toString()));
                                            break;
                                        }
                                    }
                                }
                                builder.modelConfig((MLModelConfig)configBuilder.build());
                                break;
                            }
                            TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder configBuilder = TextEmbeddingModelConfig.builder();
                            Map configMap = (Map)entry.getValue();
                            for (Map.Entry configEntry : configMap.entrySet()) {
                                switch (configEntry.getKey().toString()) {
                                    case "model_type": {
                                        configBuilder.modelType(configEntry.getValue().toString());
                                        break;
                                    }
                                    case "all_config": {
                                        configBuilder.allConfig(configEntry.getValue().toString());
                                        break;
                                    }
                                    case "additional_config": {
                                        configBuilder.additionalConfig((Map)configEntry.getValue());
                                        break;
                                    }
                                    case "embedding_dimension": {
                                        configBuilder.embeddingDimension(Integer.valueOf(((Double)configEntry.getValue()).intValue()));
                                        break;
                                    }
                                    case "framework_type": {
                                        configBuilder.frameworkType(BaseModelConfig.FrameworkType.from((String)configEntry.getValue().toString()));
                                        break;
                                    }
                                    case "pooling_mode": {
                                        configBuilder.poolingMode(BaseModelConfig.PoolingMode.from((String)configEntry.getValue().toString().toUpperCase(Locale.ROOT)));
                                        break;
                                    }
                                    case "normalize_result": {
                                        configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
                                        break;
                                    }
                                    case "model_max_length": {
                                        configBuilder.modelMaxLength(Integer.valueOf(((Double)configEntry.getValue()).intValue()));
                                        break;
                                    }
                                    case "query_prefix": {
                                        configBuilder.queryPrefix(configEntry.getValue().toString());
                                        break;
                                    }
                                    case "passage_prefix": {
                                        configBuilder.passagePrefix(configEntry.getValue().toString());
                                        break;
                                    }
                                }
                            }
                            builder.modelConfig((MLModelConfig)configBuilder.build());
                            break;
                        }
                        case "model_content_hash_value": {
                            builder.hashValue(entry.getValue().toString());
                            break;
                        }
                    }
                });
                listener.onResponse((Object)builder.build());
                return null;
            });
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
        finally {
            FileUtils.deleteFileQuietly(this.mlEngine.getRegisterModelPath(taskId));
        }
    }

    public boolean isModelAllowed(MLRegisterModelInput registerModelInput, List modelMetaList) {
        String modelName = registerModelInput.getModelName();
        String version = registerModelInput.getVersion();
        MLModelFormat modelFormat = registerModelInput.getModelFormat();
        for (Object meta : modelMetaList) {
            String name = (String)((Map)meta).get("name");
            List versions = (List)((Map)meta).get("version");
            List formats = (List)((Map)meta).get("format");
            if (!name.equals(modelName) || !versions.contains(version.toLowerCase(Locale.ROOT)) || !formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) continue;
            return true;
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput registerModelInput) throws PrivilegedActionException {
        String modelName = registerModelInput.getModelName();
        String version = registerModelInput.getVersion();
        try {
            List list = AccessController.doPrivileged(() -> {
                Path registerModelPath = this.mlEngine.getRegisterModelPath(taskId, modelName, version);
                String cacheFilePath = registerModelPath.resolve("model_meta_list.json").toString();
                String modelMetaListUrl = this.mlEngine.getPrebuiltModelMetaListPath();
                DownloadUtils.download((String)modelMetaListUrl, (String)cacheFilePath, (Progress)new ProgressBar());
                List config = null;
                try (JsonReader reader = new JsonReader((Reader)new FileReader(cacheFilePath));){
                    config = (List)StringUtils.gson.fromJson(reader, List.class);
                }
                return config;
            });
            return list;
        }
        finally {
            FileUtils.deleteFileQuietly(this.mlEngine.getRegisterModelPath(taskId));
        }
    }

    public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
        try {
            AccessController.doPrivileged(() -> {
                Path registerModelPath = this.mlEngine.getRegisterModelPath(taskId, modelName, version);
                String modelPath = String.valueOf(registerModelPath) + ".zip";
                Path modelPartsPath = registerModelPath.resolve("chunks");
                File modelZipFile = new File(modelPath);
                log.debug("download model to file {}", (Object)modelZipFile.getAbsolutePath());
                DownloadUtils.download((String)url, (String)modelPath, (Progress)new ProgressBar());
                this.verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
                String hash = FileUtils.calculateFileHash(modelZipFile);
                if (modelContentHash == null) {
                    log.error("Hash code need to be provided when register via url.");
                    throw new IllegalArgumentException("Model content Hash code need to be provided when register via url. Please calculate sha 256 Hash code.");
                }
                if (hash.equals(modelContentHash)) {
                    List<String> chunkFiles = FileUtils.splitFileIntoChunks(modelZipFile, modelPartsPath, 10000000);
                    HashMap<String, Object> result = new HashMap<String, Object>();
                    result.put(CHUNK_FILES, chunkFiles);
                    result.put(MODEL_SIZE_IN_BYTES, modelZipFile.length());
                    result.put(MODEL_FILE_HASH, FileUtils.calculateFileHash(modelZipFile));
                    FileUtils.deleteFileQuietly(modelZipFile);
                    listener.onResponse(result);
                    return null;
                }
                log.error("Model content hash can't match original hash value when registering");
                throw new IllegalArgumentException("model content changed");
            });
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
        boolean hasPtFile = false;
        boolean hasOnnxFile = false;
        boolean hasTokenizerFile = false;
        try (ZipFile zipFile = new ZipFile(modelZipFilePath);){
            Enumeration<? extends ZipEntry> zipEntries = zipFile.entries();
            while (zipEntries.hasMoreElements()) {
                String fileName = zipEntries.nextElement().getName();
                hasPtFile = ModelHelper.hasModelFile(modelFormat, MLModelFormat.TORCH_SCRIPT, PYTORCH_FILE_EXTENSION, hasPtFile, fileName);
                hasOnnxFile = ModelHelper.hasModelFile(modelFormat, MLModelFormat.ONNX, ONNX_FILE_EXTENSION, hasOnnxFile, fileName);
                if (!fileName.equals(TOKENIZER_FILE_NAME)) continue;
                hasTokenizerFile = true;
            }
        }
        if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.SPARSE_TOKENIZE) {
            throw new IllegalArgumentException("Can't find model file");
        }
        if (!hasTokenizerFile && modelName != FunctionName.METRICS_CORRELATION.toString()) {
            throw new IllegalArgumentException("No tokenizer file");
        }
    }

    private static boolean hasModelFile(MLModelFormat modelFormat, MLModelFormat targetModelFormat, String fileExtension, boolean hasModelFile, String fileName) {
        if (fileName.endsWith(fileExtension)) {
            if (modelFormat != targetModelFormat) {
                throw new IllegalArgumentException("Model format is " + String.valueOf(modelFormat) + ", but find " + fileExtension + " file");
            }
            if (hasModelFile) {
                throw new IllegalArgumentException("Find multiple model files, but expected only one");
            }
            return true;
        }
        return hasModelFile;
    }

    public void deleteFileCache(String modelId) {
        FileUtils.deleteFileQuietly(this.mlEngine.getModelCachePath(modelId));
        FileUtils.deleteFileQuietly(this.mlEngine.getDeployModelPath(modelId));
        FileUtils.deleteFileQuietly(this.mlEngine.getRegisterModelPath(modelId));
    }
}

