diff --git a/cd4ml/train.py b/cd4ml/train.py index 07db03c..073db07 100644 --- a/cd4ml/train.py +++ b/cd4ml/train.py @@ -26,6 +26,6 @@ def get_trained_model(algorithm_name, logger.info('n_rows: %s, n_cols: %s' % (n_rows, n_cols)) trained_model = train_model(encoded_train_data, target_data, algorithm_name, - algorithm_params, seed=seed) + algorithm_params, seed=seed, max_features=5) return trained_model