/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hertzbeat.collector.collect.database;

import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.apache.hertzbeat.collector.collect.AbstractCollect;
import org.apache.hertzbeat.collector.collect.common.cache.AbstractConnection;
import org.apache.hertzbeat.collector.collect.common.cache.CacheIdentifier;
import org.apache.hertzbeat.collector.collect.common.cache.GlobalConnectionCache;
import org.apache.hertzbeat.collector.collect.common.cache.JdbcConnect;
import org.apache.hertzbeat.collector.collect.common.ssh.SshTunnelHelper;
import org.apache.hertzbeat.collector.util.CollectUtil;
import org.apache.hertzbeat.common.entity.job.Metrics;
import org.apache.hertzbeat.common.entity.job.SshTunnel;
import org.apache.hertzbeat.common.entity.job.protocol.JdbcProtocol;
import org.apache.hertzbeat.common.entity.message.CollectRep;
import org.apache.hertzbeat.common.util.CommonUtil;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.channel.exception.SshChannelOpenException;
import org.postgresql.util.PSQLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.jdbc.datasource.init.ScriptUtils;
import org.springframework.util.StringUtils;

public class JdbcCommonCollect
extends AbstractCollect {
    private static final Logger log = LoggerFactory.getLogger(JdbcCommonCollect.class);
    private static final String QUERY_TYPE_ONE_ROW = "oneRow";
    private static final String QUERY_TYPE_MULTI_ROW = "multiRow";
    private static final String QUERY_TYPE_COLUMNS = "columns";
    private static final String RUN_SCRIPT = "runScript";
    private static final String[] VULNERABLE_KEYWORDS = new String[]{"allowLoadLocalInfile", "allowLoadLocalInfileInPath", "useLocalInfile"};
    private static final String[] BLACK_LIST = new String[]{"create trigger", "create alias", "runscript from", "shutdown", "drop table", "drop database", "create function", "alter system", "grant all", "revoke all", "allowloadlocalinfile", "allowloadlocalinfileinpath", "uselocalinfile", "init=", "javaobjectserializer=", "runscript", "serverstatusdiffinterceptor", "queryinterceptors=", "statementinterceptors=", "exceptioninterceptors=", "xp_cmdshell", "create function", "dbms_java", "sp_sysexecute", "load_file", "allowmultiqueries", "autodeserialize", "detectcustomcollations"};
    private static final String[] UNIVERSAL_BYPASS_PATTERNS = new String[]{".*create\\s*([/\\\\]|\\\\n|/n|\\n)\\s*trigger.*", ".*create\\s*([/\\\\]|\\\\n|/n|\\n)\\s*function.*", ".*drop\\s*([/\\\\]|\\\\n|/n|\\n)\\s*table.*", ".*drop\\s*([/\\\\]|\\\\n|/n|\\n)\\s*database.*", ".*run\\s*([/\\\\]|\\\\n|/n|\\n)\\s*script.*", ".*alter\\s*([/\\\\]|\\\\n|/n|\\n)\\s*system.*", ".*grant\\s*([/\\\\]|\\\\n|/n|\\n)\\s*all.*", ".*revoke\\s*([/\\\\]|\\\\n|/n|\\n)\\s*all.*", ".*xp\\s*([/\\\\]|\\\\n|/n|\\n)\\s*cmdshell.*", ".*load\\s*([/\\\\]|\\\\n|/n|\\n)\\s*file.*"};
    private static final HashMap<String, String[]> PLATFORM_BYPASS_PATTERNS = new HashMap();
    private final GlobalConnectionCache connectionCommonCache = GlobalConnectionCache.getInstance();

    public void preCheck(Metrics metrics) throws IllegalArgumentException {
        if (metrics == null || metrics.getJdbc() == null) {
            throw new IllegalArgumentException("Database collect must has jdbc params");
        }
        if (StringUtils.hasText((String)metrics.getJdbc().getUrl())) {
            String url = metrics.getJdbc().getUrl().toLowerCase();
            for (String keyword : VULNERABLE_KEYWORDS) {
                if (!url.contains(keyword.toLowerCase())) continue;
                throw new IllegalArgumentException("Jdbc url prohibit contains vulnerable param " + keyword);
            }
        }
        SshTunnelHelper.checkTunnelParam((SshTunnel)metrics.getJdbc().getSshTunnel());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void collect(CollectRep.MetricsData.Builder builder, Metrics metrics) {
        long startTime = System.currentTimeMillis();
        JdbcProtocol jdbcProtocol = metrics.getJdbc();
        SshTunnel sshTunnel = jdbcProtocol.getSshTunnel();
        int timeout = CollectUtil.getTimeout((String)jdbcProtocol.getTimeout());
        boolean reuseConnection = Boolean.parseBoolean(jdbcProtocol.getReuseConnection());
        Statement statement = null;
        try {
            String databaseUrl;
            if (sshTunnel != null && Boolean.parseBoolean(sshTunnel.getEnable())) {
                int localPort = SshTunnelHelper.localPortForward((SshTunnel)sshTunnel, (String)jdbcProtocol.getHost(), (String)jdbcProtocol.getPort());
                databaseUrl = this.constructDatabaseUrl(jdbcProtocol, "localhost", String.valueOf(localPort));
            } else {
                databaseUrl = this.constructDatabaseUrl(jdbcProtocol, jdbcProtocol.getHost(), jdbcProtocol.getPort());
            }
            statement = this.getConnection(jdbcProtocol.getUsername(), jdbcProtocol.getPassword(), databaseUrl, timeout, reuseConnection);
            switch (jdbcProtocol.getQueryType()) {
                case "oneRow": {
                    this.queryOneRow(statement, jdbcProtocol.getSql(), metrics.getAliasFields(), builder, startTime);
                    return;
                }
                case "multiRow": {
                    this.queryMultiRow(statement, jdbcProtocol.getSql(), metrics.getAliasFields(), builder, startTime);
                    return;
                }
                case "columns": {
                    this.queryOneRowByMatchTwoColumns(statement, jdbcProtocol.getSql(), metrics.getAliasFields(), builder, startTime);
                    return;
                }
                case "runScript": {
                    Connection connection = statement.getConnection();
                    FileSystemResource rc = new FileSystemResource(jdbcProtocol.getSql());
                    ScriptUtils.executeSqlScript((Connection)connection, (Resource)rc);
                    return;
                }
                default: {
                    builder.setCode(CollectRep.Code.FAIL);
                    builder.setMsg("Not support database query type: " + jdbcProtocol.getQueryType());
                    return;
                }
            }
        }
        catch (PSQLException psqlException) {
            if ("08001".equals(psqlException.getSQLState())) {
                builder.setCode(CollectRep.Code.UN_REACHABLE);
            } else {
                builder.setCode(CollectRep.Code.FAIL);
            }
            builder.setMsg("Error: " + psqlException.getMessage() + " Code: " + psqlException.getSQLState());
            return;
        }
        catch (SQLException sqlException) {
            log.warn("Jdbc sql error: {}, code: {}.", (Object)sqlException.getMessage(), (Object)sqlException.getErrorCode());
            builder.setCode(CollectRep.Code.FAIL);
            builder.setMsg("Query Error: " + sqlException.getMessage() + " Code: " + sqlException.getErrorCode());
            return;
        }
        catch (SshException sshException) {
            Throwable throwable = sshException.getCause();
            if (throwable instanceof SshChannelOpenException) {
                log.warn("[Jdbc collect] Remote ssh server no more session channel, please increase sshd_config MaxSessions.");
            }
            String errorMsg = CommonUtil.getMessageFromThrowable((Throwable)sshException);
            builder.setCode(CollectRep.Code.UN_CONNECTABLE);
            builder.setMsg("Peer ssh connection failed: " + errorMsg);
            return;
        }
        catch (Exception e) {
            String errorMessage = CommonUtil.getMessageFromThrowable((Throwable)e);
            log.error("Jdbc error: {}.", (Object)errorMessage, (Object)e);
            builder.setCode(CollectRep.Code.FAIL);
            builder.setMsg("Query Error: " + errorMessage);
            return;
        }
        finally {
            if (statement != null) {
                Connection connection = null;
                try {
                    connection = statement.getConnection();
                    statement.close();
                }
                catch (Exception e) {
                    log.error("Jdbc close statement error: {}", (Object)e.getMessage());
                }
                try {
                    if (!reuseConnection && connection != null) {
                        connection.close();
                    }
                }
                catch (Exception e) {
                    log.error("Jdbc close connection error: {}", (Object)e.getMessage());
                }
            }
        }
    }

    public String supportProtocol() {
        return "jdbc";
    }

    private Statement getConnection(String username, String password, String url, Integer timeout, boolean reuseConnection) throws Exception {
        CacheIdentifier identifier = CacheIdentifier.builder().ip(url).username(username).password(password).build();
        Statement statement = null;
        if (reuseConnection) {
            Optional cacheOption = this.connectionCommonCache.getCache((Object)identifier, true);
            if (cacheOption.isPresent()) {
                JdbcConnect jdbcConnect = (JdbcConnect)cacheOption.get();
                try {
                    statement = jdbcConnect.getConnection().createStatement();
                    int timeoutSecond = timeout / 1000;
                    timeoutSecond = timeoutSecond <= 0 ? 1 : timeoutSecond;
                    statement.setQueryTimeout(timeoutSecond);
                    statement.setMaxRows(1000);
                }
                catch (Exception e) {
                    log.info("The jdbc connect from cache, create statement error: {}", (Object)e.getMessage());
                    try {
                        if (statement != null) {
                            statement.close();
                        }
                        jdbcConnect.close();
                    }
                    catch (Exception e2) {
                        log.error(e2.getMessage());
                    }
                    statement = null;
                    this.connectionCommonCache.removeCache((Object)identifier);
                }
            }
            if (statement != null) {
                return statement;
            }
        }
        Connection connection = DriverManager.getConnection(url, username, password);
        connection.setReadOnly(true);
        statement = connection.createStatement();
        int timeoutSecond = timeout / 1000;
        timeoutSecond = timeoutSecond <= 0 ? 1 : timeoutSecond;
        statement.setQueryTimeout(timeoutSecond);
        statement.setMaxRows(1000);
        if (reuseConnection) {
            JdbcConnect jdbcConnect = new JdbcConnect(connection);
            this.connectionCommonCache.addCache((Object)identifier, (AbstractConnection)jdbcConnect);
        }
        return statement;
    }

    private void queryOneRow(Statement statement, String sql, List<String> columns, CollectRep.MetricsData.Builder builder, long startTime) throws Exception {
        statement.setMaxRows(1);
        try (ResultSet resultSet = statement.executeQuery(sql);){
            if (resultSet.next()) {
                CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
                for (String column : columns) {
                    if ("responseTime".equals(column)) {
                        long time = System.currentTimeMillis() - startTime;
                        valueRowBuilder.addColumn(String.valueOf(time));
                        continue;
                    }
                    String value = resultSet.getString(column);
                    value = value == null ? "&nbsp;" : value;
                    valueRowBuilder.addColumn(value);
                }
                builder.addValueRow(valueRowBuilder.build());
            }
        }
    }

    private void queryOneRowByMatchTwoColumns(Statement statement, String sql, List<String> columns, CollectRep.MetricsData.Builder builder, long startTime) throws Exception {
        try (ResultSet resultSet = statement.executeQuery(sql);){
            HashMap<String, String> values = new HashMap<String, String>(columns.size());
            while (resultSet.next()) {
                if (resultSet.getString(1) == null) continue;
                values.put(resultSet.getString(1).toLowerCase().trim(), resultSet.getString(2));
            }
            CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
            for (String column : columns) {
                if ("responseTime".equals(column)) {
                    long time = System.currentTimeMillis() - startTime;
                    valueRowBuilder.addColumn(String.valueOf(time));
                    continue;
                }
                String value = (String)values.get(column.toLowerCase());
                value = value == null ? "&nbsp;" : value;
                valueRowBuilder.addColumn(value);
            }
            builder.addValueRow(valueRowBuilder.build());
        }
    }

    private void queryMultiRow(Statement statement, String sql, List<String> columns, CollectRep.MetricsData.Builder builder, long startTime) throws Exception {
        try (ResultSet resultSet = statement.executeQuery(sql);){
            while (resultSet.next()) {
                CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
                for (String column : columns) {
                    if ("responseTime".equals(column)) {
                        long time = System.currentTimeMillis() - startTime;
                        valueRowBuilder.addColumn(String.valueOf(time));
                        continue;
                    }
                    String value = resultSet.getString(column);
                    value = value == null ? "&nbsp;" : value;
                    valueRowBuilder.addColumn(value);
                }
                builder.addValueRow(valueRowBuilder.build());
            }
        }
    }

    private String recursiveDecode(String url) {
        String prev;
        String decoded = url;
        int max = 5;
        do {
            prev = decoded;
            try {
                decoded = URLDecoder.decode(prev, StandardCharsets.UTF_8);
            }
            catch (Exception e) {
                break;
            }
        } while (!prev.equals(decoded) && --max > 0);
        return decoded;
    }

    private String constructDatabaseUrl(JdbcProtocol jdbcProtocol, String host, String port) {
        if (Objects.nonNull(jdbcProtocol.getUrl()) && !Objects.equals("", jdbcProtocol.getUrl()) && jdbcProtocol.getUrl().startsWith("jdbc")) {
            String platform;
            String[] platformPatterns;
            if (jdbcProtocol.getUrl().length() > 2048) {
                throw new IllegalArgumentException("JDBC URL length exceeds maximum limit of 2048 characters");
            }
            String cleanedUrl = jdbcProtocol.getUrl().replaceAll("[\\x00-\\x1F\\x7F\\xA0]", "");
            String url = this.recursiveDecode(cleanedUrl);
            if (!(url = url.toLowerCase()).matches("^jdbc:[a-zA-Z0-9]+:([^\\s;]+)(;[^\\s;]+)*$")) {
                throw new IllegalArgumentException("Invalid JDBC URL format");
            }
            for (String keyword : BLACK_LIST) {
                if (!url.contains(keyword.toLowerCase())) continue;
                throw new IllegalArgumentException("Invalid JDBC URL: contains potentially malicious parameter: " + keyword);
            }
            String normalizedUrl = url.replaceAll("[\\x00-\\x1F\\x7F\\xA0]", " ").toLowerCase();
            if (normalizedUrl.matches(".*jndi\\s*[:=].*") || normalizedUrl.matches(".*ldap\\s*[:=].*") || normalizedUrl.matches(".*rmi\\s*[:=].*") || normalizedUrl.matches(".*java\\s*[:=].*") || normalizedUrl.matches(".*serialization\\s*[:=].*") || normalizedUrl.matches(".*deserializ.*\\s*[:=].*") || normalizedUrl.matches(".*objectinputstream\\s*[:=].*") || normalizedUrl.matches(".*readobject\\s*[:=].*")) {
                throw new IllegalArgumentException("Invalid JDBC URL: contains potentially malicious JNDI or deserialization parameter");
            }
            for (String pattern : UNIVERSAL_BYPASS_PATTERNS) {
                if (!normalizedUrl.matches(pattern)) continue;
                throw new IllegalArgumentException("Invalid JDBC URL: contains potentially malicious bypass pattern");
            }
            if (jdbcProtocol.getPlatform() != null && (platformPatterns = PLATFORM_BYPASS_PATTERNS.get(platform = jdbcProtocol.getPlatform().toLowerCase())) != null) {
                for (String pattern : platformPatterns) {
                    if (!normalizedUrl.matches(pattern)) continue;
                    throw new IllegalArgumentException("Invalid " + platform.toUpperCase() + " JDBC URL: contains potentially malicious bypass pattern");
                }
            }
            return normalizedUrl;
        }
        assert (jdbcProtocol.getPlatform() != null);
        return switch (jdbcProtocol.getPlatform()) {
            case "mysql", "mariadb" -> "jdbc:mysql://" + host + ":" + port + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase()) + "?useUnicode=true&characterEncoding=utf-8&useSSL=false";
            case "postgresql" -> "jdbc:postgresql://" + host + ":" + port + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase());
            case "clickhouse" -> "jdbc:clickhouse://" + host + ":" + port + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase());
            case "sqlserver" -> "jdbc:sqlserver://" + host + ":" + port + ";" + (String)(jdbcProtocol.getDatabase() == null ? "" : "DatabaseName=" + jdbcProtocol.getDatabase()) + ";trustServerCertificate=true;";
            case "oracle" -> "jdbc:oracle:thin:@" + host + ":" + port + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase());
            case "dm" -> "jdbc:dm://" + host + ":" + port;
            case "testcontainers" -> "jdbc:tc:" + host + ":" + port + ":///" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase()) + "?user=root&password=root";
            default -> throw new IllegalArgumentException("Not support database platform: " + jdbcProtocol.getPlatform());
        };
    }

    static {
        PLATFORM_BYPASS_PATTERNS.put("h2", new String[]{".*(\\\\\\\\|/|\\\\|\\\\n|/n|\\n)\\s*init\\s*=.*", ".*in\\s*([/\\\\]|\\\\n|/n|\\n)\\s*it\\s*=.*", ".*(\\\\\\\\|/|\\\\|\\\\n|/n|\\n)\\s*runscript\\s+from.*", ".*ru\\s*([/\\\\]|\\\\n|/n|\\n)\\s*script\\s+from.*"});
        String[] mysqlPatterns = new String[]{".*allow\\s*([/\\\\]|\\\\n|/n|\\n)\\s*load\\s*([/\\\\]|\\\\n|/n|\\n)\\s*local\\s*([/\\\\]|\\\\n|/n|\\n)\\s*infile.*", ".*allow\\s*([/\\\\]|\\\\n|/n|\\n)\\s*multi\\s*([/\\\\]|\\\\n|/n|\\n)\\s*queries.*", ".*query\\s*([/\\\\]|\\\\n|/n|\\n)\\s*interceptors.*", ".*statement\\s*([/\\\\]|\\\\n|/n|\\n)\\s*interceptors.*", ".*exception\\s*([/\\\\]|\\\\n|/n|\\n)\\s*interceptors.*", ".*auto\\s*([/\\\\]|\\\\n|/n|\\n)\\s*deserialize.*"};
        PLATFORM_BYPASS_PATTERNS.put("mysql", mysqlPatterns);
        PLATFORM_BYPASS_PATTERNS.put("mariadb", mysqlPatterns);
        PLATFORM_BYPASS_PATTERNS.put("postgresql", new String[]{".*socket\\s*([/\\\\]|\\\\n|/n|\\n)\\s*factory.*", ".*logger\\s*([/\\\\]|\\\\n|/n|\\n)\\s*file.*", ".*ssl\\s*([/\\\\]|\\\\n|/n|\\n)\\s*mode.*", ".*logger\\s*([/\\\\]|\\\\n|/n|\\n)\\s*level.*"});
        PLATFORM_BYPASS_PATTERNS.put("sqlserver", new String[]{".*integrated\\s*([/\\\\]|\\\\n|/n|\\n)\\s*security.*", ".*authentication\\s*([/\\\\]|\\\\n|/n|\\n)\\s*scheme.*", ".*select\\s*([/\\\\]|\\\\n|/n|\\n)\\s*method.*", ".*send\\s*([/\\\\]|\\\\n|/n|\\n)\\s*string\\s*([/\\\\]|\\\\n|/n|\\n)\\s*parameters\\s*([/\\\\]|\\\\n|/n|\\n)\\s*as\\s*([/\\\\]|\\\\n|/n|\\n)\\s*unicode.*", ".*x\\s*([/\\\\]|\\\\n|/n|\\n)\\s*open\\s*([/\\\\]|\\\\n|/n|\\n)\\s*state.*", ".*application\\s*([/\\\\]|\\\\n|/n|\\n)\\s*intent.*"});
        PLATFORM_BYPASS_PATTERNS.put("clickhouse", new String[]{".*custom\\s*([/\\\\]|\\\\n|/n|\\n)\\s*http\\s*([/\\\\]|\\\\n|/n|\\n)\\s*params.*", ".*http\\s*([/\\\\]|\\\\n|/n|\\n)\\s*connection\\s*([/\\\\]|\\\\n|/n|\\n)\\s*provider.*", ".*check\\s*([/\\\\]|\\\\n|/n|\\n)\\s*all\\s*([/\\\\]|\\\\n|/n|\\n)\\s*nodes.*", ".*fail\\s*([/\\\\]|\\\\n|/n|\\n)\\s*over.*", ".*use\\s*([/\\\\]|\\\\n|/n|\\n)\\s*objects\\s*([/\\\\]|\\\\n|/n|\\n)\\s*in\\s*([/\\\\]|\\\\n|/n|\\n)\\s*arrays.*"});
        PLATFORM_BYPASS_PATTERNS.put("oracle", new String[]{".*oracle\\s*([/\\\\]|\\\\n|/n|\\n)\\s*jdbc.*", ".*oracle\\s*([/\\\\]|\\\\n|/n|\\n)\\s*net.*", ".*oracle\\.jdbc\\.timezoneinfotable\\s*=.*", ".*oracle\\.net\\.wallet_location\\s*=.*", ".*oracle\\.net\\.ssl_server_dn_match\\s*=\\s*false.*", ".*oracle\\.jdbc\\.enablesqlinjectionattack\\s*=\\s*true.*", ".*oracle\\.jdbc\\.implicitstatementcachesize\\s*=\\s*0.*", ".*oracle\\.jdbc\\.timezoneinfotable\\s*=.*", ".*oracle\\.net\\.wallet_location\\s*=.*", ".*oracle\\.net\\.ssl_server_dn_match\\s*=\\s*false.*", ".*oracle\\.jdbc\\.enablesqlinjectionattack\\s*=\\s*true.*", ".*oracle\\.jdbc\\.implicitstatementcachesize\\s*=\\s*0.*"});
        PLATFORM_BYPASS_PATTERNS.put("dm", new String[]{".*login\\s*([/\\\\]|\\\\n|/n|\\n)\\s*mode.*", ".*compatible\\s*([/\\\\]|\\\\n|/n|\\n)\\s*mode.*", ".*en\\s*([/\\\\]|\\\\n|/n|\\n)\\s*crypt.*", ".*ci\\s*([/\\\\]|\\\\n|/n|\\n)\\s*pher.*"});
    }
}

