Skip to content

Commit

Permalink
Merge pull request #14 from amoffat/dev
Browse files Browse the repository at this point in the history
Release 1.0.0
  • Loading branch information
amoffat authored Sep 11, 2023
2 parents f4c0e4b + 136c39c commit 9836d62
Show file tree
Hide file tree
Showing 78 changed files with 2,777 additions and 613 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ __pycache__/
/.venv
/.env
/test_grammar.py
/test_transform.py
/dist
/.coverage
/TODO.md
TODO.md
/docs/build
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": ["heimdallm", "-s", "-x"],
"python.testing.pytestArgs": ["heimdallm", "-s"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"editor.rulers": [88],
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## 1.0.0 - 9/11/23

- Subquery support
- CTE (Common Table Expressions) support
- PostgreSQL support
- Renamed constraint method `required_constraints` to `parameterized_constraints`
- Renamed Bifrost method `mocked` to `validation_only`
- All exceptions include a `ctx` property for debugging
- Mysql `INTERVAL` syntax support

## 0.3.0 - 7/15/23

- Autofix non-qualified column names
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ the mitigations.

- Sqlite
- MySQL
- Postgres

There is active development for the other top relational SQL databases. To help me
prioritize, please vote on which database you would like to see supported:
Expand Down
5 changes: 5 additions & 0 deletions docs/source/api/abc/context.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Traverse Context
================

.. automodule:: heimdallm.context
:members:
1 change: 1 addition & 0 deletions docs/source/api/abc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ intended for direct use.
envelope
validator
llm_integration
context

sql/index
2 changes: 1 addition & 1 deletion docs/source/attack_surface/sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Optional conditions

When required conditions are defined, either as a :meth:`requester identity
<validator.ConstraintValidator.requester_identities>`, or as some other
:meth:`required constraint <validator.ConstraintValidator.required_constraints>`, an
:meth:`parameterized constraint <validator.ConstraintValidator.parameterized_constraints>`, an
attacker may attempt to bypass the condition by coaxing the LLM to produce a query that
includes the condition as part of an ``OR`` clause. For example:

Expand Down
8 changes: 8 additions & 0 deletions docs/source/blog/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
📖 Blog Posts
=============

.. toctree::
:glob:
:maxdepth: 2

posts/*
437 changes: 437 additions & 0 deletions docs/source/blog/posts/safe-sql-execution.rst

Large diffs are not rendered by default.

Binary file added docs/source/images/bridge-of-death.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/gattaca.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/ghost.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/grail.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/sam-gerty.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/smiley.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/truman.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ browse the navigation on the left.
:maxdepth: 5

quickstart
blog/index
bifrost
api/index
reconstruction
Expand Down
8 changes: 4 additions & 4 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ First let's set up our imports.
from heimdallm.bifrosts.sql.sqlite.select.bifrost import Bifrost
from heimdallm.bifrosts.sql.sqlite.select.envelope import PromptEnvelope
from heimdallm.bifrosts.sql.sqlite.select.validator import ConstraintValidator
from heimdallm.bifrosts.sql.common import FqColumn, JoinCondition, RequiredConstraint
from heimdallm.bifrosts.sql.common import FqColumn, JoinCondition, ParameterizedConstraint
from heimdallm.llm_providers import openai
logging.basicConfig(level=logging.ERROR)
Expand Down Expand Up @@ -76,15 +76,15 @@ the methods that you can override in the derived class, look :doc:`here.
.. code-block:: python
class CustomerConstraintValidator(SQLConstraintValidator):
def requester_identities(self) -> Sequence[RequiredConstraint]:
def requester_identities(self) -> Sequence[ParameterizedConstraint]:
return [
RequiredConstraint(
ParameterizedConstraint(
column="customer.customer_id",
placeholder="customer_id",
),
]
def required_constraints(self) -> Sequence[RequiredConstraint]:
def parameterized_constraints(self) -> Sequence[ParameterizedConstraint]:
return []
def select_column_allowed(self, column: FqColumn) -> bool:
Expand Down
18 changes: 12 additions & 6 deletions docs/source/roadmap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ More databases

I will be adding support for more SQL-based databases:

* MySQL
* PostgreSQL
* SQL Server
* Oracle
* Snowflake
Expand Down Expand Up @@ -48,10 +46,18 @@ Could produce a validated SQL query:
INSERT INTO calendar (title, when, user_id)
VALUES ('Dinner at 7 on friday', '2023-07-01 19:00:00', 123)
Generalized constraint spec
***************************

The current implementation requires a Python application to use HeimdaLLM, because
constraint validators are defined by subclassing a Python class. A future implementation
could be language agnostic by providing an api and a JSON or YAML spec for constraining
LLM output.

More Bifrosts
*************

Bifrosts are not limited to converting human input to trusted SQL statements.
HeimdaLLM's has been deliberately designed to be general enough to support many kinds
of structured output. I intend to develop more powerful Bifrosts that supercharge your
ability to provide natural language interaction with your application. Stay tuned!
Bifrosts are not limited to converting human input to trusted SQL statements. HeimdaLLM
is generalized enough to support many kinds of structured output. I intend to develop
more Bifrosts that facilitate natural language interactions with your application.
Stay tuned!
40 changes: 28 additions & 12 deletions heimdallm/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import structlog
from lark import Lark, ParseTree

from heimdallm.context import TraverseContext

if TYPE_CHECKING:
import heimdallm.constraints
import heimdallm.envelope
Expand Down Expand Up @@ -45,6 +47,7 @@ def __init__(
self.grammar = grammar
self.tree_producer = tree_producer
self.constraint_validators = constraint_validators
self.ctx = TraverseContext()

def traverse(
self,
Expand All @@ -64,7 +67,9 @@ def traverse(
:return: The trusted LLM output.
"""

log = LOG.bind(input=untrusted_human_input, autofix=autofix)
self.ctx.untrusted_human_input = untrusted_human_input

log = LOG.bind(autofix=autofix)
log.info("Traversing untrusted input")

# wrap the untrusted input in our prompt
Expand All @@ -75,7 +80,7 @@ def traverse(
# talk to our LLM
log.info("Sending envelope to LLM")
untrusted_llm_output: str = self.llm.complete(untrusted_llm_input)
log = log.bind(llm_output=untrusted_llm_output)
self.ctx.untrusted_llm_output = untrusted_llm_output
log.info("Received raw result from LLM")

# trim any cruft off of the LLM output
Expand All @@ -85,7 +90,7 @@ def traverse(
except Exception as e:
log.exception("Unwrap failed")
raise e
log = log.bind(unwrapped=untrusted_llm_output)
self.ctx.untrusted_llm_output = untrusted_llm_output
log.info("Unwrap succeeded")

# throws a parse error
Expand All @@ -106,11 +111,12 @@ def traverse(
)
try:
trusted_llm_output, tree = self._try_validator(
log,
validator,
autofix,
untrusted_llm_output,
tree,
log=log,
validator=validator,
untrusted_llm_output=untrusted_llm_output,
autofix=autofix,
ctx=self.ctx,
tree=tree,
)
except Exception as e:
validation_exc = e
Expand All @@ -128,27 +134,33 @@ def traverse(
log.exception("Validation failed")
raise e

log = log.bind(trusted=untrusted_llm_output)
log.info("Validation succeeded")

trusted_llm_output = self.post_transform(trusted_llm_output, tree)
self.ctx.trusted_llm_output = trusted_llm_output

return trusted_llm_output

def _try_validator(
self,
*,
log: structlog.BoundLogger,
validator: "heimdallm.constraints.ConstraintValidator",
autofix: bool,
untrusted_llm_output: str,
ctx: TraverseContext,
tree: ParseTree,
) -> tuple[str, ParseTree]:
"""Attempt validation with an individual constraint validator."""

if autofix:
log.info("Autofixing parse tree and reconstructing the input")
try:
untrusted_llm_output = validator.fix(self, self.grammar, tree)
untrusted_llm_output = validator.fix(
bifrost=self,
grammar=self.grammar,
tree=tree,
ctx=ctx,
)
except Exception as e:
log.exception("Autofix failed")
raise e
Expand All @@ -162,7 +174,11 @@ def _try_validator(

# throws a bifrost-specific exception
log.info("Validating parse tree")
validator.validate(self, untrusted_llm_output, tree)
validator.validate(
bifrost=self,
tree=tree,
ctx=ctx,
)
log.info("Validation succeeded")

return untrusted_llm_output, tree
Expand Down
29 changes: 15 additions & 14 deletions heimdallm/bifrosts/sql/bifrost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from heimdallm.bifrost import Bifrost as _BaseBifrost
from heimdallm.bifrosts.sql import exc
from heimdallm.bifrosts.sql.visitors.id_setter import IdSetter
from heimdallm.bifrosts.sql.visitors.parent import ParentSetter
from heimdallm.llm import LLMIntegration
from heimdallm.llm_providers.mock import EchoMockLLM

Expand All @@ -31,23 +33,20 @@ class Bifrost(_BaseBifrost, ABC):
needs to succeed for validation to pass.
"""

# for tests
@classmethod
def mocked(
def validation_only(
cls,
constraint_validators: Union[
"heimdallm.bifrosts.sql.validator.ConstraintValidator",
Sequence["heimdallm.bifrosts.sql.validator.ConstraintValidator"],
],
):
"""A convenience method for our tests. This creates a Bifrost that assumes its
untrusted input is a SQL query already, so it does not need to communicate with
the LLM, only parse and validate it.
"""A convenience method for doing just constraint validation. This creates a
Bifrost that assumes its untrusted input is a SQL query already, so it does not
need to communicate with the LLM, only parse and validate it.
:param constraint_validators: A constraint validator or sequence of constraint
validators to run on the untrusted input.
:meta private:
"""
if not isinstance(constraint_validators, Sequence):
constraint_validators = [constraint_validators]
Expand Down Expand Up @@ -92,8 +91,7 @@ def reserved_keywords(cls) -> set[str]:
"""
raise NotImplementedError

@classmethod
def build_tree_producer(cls) -> Callable[[Lark, str], ParseTree]:
def build_tree_producer(self) -> Callable[[Lark, str], ParseTree]:
"""
Produces a that can create a single parse tree. May be implemented in a subclass
if you want to do custom ambiguity resolution.
Expand All @@ -106,13 +104,16 @@ def parse(grammar: Lark, untrusted_query: str) -> ParseTree:
ambig_tree = grammar.parse(untrusted_query)
try:
final_tree = AmbiguityResolver(
untrusted_query,
cls.reserved_keywords(),
ctx=self.ctx,
reserved_keywords=self.reserved_keywords(),
).transform(ambig_tree)
except VisitError as e:
if isinstance(e.orig_exc, exc.BaseException):
raise e.orig_exc
raise e

final_tree = ParentSetter().visit(final_tree)
final_tree = IdSetter().visit(final_tree)
return final_tree

return parse
Expand Down Expand Up @@ -178,10 +179,10 @@ def parse(self, untrusted_llm_output: str) -> ParseTree:
try:
return super().parse(untrusted_llm_output)
except lark.exceptions.UnexpectedEOF as e:
raise exc.InvalidQuery(query=untrusted_llm_output) from e
raise exc.InvalidQuery(ctx=self.ctx) from e
except lark.exceptions.UnexpectedCharacters as e:
raise exc.InvalidQuery(query=untrusted_llm_output) from e
raise exc.InvalidQuery(ctx=self.ctx) from e
except exc.BaseException as e:
raise e
except Exception as e:
raise exc.InvalidQuery(query=untrusted_llm_output) from e
raise exc.InvalidQuery(ctx=self.ctx) from e
16 changes: 8 additions & 8 deletions heimdallm/bifrosts/sql/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Any, Optional, Sequence

from . import exc


class RequiredConstraint:
class ParameterizedConstraint:
"""This represents a constraint that *must* be applied to the query.
In the query, this comes in the form of ``table.column=:placeholder``. Enforced by
Expand All @@ -23,7 +21,7 @@ def __init__(self, *, column: str, placeholder: str):

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, RequiredConstraint)
isinstance(other, ParameterizedConstraint)
and other.fq_column == self.fq_column
and other.placeholder == self.placeholder
)
Expand Down Expand Up @@ -60,7 +58,9 @@ def from_string(cls, fq_column_name: str) -> "FqColumn":
:param fq_column_name: The fully-qualified column name.
:raises UnqualifiedColumn: If the string does not contain a period."""
if "." not in fq_column_name:
raise exc.UnqualifiedColumn(fq_column_name)
raise RuntimeError(
f"Expected fully-qualified column name: {fq_column_name}"
)
table, column = fq_column_name.split(".")
return cls(table=table, column=column)

Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(self, first: str, second: str, *, identity: Optional[str] = None):
self.identity_placeholder = identity

@property
def requester_identities(self) -> Sequence[RequiredConstraint]:
def requester_identities(self) -> Sequence[ParameterizedConstraint]:
"""If this join condition has been marked as an identity join,
construct the required constraints for both sides of the join. We'll use those
constraints when testing for the requester's identity.
Expand All @@ -121,11 +121,11 @@ def requester_identities(self) -> Sequence[RequiredConstraint]:
"""
if self.identity_placeholder:
return [
RequiredConstraint(
ParameterizedConstraint(
column=self.first.name,
placeholder=self.identity_placeholder,
),
RequiredConstraint(
ParameterizedConstraint(
column=self.second.name,
placeholder=self.identity_placeholder,
),
Expand Down
Loading

0 comments on commit 9836d62

Please sign in to comment.