diff --git a/h2o-core/src/main/java/water/jdbc/SQLManager.java b/h2o-core/src/main/java/water/jdbc/SQLManager.java index 855b20838687..8b148d706a81 100644 --- a/h2o-core/src/main/java/water/jdbc/SQLManager.java +++ b/h2o-core/src/main/java/water/jdbc/SQLManager.java @@ -5,10 +5,20 @@ import water.parser.ParseDataset; import water.util.Log; +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLDecoder; import java.sql.*; +import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.atomic.AtomicLong; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class SQLManager { @@ -30,6 +40,15 @@ public class SQLManager { private static final String TMP_TABLE_ENABLED = H2O.OptArgs.SYSTEM_PROP_PREFIX + "sql.tmp_table.enabled"; + private static final String DISALLOWED_JDBC_PARAMETERS_PARAM = H2O.OptArgs.SYSTEM_PROP_PREFIX + "sql.jdbc.disallowed.parameters"; + + private static final Pattern JDBC_PARAMETERS_REGEX_PATTERN = Pattern.compile("(?i)[?;&]([a-z]+)="); + + private static final List DEFAULT_JDBC_DISALLOWED_PARAMETERS = Stream.of( + "autoDeserialize", "queryInterceptors", "allowLoadLocalInfile", "allowMultiQueries", //mysql + "allowLoadLocalInfileInPath", "allowUrlInLocalInfile", "allowPublicKeyRetrieval", //mysql + "init", "script", "shutdown" //h2 + ).map(String::toLowerCase).collect(Collectors.toList()); private static AtomicLong NEXT_TABLE_NUM = new AtomicLong(0); static Key nextTableKey(String prefix, String postfix) { @@ -58,6 +77,7 @@ public static Job importSqlTable( final String username, final String password, final String columns, final Boolean useTempTable, final String tempTableName, final SqlFetchMode fetchMode, final Integer numChunksHint) { + validateJdbcUrl(connection_url); final Key destination_key = nextTableKey(table, "sql_to_hex"); final Job j = new Job<>(destination_key, Frame.class.getName(), "Import SQL Table"); @@ -533,6 +553,7 @@ private static int estimateConcurrentConnections(final int cloudSize, final shor * @throws SQLException if a database access error occurs or the url is */ public static Connection getConnectionSafe(String url, String username, String password) throws SQLException { + validateJdbcUrl(url); initializeDatabaseDriver(getDatabaseType(url)); try { return DriverManager.getConnection(url, username, password); @@ -588,6 +609,30 @@ static void initializeDatabaseDriver(String databaseType) { } } + public static void validateJdbcUrl(String jdbcUrl) throws IllegalArgumentException { + if (jdbcUrl == null || jdbcUrl.trim().isEmpty()) { + throw new IllegalArgumentException("JDBC URL is null or empty"); + } + + if (!jdbcUrl.toLowerCase().startsWith("jdbc:")) { + throw new IllegalArgumentException("JDBC URL must start with 'jdbc:'"); + } + + Matcher matcher = JDBC_PARAMETERS_REGEX_PATTERN.matcher(jdbcUrl); + String property = System.getProperty(DISALLOWED_JDBC_PARAMETERS_PARAM); + List disallowedParameters = property == null ? + DEFAULT_JDBC_DISALLOWED_PARAMETERS : + Arrays.stream(property.split(",")).map(String::toLowerCase).collect(Collectors.toList()); + + while (matcher.find()) { + String key = matcher.group(1); + if (disallowedParameters.contains(key.toLowerCase())) { + throw new IllegalArgumentException("Potentially dangerous JDBC parameter found: " + key + + ". That behavior can be altered by setting " + DISALLOWED_JDBC_PARAMETERS_PARAM + " env variable to another comma separated list."); + } + } + } + static class SqlTableToH2OFrameStreaming { final String _table, _columns, _databaseType; final int _numCol; diff --git a/h2o-core/src/test/java/water/jdbc/SQLManagerTest.java b/h2o-core/src/test/java/water/jdbc/SQLManagerTest.java index d52febaa50f8..179311eebeec 100644 --- a/h2o-core/src/test/java/water/jdbc/SQLManagerTest.java +++ b/h2o-core/src/test/java/water/jdbc/SQLManagerTest.java @@ -145,4 +145,24 @@ public void testBuildSelectChunkSql() { Assert.assertEquals("SELECT * FROM mytable LIMIT 1310 OFFSET 0", SQLManager.buildSelectChunkSql("", "mytable", 0, 1310, "*", null)); } + + @Test + public void testValidateJdbcConnectionStringH2() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Potentially dangerous JDBC parameter found: init"); + + String h2MaliciousJdbc = "jdbc:h2:mem:test;MODE=MSSQLServer;init=CREATE ALIAS RBT AS '@groovy.transform.ASTTest(value={ assert java.lang.Runtime.getRuntime().exec(\"reboot\")" + "})" + "def rbt" + "'"; + + SQLManager.validateJdbcUrl(h2MaliciousJdbc); + } + + @Test + public void testValidateJdbcConnectionStringMysql() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Potentially dangerous JDBC parameter found: autoDeserialize"); + + String mysqlMaliciousJdbc = "jdbc:mysql://domain:123/test?autoDeserialize=true&queryInterceptors=com.mysql.cj.jdbc.interceptors.ServerStatusDiffInterceptor&user=abcd"; + + SQLManager.validateJdbcUrl(mysqlMaliciousJdbc); + } }