From b0b007363e87ad0a370e43574c5de4aad0036b8b Mon Sep 17 00:00:00 2001 From: David Nowinsky Date: Thu, 12 Dec 2024 14:26:47 +0100 Subject: [PATCH] fix(common): prevent injection during jinja templating (#1850) * fix(common): prevent injection during jinja templating * refactor: avoid using global jinja_env Instanciating an env doesn't cost a lot, so it's best to avoid any globals --- tests/test_common.py | 20 +++++++++++++++++++ toucan_connectors/common.py | 16 +++++++++------ .../snowflake/snowflake_connector.py | 4 ++-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index 2ec7ee4aa..7434d02be 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,7 @@ from datetime import date, datetime, timedelta from typing import Any +import jinja2 import numpy as np import pandas as pd import pytest @@ -285,6 +286,25 @@ def test_nosql_apply_parameters_to_query_error_on_params(query: dict, params: di nosql_apply_parameters_to_query(query, params, handle_errors=True) +def test_nosql_apply_parameters_to_query_unsafe(): + """ + It should prevent any code execution, by using Jinja's sandboxed environement + """ + with pytest.raises(jinja2.exceptions.SecurityError): + nosql_apply_parameters_to_query( + { + "test": "{% for x in var.__class__.__base__.__subclasses__() %}" + + "{% if 'warning' in x.__name__ %}" + + "{{x()._module.__builtins__ ['__import__']" + + "('os').popen('ls').read()}}" + + "{% endif %}{% endfor %}" + }, + {"var": "plop"}, + ) + with pytest.raises(jinja2.exceptions.SecurityError): + nosql_apply_parameters_to_query({"test": "{{ var.__class__.mro()[-1] }}"}, {"var": "plop"}) + + def test_nosql_apply_parameters_to_query_dot(): """It should handle both `x["y"]` and `x.y`""" query1 = {"facet": "{{ facet.value }}", "sort": "{{ rank[0] }}", "rows": "{{ bibou[0].value }}"} diff --git a/toucan_connectors/common.py b/toucan_connectors/common.py index 6a70b0e91..fec88ffde 100644 --- a/toucan_connectors/common.py +++ b/toucan_connectors/common.py @@ -8,8 +8,9 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable -from jinja2 import Environment, Template, Undefined, UndefinedError, meta +from jinja2 import Environment, Undefined, UndefinedError, meta from jinja2.nativetypes import NativeEnvironment +from jinja2.sandbox import ImmutableSandboxedEnvironment from pydantic import Field from toucan_connectors.utils.slugify import slugify @@ -19,6 +20,9 @@ import sqlalchemy as sa +class NativeImmutableSandboxedEnvironment(NativeEnvironment, ImmutableSandboxedEnvironment): ... + + # Query interpolation RE_PARAM = r"%\(([^(%\()]*)\)s" @@ -64,7 +68,7 @@ def is_jinja_alone(s: str) -> bool: def _has_parameters(query: str) -> bool: - t = Environment().parse(query) # noqa: S701 + t = ImmutableSandboxedEnvironment().parse(query) # noqa: S701 return bool(meta.find_undeclared_variables(t) or re.search(RE_PARAM, query)) @@ -165,9 +169,9 @@ def _render_query(query: dict | list[dict] | tuple | str, parameters: dict | Non if is_jinja_alone(query): clean_p = _prepare_parameters(clean_p) # type:ignore[assignment] - env: Environment | NativeEnvironment = NativeEnvironment() + env: Environment | NativeEnvironment = NativeImmutableSandboxedEnvironment() else: - env = Environment() # noqa: S701 + env = ImmutableSandboxedEnvironment() # noqa: S701 try: res = env.from_string(query).render(clean_p) @@ -237,7 +241,7 @@ def _flatten_dict(p, parent_key=""): parameters.update(p_keep_type) logging.getLogger(__name__).debug(f"Render query: {query} with parameters {parameters}") - return Template(query).render(parameters) + return ImmutableSandboxedEnvironment().from_string(query).render(parameters) # jq filtering @@ -476,7 +480,7 @@ def pandas_read_sql( if convert_to_printf: query = convert_to_printf_templating_style(query) if render_user: - query = Template(query).render({"user": params.get("user", {})}) + query = ImmutableSandboxedEnvironment().from_string(query).render({"user": params.get("user", {})}) if convert_to_qmark: query, params = convert_to_qmark_paramstyle(query, params) if convert_to_numeric: diff --git a/toucan_connectors/snowflake/snowflake_connector.py b/toucan_connectors/snowflake/snowflake_connector.py index 9416287e6..2a1ad3941 100644 --- a/toucan_connectors/snowflake/snowflake_connector.py +++ b/toucan_connectors/snowflake/snowflake_connector.py @@ -29,7 +29,7 @@ import pandas as pd import requests import snowflake - from jinja2 import Template + from jinja2.sandbox import ImmutableSandboxedEnvironment from snowflake import connector as sf_connector from snowflake.connector import SnowflakeConnection from snowflake.connector.cursor import DictCursor as SfDictCursor @@ -268,7 +268,7 @@ def get_status(self) -> ConnectorStatus: def get_connection_params(self) -> dict[str, str | int | None]: params: dict[str, str | int | None] = { - "user": Template(self.user).render(), + "user": ImmutableSandboxedEnvironment().from_string(self.user).render(), "account": self.account, "authenticator": self.authentication_method, # hard Snowflake params