diff --git a/.gitignore b/.gitignore index 65711ab..9ffaecb 100644 --- a/.gitignore +++ b/.gitignore @@ -56,8 +56,6 @@ # eggs mandala.egg-info/* -mandala/demos/data/ -mandala/__archive__ scratchpad/ mandala/scipy/ @@ -83,4 +81,7 @@ mandala/tests/output/* .idea/** # ignore docs build -site/ \ No newline at end of file +site/ + +# Ignore previous version +mandala/_prev/ \ No newline at end of file diff --git a/README.md b/README.md index 84d04f8..4ab2e29 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,7 @@

- logo + logo
-Blog post | Install | Quickstart | Testimonials | @@ -12,520 +11,125 @@ Tutorials
-[![Gitter](https://img.shields.io/gitter/room/amakelov/mandala)](https://app.gitter.im/#/room/#mandala:gitter.im) - # Computations that save, query and version themselves -![tldr](https://user-images.githubusercontent.com/1467702/231244639-f318af0e-3993-4ad1-8822-e8d889003dc1.gif) -
-`mandala` is a framework for experiment tracking and [incremental -computing](https://en.wikipedia.org/wiki/Incremental_computing) with a simple -plain-Python interface. It automates away the pain of experiment data management -(*did I run this already? what's in this file? where's the result of that -computation?* ...) with the following mix of features: -- **plain-Python**: decorate the functions whose calls you want to save, and - just write ordinary Python code using them - including data structures -and control flow. The results are automatically accessible upon a -re-run, queriable, and versioned. No need to save, load, or name anything by yourself -- **never compute the same thing twice**: `mandala` saves the result of each - function call, and (hashes of) the inputs and the dependencies - (functions/globals) accessed by the call. If later the inputs and dependencies - are the same, it just loads the results from storage. -- **query by pattern-matching directly to computational code**: your code - already knows the relationships between variables in your project! `mandala` - lets you produce tables relating any variables by [directly pointing to - the code establishing the wanted relationship between them](#query-by-directly-pattern-matching-python-code). -- **fine-grained versioning that's under your control**: each function's source - code has its own "mini git repo" behind the scenes, and each call tracks the - versions of all dependencies that went into it. You can decide when a change - to a dependency (e.g. refactoring a function to improve readability) doesn't change its semantics (so calls dependent on it won't be recomputed). - - -## Install +`mandala` eliminates the effort and code overhead of ML experiment tracking with +two tools: + +1. The `@op` decorator: + - Automatically captures inputs, outputs and code (+dependencies) of Python function calls + - Automatically reuses past results & never computes the same call twice + - Designed to be composed into end-to-end persisted programs, enabling + efficient iterative development in plain-Python, without thinking about the + storage backend. + +2. The `ComputationFrame` data structure: + - Generalized dataframes: "columns" are a computation graph, "rows" are + (partial) executions of the graph + - Automates exploration, querying and high-level operations over + heterogeneous "webs" of `@op` calls + - Can be converted to a `DataFrame` of execution traces for downstream + analysis of relationships between variables in a project + +# Install ``` pip install git+https://github.com/amakelov/mandala ``` -## Testimonials - -> "`mandala` addresses a core challenge in my notebook workflow: being able to -> explore data with code, without having to worry about losing the results of -> expensive calculations." - *Adam Jermyn, Member of Technical Staff, Anthropic* - - -## Video walkthroughs -### Rapidly iterate on a project with memoization -![mem](https://user-images.githubusercontent.com/1467702/231246050-21855bb2-6ce0-43d6-b7c0-3e0ed2a68f28.gif) - -Decorate the functions you want to memoize with `@op`, and compose programs out -of them by chaining their inputs and outputs using ordinary **control flow** and -**collections**. Every such program is **end-to-end memoized**: -- it becomes an **imperative query interface** to its own results by - (quickly) re-executing the same code, or parts of it -- it is **incrementally extensible** with new logic and parameters in-place, - which makes it both easy and efficient to interact with experiments - -### Query by directly pattern-matching Python code -![query](https://user-images.githubusercontent.com/1467702/231246102-276d7ae9-3a7f-46f8-9899-ae9dcf4f0484.gif) - -Any computation in a `with storage.run():` block is also a -**declarative query interface** to analogous computations in the entire storage: -- **get a table of all values with the same computational history**: -`storage.similar(x, y, ...)` returns a table of all values in the storage that -were computed in the same ways as `x, y, ...`, but possibly starting from -different inputs -- **query through collections**: queries propagate provenance from a collection - to its elements and vice-versa. This means you can query through operations - that aggregate many objects into one, or produce many objects from a fixed - number. - - **NOTE**: a collection in a computation can pattern-match any collection in - the storage with the same *kinds* of elements (as given by their computational - history), but not necessarily in the same order or quantity. This ensures that - you don't only match to the specific computation you have, but all *analogous* - computations too. -- **define queries explicitly**: for full control, use the `with storage.query():` context manager. For more on how this works, see [below](#explicit-declarative-queries-with-storagequery) - -### Automatic per-call versioning and dependency tracking -![deps](https://user-images.githubusercontent.com/1467702/231246159-fc8996a1-0987-4cec-9f0d-f0408609886e.gif) - -`mandala` comes with a very fine-grained versioning system: -- **per-call dependency tracking**: automatically track the functions and global -variables accessed by each memoized call, and alert you to changes in them, so -you can (carefully) choose whether a change to a dependency requires -recomputation of dependent calls (like bug fixes and changes to logic) or not -(like refactoring, comments, and logging) -- **the code determines all versions automatically**: use the current state of -each dependency in your codebase to automatically determine the currently -compatible versions of each memoized function to use in computation and queries. -In particular, this means that: - - **you can go "back in time"** and access the storage relative to an earlier - state of the code (or even branch in a new direction like in `git`) by - just restoring this state - - **the code is the truth**: when in doubt about the meaning of a result, you - can just look at the current code. - -## Quickstart -```python -from mandala.imports import * - -# the storage saves calls and tracks dependencies, versions, etc. -storage = Storage( - deps_path='__main__' # track dependencies in current session - ) - -@op # memoization (and more) decorator -def increment(x: int) -> int: # always indicate number of outputs in return type - print('hi from increment!') - return x + 1 - -increment(23) # function acts normally - -with storage.run(): # context manager that triggers `mandala` - y = increment(23) # now it's memoized w.r.t. this version of `increment` - -print(y) # result wrapped with metadata. -print(unwrap(y)) # `unwrap` gets the raw value - -with storage.run(): - y = increment(23) # loads result from `storage`; doesn't execute `increment` - -@op # type-annotate data structures to store elts separately -def average(nums: list) -> float: - print('hi from average!') - return sum(nums) / len(nums) - -# memoized functions are designed to be composed! -with storage.run(): - # sliding averages of `increment`'s results over 3 elts - nums = [increment(i) for i in range(5)] - for i in range(3): - result = average(nums[i:i+3]) - -# get a table of all values similar to `result` in `storage`, -# i.e., computed as average([increment(something), ...]) -# read the message this prints out! -print(storage.similar(result, context=True)) - -# change implementation of `increment` and re-run -# you'll be asked if the change requires recomputing dependencies (say yes) -@op -def increment(x: int) -> int: - print('hi from new increment!') - return x + 2 - -with storage.run(): - nums = [increment(i) for i in range(5)] - for i in range(3): - # only one call to `average` is executed! - result = average(nums[i:i+3]) - -# query is ran against the *new* version of `increment` -print(storage.similar(result, context=True)) -``` - -## Basic usage -This is a quick guide on how to get up to speed with the core features and avoid -common pitfalls. - -- [defining storages and memoized functions](#storage-and-the-op-decorator) -- [memoization basics](#compute--memoize-with-storagerun) -- [query storage directly from computational code](#implicit-declarative-queries) -- [explicit query interface](#explicit-declarative-queries-with-storagequery) -- [versioning and dependency tracking](#versioning-and-dependency-tracking) - -### `Storage` and the `@op` decorator -A `Storage` instance holds all the data (saved calls and metadata) for a -collection of memoized functions. In a given project, you should have just one -`Storage` and many memoized functions connected to it. This way, the calls to -memoized functions create a queriable web of interlinked objects. - -```python -from mandala.all import Storage, op - -storage = Storage( - db_path='my_persistent_storage.db', # omit for an in-memory storage - deps_path='path_to_code_folder/', # omit to disable automatic dependency tracking - spillover_dir='spillover_dir/', # spillover storage for large objects - # see docs for more options -) -``` -The `@op` decorator marks a function `f` as memoizable. Some notes apply: -- `f` must have a **fixed number of arguments** (defaults are allowed, and - arguments can always be added backward-compatibly) -- `f` must have a **fixed number of return values**, and this must be annotated in - the function signature **or** declared as `@op(nout=...)` -- `f` must have a **compatible interface** throughout its life. - - **The only way to change `f`'s interface once it already has memoized calls - is to add new arguments with default values.** - - if you want to change the interface in an incompatible way, you should - either just make a new function (under a new name), or increment the - `@op(version=...)` argument. -- the `list`, `dict` and `set` collections, when used as argument/return - annotations, cause elements of these collections to be stored separately. To - avoid confusion, `tuple`s are reserved for specifying the number of outputs. - -```python -from sklearn.datasets import load_digits -from sklearn.ensemble import RandomForestClassifier -from typing import Tuple -import numpy as np - -@op # core mandala decorator -def load_data(n_class: int = 10) -> Tuple[np.ndarray, np.ndarray]: - return load_digits(n_class=n_class, return_X_y=True) - -@op -def train_model(X: np.ndarray, y: np.ndarray, - n_estimators:int = 5) -> RandomForestClassifier: - return RandomForestClassifier(n_estimators=n_estimators, - max_depth=2).fit(X, y) -``` - -**Calling an `@op`-decorated function "normally" does not memoize**. To actually -put data in the storage, you must put calls inside a `with storage.run():` -block. +# Quickstart -### Compute & memoize `with storage.run():` -**`@op`-decorated functions are designed to be composed** with one another -inside `storage.run()` blocks. This composability lets -you use the same piece of ordinary Python code to compute, save, load, *and any -combination of the three*: -```python -# generate the dataset. This saves `X, y` to storage. -with storage.run(): - X, y = load_data() +[Run in Colab](https://colab.research.google.com/github/amakelov/mandala/blob/master/mandala/_next/tutorials/hello.ipynb) -# later, train a model by directly adding on top of this code. `load_data` is -# not computed again -with storage.run(): - X, y = load_data() - model = train_model(X, y) - -# iterate on this with more parameters & logic -from sklearn.metrics import accuracy_score +# Documentation +TODO: link -@op -def get_acc(model:RandomForestClassifier, X:np.ndarray, y:np.ndarray) -> float: - return round(accuracy_score(y_pred=model.predict(X), y_true=y), 2) - -with storage.run(): - for n_class in (10, 5, 2): - X, y = load_data(n_class) - for n_estimators in (5, 10, 20): - model = train_model(X, y, n_estimators=n_estimators) - acc = get_acc(model, X, y) - print(acc) -``` -```python -ValueRef(0.66, uid=15a...) -ValueRef(0.73, uid=79e...) -ValueRef(0.81, uid=5a4...) -ValueRef(0.84, uid=6c4...) -ValueRef(0.89, uid=fb8...) -ValueRef(0.93, uid=c3d...) -ValueRef(1.0, uid=b67...) -ValueRef(1.0, uid=b67...) -ValueRef(1.0, uid=b67...) -``` - -Memoized functions return `Ref` instances (`ValueRef`, `ListRef`, -...), which bundle the actual return value with storage metadata. To get the -return value itself, use `unwrap`. It works recursively on collections (lists, -dicts, sets, tuples) as well. - -You can **imperatively query** storage just by retracing some code that's been -entirely memoized: -```python -from mandala.all import unwrap -# use composable memoization as imperative query interface -with storage.run(): - X, y = load_data(5) - for n_estimators in (5, 20): - model = train_model(X, y, n_estimators=n_estimators) - acc = get_acc(model, X, y) - print(unwrap(acc)) -``` -```python -0.84 -0.93 -``` -### Implicit declarative queries -You can point to local variables in memoized code, and get a table of all values -in storage with the same functional dependencies as these variables have in the -code. For example, the `storage.similar()` method can be used to get values with -the same **joint computational history** as the given variables. To be able to -use a local variable in `storage.similar()`, it needs to be `wrap`ped as a `Ref` -(which has no effect on computation): - -```python -from mandala.all import wrap -# use composable memoization as imperative query interface -with storage.run(): - n_class = wrap(5) - X, y = load_data(n_class) - for n_estimators in wrap((5, 20)): # `wrap` maps over list, set and tuple - model = train_model(X, y, n_estimators=n_estimators) - acc = get_acc(model, X, y) - -storage.similar(n_class, n_estimators, acc) -``` -```python -Pattern-matching to the following computational graph (all constraints apply): - n_estimators = Q() # input to computation; can match anything - n_class = Q() # input to computation; can match anything - X, y = load_data(n_class=n_class) - model = train_model(X=X, y=y, n_estimators=n_estimators) - acc = get_acc(model=model, X=X, y=y) - result = storage.df(n_class, n_estimators, acc) - n_class n_estimators acc -1 10 5 0.66 -0 10 10 0.73 -2 10 20 0.81 -7 5 5 0.84 -6 5 10 0.89 -8 5 20 0.93 -4 2 5 1.00 -3 2 10 1.00 -5 2 20 1.00 -``` - -The computational graph printed out by the query (default `verbose=True`) is -also a good starting point for running an explicit query where you directly -provide the computational graph instead of extracting it from a program. - -### Explicit declarative queries `with storage.query():` -The kind of printout above can be directly copy-pasted into a `with -storage.query():` block. Here it is with some more explanation: -```python -with storage.query(): - n_class = Q() # creates a variable that can match any value in storage - # @op calls impose constraints between the values variables can match - X, y = load_data(n_class) - n_estimators = Q() # another variable that can match anything - model = train_model(X, y, n_estimators=n_estimators) - acc = get_acc(model, X, y) - # get a table where each row is a matching of the given variables - # that satisfies the constraints - result = storage.df(n_class, n_estimators, acc) -``` -#### How the `.query()` context works -- a query variable, generated with `Q()` or as return value from an `@op` - (like `X`, `y`, ... above), can in -principle match any value in the storage. -- by chaining together calls to `@op`s, you impose constraints between the - inputs and outputs to the op. For exampe, `X, y = load_data(n_class)` imposes - the constraint that a matching of values `(a, b, c)` to `(n_class, X, y)` must - satisfy `b, c = load_data(a)`. -- You can omit even required function arguments. This leaves them unconstrained. -- the `df` method takes any sequence of variables, and returns a table -where each row is a matching of values to the respective variables that -satisfies **all** the constraints. - -**The query implementation has not been optimized for performance at this point**. Keep in mind that -- if your query is not sufficiently constrained, there may be a combinatorial -explosion of results; -- if you query involves many variables and constraints, the default `SQL` query -solver may have a hard time, or flat out raise an error. Try using -`engine='naive'` in the `storage.similar()` or `storage.df()` methods instead. - -### Versioning and dependency tracking -Passing a value to the `deps_path` parameter of the `Storage` class enables -dependency tracking and versioning. This means that any time a memoized function -*actually executes* (instead of loading an already saved call), it keeps track of -the functions and global variables it accesses along the way. - -The number of tracked functions should be limited for efficiency (you typically -don't want to track changes in installed libraries!). Setting `deps_path` to -`"__main__"` will only look for dependencies defined in the current interactive -session or process. Setting it to a folder will only look for dependencies -defined in this folder. - -#### NOTE: The `@track` decorator -The most efficient and reliable implementation of dependency tracking currently -requires you to explicitly put `@track` on non-memoized functions and classes -you want to track. This limitation may be lifted in the future, but at the cost -of more magic. - -#### What is a version? -A **version** for a memoized function is (to a first approximation) a set of -source codes for functions/methods/global variables accessed by some call to -this function. Even if you don't change anything in the code, a single function -can have multiple versions if it invokes different dependencies for different calls. For example, consider this code: -```python -import numpy as np -from sklearn.datasets import load_digits -from sklearn.linear_model import LogisticRegression -from sklearn.preprocessing import StandardScaler -from mandala.imports import Storage, op, track -from typing import Tuple, Any - -N_CLASS = 10 - -@track # to track a non-memoized function as a dependency -def scale_data(X): - return StandardScaler(with_mean=True, with_std=False).fit_transform(X) - -@op -def load_data() -> Tuple[np.ndarray, np.ndarray]: - X, y = load_digits(n_class=N_CLASS, return_X_y=True) - return X, y - -@op -def train_model(X, y, scale=False) -> LogisticRegression: - if scale: - X = scale_data(X) - return LogisticRegression().fit(X, y) - -@op -def eval_model(model, X, y, scale=False) -> Any: - if scale: - X = scale_data(X) - return model.score(X, y) - -storage = Storage(deps_path='__main__') - -with storage.run(): - X, y = load_data() - for scale in [False, True]: - model = train_model(X, y, scale=scale) - acc = eval_model(model, X, y, scale=scale) -``` -When you run it, `train_model` and `eval_model` will each have two versions - -one that depends on `scale_data` and one that doesn't. You can confirm this by -calling `storage.versions(train_model)`. Now suppose we make some changes -and re-run: -```python -N_CLASS = 5 - -@track -def scale_data(X): - return StandardScaler(with_mean=True, with_std=True).fit_transform(X) - -@op -def eval_model(model, X, y, scale=False) -> Any: - if scale: - X = scale_data(X) - return round(model.score(X, y), 2) - -with storage.run(): - X, y = load_data() - for scale in [False, True]: - model = train_model(X, y, scale=scale) - acc = eval_model(model, X, y, scale=scale) -``` -When entering the `storage.run()` block, the storage will detect the changes in -the tracked components, and for each change will present you with the functions -affected: -- `N_CLASS` is a dependency for `load_data`; -- `scale_data` is a dependency for the calls to `train_model` and `eval_model` - which had `scale=True`; -- `eval_model` is a dependency for itself. - -#### Semantic vs content changes and versions -For each change to the content of some dependency (the source code of a function -or the value of a global variable), you can choose whether this content change -is also a **semantic** change. A semantic change will cause all calls that -have accessed this dependency to not appear memoized **with respect to the new -state of the code**. The content versions of a single dependency are organized -in a `git`-like DAG (currently, tree) that can be inspected using -`storage.sources(f)` for functions. - -#### Going back in time -Since the versioning system is content-based, simply restoring an old state of -the code makes the storage automatically recognize which "world" it's in, and -which calls are memoized in this world. - -#### A warning about non-semantic changes -The main motivation for allowing non-semantic changes is to maintain clarity in -the storage when doing routine code improvements (refactoring, comments, -logging). **However**, non-semantic changes should be applied with care. Apart from -being prone to errors (you wrongly conclude that a change has no effect on -semantics when it does), they can also introduce **invisible dependencies**: -suppose you factor a function out of some dependency and mark the change -non-semantic. Then the newly extracted function may in reality be a dependency -of the existing calls, but this goes unnoticed by the system. - -## Other gotchas - -- **under development**: the biggest gotcha is that this project is under active -development, which means things can change unpredictably. -- **slow**: it hasn't been optimized for performance, so many things are quite -inefficient -- **pure functions**: you should probably only use it for functions with a - deterministic input-output behavior if you're new to this project: - - **changing a `Ref`'s object in-place will generally break things**. If you - really need to update an object in-place, wrap the update in an `@op` so - that you get instead a new `Ref` (with updated metadata) pointing to the - same (changed) object, and discard the old `Ref`. - - if a function does not have a **deterministic set of dependencies** - it invokes for each given call, this may break the versioning system's - invariants. -- **avoid long (e.g. > 50) chains of calls in queries**: you should keep your - workflows relatively shallow for queries to be efficient. This means e.g. no - long recursive chains of calling a function repeatedly on its output -- **don't rename anything (yet)**: there isn't good support yet for renaming -functions, or moving functions around files. It's possible to rename functions -and their arguments, but this is still undocumented. -- **deletion**: no interfaces are currently exposed for deleting results. -- **examine complex queries manually**: the color refinement algorithm used to - extract a declarative query from a computational graph can in rare cases fail - to realize that two vertices have different roles in the computational graph - when projecting to the query. When in doubt, you should examine the printout - of the query and tweak it if necessary. - -## Tutorials +# Tutorials - see the ["Hello world!" tutorial](https://github.com/amakelov/mandala/blob/master/tutorials/00_hello.ipynb) for a 2-minute introduction to the library's main features -- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_logistic.ipynb) +- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_random_forest.ipynb) for a more realistic example of a machine learning project managed by Mandala. -- [dependency - tracking](https://github.com/amakelov/mandala/blob/master/tutorials/02_dependencies.ipynb) - tutorial +- TODO: dependency tracking + +# FAQs + +## How is this different from other experiment tracking frameworks? +Compared to popular tools like W&B, MLFlow or Comet, `mandala`: +- **is tightly integrated with the actual Python code execution**, as +opposed to being an external logging framework. This makes it much easier to +compose and iterate on multi-step experiments with non-trivial control flow. + - For instance, Python's collections can be (if so desired) made + transparent to the storage system, so that individual elements are + stored separately and can be reused across collections and calls. +- **is founded on memoization, which allows direct and transparent reuse of +results**. + - While in other frameworks you need to come up with arbitrary names +for artifacts (which can later cause confusion) +- **allows reuse, queries and versioning on a more granular and flexible +level** - the function call - as opposed to entire scripts and/or notebooks. +- **provides the `ComputationFrame` data structure**, a much more natural way to +represent and manipulate persisted computations. +- **automatically resolves the version of every `@op`** from the current state +of the codebase. + +## How is the `@op` cache invalidated? +- given inputs for a call to an `@op`, e.g. `f`, it searches for a past call +to `f` on inputs with the same contents (as determined by a hash function) where the dependencies accessed by this call (including `f` +itself) have versions compatible with their current state. +- compatibility between versions of a function is decided by the user: you +have the freedom to mark certain changes as compatible with past results. +- internally, `mandala` uses slightly modified `joblib` hashing to compute a +content hash for Python objects. This is practical for many use cases, but +not perfect, as discussed in the "gotchas" notebook TODO. + +## Can I change the code of `@op`s, and what happens if I do? +- a frequent use case: you have some `@op` you've been using, then want to +extend its functionality in a way that doesn't invalidate the past results. +The recommended way is to add a new argument `a`, and provide a default +value for it wrapped with `NewArgDefault(x)`. When a value equal to `x` is +passed for this argument, the storage falls back on calls before + +## How self-contained is it? +- `mandala`'s core is simple (only a few kLoCs) and only depends on `pandas` +and `joblib`. +- for visualization of `ComputationFrame`s, you should have `dot` installed +on the system level, and/or the Python `graphviz` library installed. + +# Limitations + +# Roadmap + +# Testimonials + +> "`mandala` addresses a core challenge in my notebook workflow: being able to +> explore data with code, without having to worry about losing the results of +> expensive calculations." - *Adam Jermyn, Member of Technical Staff, Anthropic* -## Related work +# Galaxybrained vision +Aspirationally, `mandala` is about much more than ML experiment tracking. The +main goal is to **make persistence logic & best practices a natural extension of Python**. +Once this is achieved, the purely "computational" code you must write anyway +doubles as a storage interface. It's hard to think of a simpler and more +reliable way to manage computational artifacts. + +## A first-principles approach to managing computational artifacts +What we want from our storage are ways to +- refer to artifacts with short, unambiguous descriptions: "here's [big messy Python object] I computed, which to me +means [human-readable description]" +- save artifacts: "save [big messy Python object]" +- refer to artifacts and load them at a later time: "give me [human-readable description] that I computed before" +- know when you've already computed something: "have I computed [human-readable description]?" +- query results in more complicated ways: "give me all the things that satisfy +[higher-level human-readable description]", which in practice means some +predicate over combinations of artifacts. +- get a report of how artifacts were generated: "what code went into [human-readable description]?" + +The key observation is that **execution traces** can already answer ~all of +these questions. + +# Related work `mandala` combines ideas from, and shares similarities with, many technologies. Here are some useful points of comparison: - **memoization**: @@ -544,16 +148,9 @@ Here are some useful points of comparison: computation data processing framework that unifies over different resource types (files or services). It also uses an analogous notion of hashing to keep track of computations. -- **queries**: - - all queries in `mandala` are [conjunctive queries](https://en.wikipedia.org/wiki/Conjunctive_query), a - fundamental class of queries in relational algebra. - - conjunctive queries are also related to category theory, see e.g. +- **computation frames**: + - computation frames are related to the idea of using certain functions category theory, see e.g. [here](https://blog.algebraicjulia.org/post/2020/12/cset-conjunctive-queries/). - - the [color refinement - algorithm](https://en.wikipedia.org/wiki/Colour_refinement_algorithm) used - to extract a query from an arbitrary computational graph is a standard tool - for finding similar substructure in graphs and testing for graph - isomorphism. - **versioning**: - the revision history of each function in the codebase is organized in a "mini-[`git`](https://git-scm.com/) repository" that shares only the most basic features with `git`: it is a @@ -566,6 +163,5 @@ Here are some useful points of comparison: make backward-compatible changes to the interface and logic of dependencies. It is different in that versions are still labeled by content, instead of by "non-canonical" numbers. - - the [unison programming - language](https://www.unison-lang.org/learn/the-big-idea/) represents + - the [unison programming language](https://www.unison-lang.org/learn/the-big-idea/) represents functions by the hash of their content (syntax tree, to be exact). diff --git a/mandala/_next/README.md b/mandala/_next/README.md deleted file mode 100644 index 4ab2e29..0000000 --- a/mandala/_next/README.md +++ /dev/null @@ -1,167 +0,0 @@ -
-
- logo -
-Install | -Quickstart | -Testimonials | -Demos | -Usage | -Gotchas | -Tutorials -
- -# Computations that save, query and version themselves - -
- -`mandala` eliminates the effort and code overhead of ML experiment tracking with -two tools: - -1. The `@op` decorator: - - Automatically captures inputs, outputs and code (+dependencies) of Python function calls - - Automatically reuses past results & never computes the same call twice - - Designed to be composed into end-to-end persisted programs, enabling - efficient iterative development in plain-Python, without thinking about the - storage backend. - -2. The `ComputationFrame` data structure: - - Generalized dataframes: "columns" are a computation graph, "rows" are - (partial) executions of the graph - - Automates exploration, querying and high-level operations over - heterogeneous "webs" of `@op` calls - - Can be converted to a `DataFrame` of execution traces for downstream - analysis of relationships between variables in a project - -# Install -``` -pip install git+https://github.com/amakelov/mandala -``` - -# Quickstart - -[Run in Colab](https://colab.research.google.com/github/amakelov/mandala/blob/master/mandala/_next/tutorials/hello.ipynb) - -# Documentation -TODO: link - -# Tutorials -- see the ["Hello world!" - tutorial](https://github.com/amakelov/mandala/blob/master/tutorials/00_hello.ipynb) - for a 2-minute introduction to the library's main features -- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_random_forest.ipynb) -for a more realistic example of a machine learning project managed by Mandala. -- TODO: dependency tracking - -# FAQs - -## How is this different from other experiment tracking frameworks? -Compared to popular tools like W&B, MLFlow or Comet, `mandala`: -- **is tightly integrated with the actual Python code execution**, as -opposed to being an external logging framework. This makes it much easier to -compose and iterate on multi-step experiments with non-trivial control flow. - - For instance, Python's collections can be (if so desired) made - transparent to the storage system, so that individual elements are - stored separately and can be reused across collections and calls. -- **is founded on memoization, which allows direct and transparent reuse of -results**. - - While in other frameworks you need to come up with arbitrary names -for artifacts (which can later cause confusion) -- **allows reuse, queries and versioning on a more granular and flexible -level** - the function call - as opposed to entire scripts and/or notebooks. -- **provides the `ComputationFrame` data structure**, a much more natural way to -represent and manipulate persisted computations. -- **automatically resolves the version of every `@op`** from the current state -of the codebase. - -## How is the `@op` cache invalidated? -- given inputs for a call to an `@op`, e.g. `f`, it searches for a past call -to `f` on inputs with the same contents (as determined by a hash function) where the dependencies accessed by this call (including `f` -itself) have versions compatible with their current state. -- compatibility between versions of a function is decided by the user: you -have the freedom to mark certain changes as compatible with past results. -- internally, `mandala` uses slightly modified `joblib` hashing to compute a -content hash for Python objects. This is practical for many use cases, but -not perfect, as discussed in the "gotchas" notebook TODO. - -## Can I change the code of `@op`s, and what happens if I do? -- a frequent use case: you have some `@op` you've been using, then want to -extend its functionality in a way that doesn't invalidate the past results. -The recommended way is to add a new argument `a`, and provide a default -value for it wrapped with `NewArgDefault(x)`. When a value equal to `x` is -passed for this argument, the storage falls back on calls before - -## How self-contained is it? -- `mandala`'s core is simple (only a few kLoCs) and only depends on `pandas` -and `joblib`. -- for visualization of `ComputationFrame`s, you should have `dot` installed -on the system level, and/or the Python `graphviz` library installed. - -# Limitations - -# Roadmap - -# Testimonials - -> "`mandala` addresses a core challenge in my notebook workflow: being able to -> explore data with code, without having to worry about losing the results of -> expensive calculations." - *Adam Jermyn, Member of Technical Staff, Anthropic* - -# Galaxybrained vision -Aspirationally, `mandala` is about much more than ML experiment tracking. The -main goal is to **make persistence logic & best practices a natural extension of Python**. -Once this is achieved, the purely "computational" code you must write anyway -doubles as a storage interface. It's hard to think of a simpler and more -reliable way to manage computational artifacts. - -## A first-principles approach to managing computational artifacts -What we want from our storage are ways to -- refer to artifacts with short, unambiguous descriptions: "here's [big messy Python object] I computed, which to me -means [human-readable description]" -- save artifacts: "save [big messy Python object]" -- refer to artifacts and load them at a later time: "give me [human-readable description] that I computed before" -- know when you've already computed something: "have I computed [human-readable description]?" -- query results in more complicated ways: "give me all the things that satisfy -[higher-level human-readable description]", which in practice means some -predicate over combinations of artifacts. -- get a report of how artifacts were generated: "what code went into [human-readable description]?" - -The key observation is that **execution traces** can already answer ~all of -these questions. - -# Related work -`mandala` combines ideas from, and shares similarities with, many technologies. -Here are some useful points of comparison: -- **memoization**: - - standard Python memoization solutions are [`joblib.Memory`](https://joblib.readthedocs.io/en/latest/generated/joblib.Memory.html) - and - [`functools.lru_cache`](https://docs.python.org/3/library/functools.html#functools.lru_cache). - `mandala` uses `joblib` serialization and hashing under the hood. - - [`incpy`](https://github.com/pajju/IncPy) is a project that integrates - memoization with the python interpreter itself. - - [`funsies`](https://github.com/aspuru-guzik-group/funsies) is a - memoization-based distributed workflow executor that uses an analogous notion - of hashing to `mandala` to keep track of which computations have already been done. It - works on the level of scripts (not functions), and lacks queriability and - versioning. - - [`koji`](https://arxiv.org/abs/1901.01908) is a design for an incremental - computation data processing framework that unifies over different resource - types (files or services). It also uses an analogous notion of hashing to - keep track of computations. -- **computation frames**: - - computation frames are related to the idea of using certain functions category theory, see e.g. - [here](https://blog.algebraicjulia.org/post/2020/12/cset-conjunctive-queries/). -- **versioning**: - - the revision history of each function in the codebase is organized in a "mini-[`git`](https://git-scm.com/) repository" that shares only the most basic - features with `git`: it is a - [content-addressable](https://en.wikipedia.org/wiki/Content-addressable_storage) - tree, where each edge tracks a diff from the content at one endpoint to that - at the other. Additional metadata indicates equivalence classes of - semantically equivalent contents. - - [semantic versioning](https://semver.org/) is another popular code - versioning system. `mandala` is similar to `semver` in that it allows you to - make backward-compatible changes to the interface and logic of dependencies. - It is different in that versions are still labeled by content, instead of by - "non-canonical" numbers. - - the [unison programming language](https://www.unison-lang.org/learn/the-big-idea/) represents - functions by the hash of their content (syntax tree, to be exact). diff --git a/mandala/_next/common_imports.py b/mandala/_next/common_imports.py deleted file mode 100644 index c65c395..0000000 --- a/mandala/_next/common_imports.py +++ /dev/null @@ -1,63 +0,0 @@ -import time -import traceback -import random -import logging -import itertools -import copy -import hashlib -import io -import os -import shutil -import sys -import joblib -import inspect -import binascii -import asyncio -import ast -import types -import tempfile -from collections import defaultdict, OrderedDict -from typing import ( - Any, - Dict, - List, - Callable, - Tuple, - Iterable, - Optional, - Set, - Union, - TypeVar, - Literal, -) -from pathlib import Path - -import pandas as pd -import pyarrow as pa -import numpy as np - -try: - import rich - - has_rich = True -except ImportError: - has_rich = False - -if has_rich: - from rich.logging import RichHandler - - logger = logging.getLogger("mandala") - logging_handler = RichHandler(enable_link_path=False) - FORMAT = "%(message)s" - logging.basicConfig( - level="INFO", format=FORMAT, datefmt="[%X]", handlers=[logging_handler] - ) -else: - logger = logging.getLogger("mandala") - # logger.addHandler(logging.StreamHandler()) - FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" - logging.basicConfig(format=FORMAT) - logger.setLevel(logging.INFO) - - -from mandala.common_imports import sess diff --git a/mandala/_next/deps/crawler.py b/mandala/_next/deps/crawler.py deleted file mode 100644 index 25f109c..0000000 --- a/mandala/_next/deps/crawler.py +++ /dev/null @@ -1,109 +0,0 @@ -import types -from ..common_imports import * -from ..utils import unwrap_decorators -import importlib -from .model import ( - DepKey, - CallableNode, -) -from .utils import ( - is_callable_obj, - extract_func_obj, - unknown_function, -) - - -def crawl_obj( - obj: Any, - module_name: str, - include_methods: bool, - result: Dict[DepKey, CallableNode], - strict: bool, - objs_result: Dict[DepKey, Callable], -): - """ - Find functions and optionally methods native to the module of this object. - """ - if is_callable_obj(obj=obj, strict=strict): - if isinstance(unwrap_decorators(obj, strict=False), types.BuiltinFunctionType): - return - v = extract_func_obj(obj=obj, strict=strict) - if v is not unknown_function and v.__module__ != module_name: - # exclude non-local functions - return - dep_key = (module_name, v.__qualname__) - node = CallableNode.from_obj(obj=v, dep_key=dep_key) - result[dep_key] = node - objs_result[dep_key] = obj - if isinstance(obj, type): - if include_methods: - if obj.__module__ != module_name: - return - for k in obj.__dict__.keys(): - v = obj.__dict__[k] - crawl_obj( - obj=v, - module_name=module_name, - include_methods=include_methods, - result=result, - strict=strict, - objs_result=objs_result, - ) - - -def crawl_static( - root: Optional[Path], - strict: bool, - package_name: Optional[str] = None, - include_methods: bool = False, -) -> Tuple[Dict[DepKey, CallableNode], Dict[DepKey, Callable]]: - """ - Find all python files in the root directory, and use importlib to import - them, look for callable objects, and create callable nodes from them. - """ - result: Dict[DepKey, CallableNode] = {} - objs_result: Dict[DepKey, Callable] = {} - paths = [] - if root is not None: - if root.is_file(): - assert package_name is not None # needs this to be able to import - paths = [root] - else: - paths.extend(list(root.rglob("*.py"))) - paths.append("__main__") - for path in paths: - filename = path.name if path != "__main__" else "__main__" - if filename in ("setup.py", "console.py"): - continue - if path != "__main__" and root is not None: - if root.is_file(): - module_name = root.stem - else: - module_name = ( - path.with_suffix("").relative_to(root).as_posix().replace("/", ".") - ) - if package_name is not None: - module_name = ".".join([package_name, module_name]) - else: - module_name = "__main__" - try: - module = importlib.import_module(module_name) - except: - msg = f"Failed to import {module_name}:" - if strict: - raise ValueError(msg) - else: - logger.warning(msg) - continue - keys = list(module.__dict__.keys()) - for k in keys: - v = module.__dict__[k] - crawl_obj( - obj=v, - module_name=module_name, - strict=strict, - include_methods=include_methods, - result=result, - objs_result=objs_result, - ) - return result, objs_result diff --git a/mandala/_next/deps/deep_versions.py b/mandala/_next/deps/deep_versions.py deleted file mode 100644 index dc4a0ba..0000000 --- a/mandala/_next/deps/deep_versions.py +++ /dev/null @@ -1,188 +0,0 @@ -from ..common_imports import * -from .utils import DepKey, hash_dict -from .model import ( - Node, - CallableNode, - GlobalVarNode, - TerminalNode, -) - - -from .shallow_versions import DAG - - -class Version: - """ - Model of a "deep" version of a component that includes versions of its - dependencies. - """ - - def __init__( - self, - component: DepKey, - dynamic_deps_commits: Dict[DepKey, str], - memoized_deps_content_versions: Dict[DepKey, Set[str]], - ): - ### raw data from the trace - # the component whose dependencies are traced - self.component = component - # the content hashes of the direct dependencies - self.direct_deps_commits = dynamic_deps_commits - # pointers to content hashes of versions of memoized calls - self.memoized_deps_content_versions = memoized_deps_content_versions - - ### cached data. These are set against a dependency state - self._is_synced = False - # the expanded set of dependencies, including all transitive - # dependencies. Note this is a set of *content* hashes per dependency - self._content_expansion: Dict[DepKey, Set[str]] = None - # a hash uniquely identifying the content of dependencies of this version - self._content_version: str = None - # the semantic hashes of all dependencies for this version; - # the system enforces that the semantic hash of a dependency is the same - # for all commits of a component referenced by this version - self._semantic_expansion: Dict[DepKey, str] = None - # overall semantic hash of this version - self._semantic_version: str = None - - @property - def presentation(self) -> str: - return f'Version of "{self.component[1]}" from module "{self.component[0]}" (content: {self.content_version}, semantic: {self.semantic_version})' - - @staticmethod - def from_trace( - component: DepKey, nodes: Dict[DepKey, Node], strict: bool = True - ) -> "Version": - dynamic_deps_commits = {} - memoized_deps_content_versions = defaultdict(set) - for dep_key, node in nodes.items(): - if isinstance(node, (CallableNode, GlobalVarNode)): - dynamic_deps_commits[dep_key] = node.content_hash - elif isinstance(node, TerminalNode): - terminal_data = node.representation - pointer_dep_key = terminal_data.dep_key - version_content_hash = terminal_data.call_content_version - memoized_deps_content_versions[pointer_dep_key].add( - version_content_hash - ) - else: - raise ValueError(f"Unexpected node type {type(node)}") - return Version( - component=component, - dynamic_deps_commits=dynamic_deps_commits, - memoized_deps_content_versions=dict(memoized_deps_content_versions), - ) - - ############################################################################ - ### methods for setting cached data from a versioning state - ############################################################################ - def _set_content_expansion(self, all_versions: Dict[DepKey, Dict[str, "Version"]]): - result = defaultdict(set) - for dep_key, content_hash in self.direct_deps_commits.items(): - result[dep_key].add(content_hash) - for ( - dep_key, - memoized_content_versions, - ) in self.memoized_deps_content_versions.items(): - for memoized_content_version in memoized_content_versions: - referenced_version = all_versions[dep_key][memoized_content_version] - for ( - referenced_dep_key, - referenced_content_hashes, - ) in referenced_version.content_expansion.items(): - result[referenced_dep_key].update(referenced_content_hashes) - self._content_expansion = dict(result) - - def _set_content_version(self): - self._content_version = hash_dict( - { - dep_key: tuple(sorted(self.content_expansion[dep_key])) - for dep_key in self.content_expansion - } - ) - - def _set_semantic_expansion( - self, - component_dags: Dict[DepKey, DAG], - all_versions: Dict[DepKey, Dict[str, "Version"]], - ): - result = {} - # from own deps - for dep_key, dep_content_hash in self.direct_deps_commits.items(): - dag = component_dags[dep_key] - semantic_hash = dag.commits[dep_content_hash].semantic_hash - result[dep_key] = semantic_hash - # from pointers - for ( - dep_key, - memoized_content_versions, - ) in self.memoized_deps_content_versions.items(): - for memoized_content_version in memoized_content_versions: - dep_version_semantic_hashes = all_versions[dep_key][ - memoized_content_version - ].semantic_expansion - overlap = set(result.keys()).intersection( - dep_version_semantic_hashes.keys() - ) - if any(result[k] != dep_version_semantic_hashes[k] for k in overlap): - raise ValueError( - f"Version {self} has conflicting semantic hashes for {overlap}" - ) - result.update(dep_version_semantic_hashes) - self._semantic_expansion = result - self._semantic_version = hash_dict(result) - - def sync( - self, - component_dags: Dict[DepKey, DAG], - all_versions: Dict[DepKey, Dict[str, "Version"]], - ): - """ - Set all the cached things in the correct order - """ - self._set_content_expansion(all_versions=all_versions) - self._set_content_version() - self._set_semantic_expansion( - component_dags=component_dags, all_versions=all_versions - ) - self.set_synced() - - @property - def content_version(self) -> str: - assert self._content_version is not None - return self._content_version - - @property - def semantic_version(self) -> str: - assert self._semantic_version is not None - return self._semantic_version - - @property - def semantic_expansion(self) -> Dict[DepKey, str]: - assert self._semantic_expansion is not None - return self._semantic_expansion - - @property - def content_expansion(self) -> Dict[DepKey, Set[str]]: - assert self._content_expansion is not None - return self._content_expansion - - @property - def support(self) -> Iterable[DepKey]: - return self.content_expansion.keys() - - @property - def is_synced(self) -> bool: - return self._is_synced - - def set_synced(self): - # it can only go from unsynced to synced - if self._is_synced: - raise ValueError("Version is already synced") - self._is_synced = True - - def __repr__(self) -> str: - return f""" -Version( - dependencies={['.'.join(elt) for elt in self.support]}, -)""" diff --git a/mandala/_next/deps/model.py b/mandala/_next/deps/model.py deleted file mode 100644 index 731bb0e..0000000 --- a/mandala/_next/deps/model.py +++ /dev/null @@ -1,311 +0,0 @@ -import textwrap -from abc import abstractmethod, ABC -import types - -from ..common_imports import * -from ..utils import get_content_hash -from ..viz import ( - write_output, -) - -from .utils import ( - DepKey, - load_obj, - get_runtime_description, - extract_code, - unknown_function, - UNKNOWN_GLOBAL_VAR, -) - - -class Node(ABC): - def __init__(self, module_name: str, obj_name: str, representation: Any): - self.module_name = module_name - self.obj_name = obj_name - self.representation = representation - - @property - def key(self) -> DepKey: - return (self.module_name, self.obj_name) - - def present_key(self) -> str: - raise NotImplementedError() - - @staticmethod - @abstractmethod - def represent(obj: Any) -> Any: - raise NotImplementedError - - @abstractmethod - def content(self) -> Any: - raise NotImplementedError - - @abstractmethod - def readable_content(self) -> str: - raise NotImplementedError - - @property - @abstractmethod - def content_hash(self) -> str: - raise NotImplementedError - - def load_obj(self, allow_fallback: bool) -> Any: - obj, found = load_obj(module_name=self.module_name, obj_name=self.obj_name) - if not found: - msg = f"{self.present_key()} not found" - if allow_fallback: - logger.warning(msg) - if hasattr(self, "FALLBACK_OBJ"): - return self.FALLBACK_OBJ - else: - raise ValueError(f"No fallback object defined for {self.__class__}") - else: - raise ValueError(msg) - return obj - - -class CallableNode(Node): - FALLBACK_OBJ = unknown_function - - def __init__( - self, - module_name: str, - obj_name: str, - representation: Optional[str], - runtime_description: str, - ): - self.module_name = module_name - self.obj_name = obj_name - self.runtime_description = runtime_description - if representation is not None: - self._set_representation(value=representation) - else: - self._representation = None - self._content_hash = None - - @staticmethod - def from_obj(obj: Any, dep_key: DepKey) -> "CallableNode": - representation = CallableNode.represent(obj=obj) - code_obj = extract_code(obj) - runtime_description = get_runtime_description(code=code_obj) - return CallableNode( - module_name=dep_key[0], - obj_name=dep_key[1], - representation=representation, - runtime_description=runtime_description, - ) - - @staticmethod - def from_runtime( - module_name: str, - obj_name: str, - code_obj: types.CodeType, - ) -> "CallableNode": - return CallableNode( - module_name=module_name, - obj_name=obj_name, - representation=None, - runtime_description=get_runtime_description(code=code_obj), - ) - - @property - def representation(self) -> str: - return self._representation - - def _set_representation(self, value: str): - assert isinstance(value, str) - self._representation = value - self._content_hash = get_content_hash(value) - - @representation.setter - def representation(self, value: str): - self._set_representation(value) - - @property - def is_method(self) -> bool: - return "." in self.obj_name - - def present_key(self) -> str: - return f"function {self.obj_name} from module {self.module_name}" - - @property - def class_name(self) -> str: - assert self.is_method - return ".".join(self.obj_name.split(".")[:-1]) - - @staticmethod - def represent( - obj: Union[types.FunctionType, types.CodeType, Callable], - allow_fallback: bool = False, - ) -> str: - if type(obj).__name__ == "Op": - obj = obj.f - if not isinstance(obj, (types.FunctionType, types.MethodType, types.CodeType)): - logger.warning(f"Found {obj} of type {type(obj)}") - try: - source = inspect.getsource(obj) - except Exception as e: - msg = f"Could not get source for {obj} because {e}" - if allow_fallback: - source = inspect.getsource(CallableNode.FALLBACK_OBJ) - logger.warning(msg) - else: - raise RuntimeError(msg) - # strip whitespace to prevent different sources looking the same in the - # ui - lines = source.splitlines() - lines = [line.rstrip() for line in lines] - source = "\n".join(lines) - return source - - def content(self) -> str: - return self.representation - - def readable_content(self) -> str: - return self.representation - - @property - def content_hash(self) -> str: - assert isinstance(self._content_hash, str) - return self._content_hash - - -class GlobalVarNode(Node): - FALLBACK_OBJ = UNKNOWN_GLOBAL_VAR - - def __init__( - self, - module_name: str, - obj_name: str, - # (content hash, truncated repr) - representation: Tuple[str, str], - ): - self.module_name = module_name - self.obj_name = obj_name - self._representation = representation - - @staticmethod - def from_obj(obj: Any, dep_key: DepKey) -> "GlobalVarNode": - representation = GlobalVarNode.represent(obj=obj) - return GlobalVarNode( - module_name=dep_key[0], - obj_name=dep_key[1], - representation=representation, - ) - - @property - def representation(self) -> Tuple[str, str]: - return self._representation - - @staticmethod - def represent(obj: Any, allow_fallback: bool = False) -> Tuple[str, str]: - truncated_repr = textwrap.shorten(text=repr(obj), width=80) - try: - content_hash = get_content_hash(obj=obj) - except Exception as e: - shortened_exception = textwrap.shorten(text=str(e), width=80) - msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}" - if allow_fallback: - content_hash = UNKNOWN_GLOBAL_VAR - logger.warning(msg) - else: - raise RuntimeError(msg) - return content_hash, truncated_repr - - def present_key(self) -> str: - return f"global variable {self.obj_name} from module {self.module_name}" - - def content(self) -> str: - return self.representation - - def readable_content(self) -> str: - return self.representation[1] - - @property - def content_hash(self) -> str: - assert isinstance(self.representation, tuple) - return self.representation[0] - - -class TerminalData: - def __init__( - self, - op_internal_name: str, - op_version: int, - call_content_version: str, - call_semantic_version: str, - # data: Tuple[Tuple[str, int], Tuple[str, str]], - dep_key: DepKey, - ): - # ((internal name, version), (content_version, semantic_version)) - self.op_internal_name = op_internal_name - self.op_version = op_version - self.call_content_version = call_content_version - self.call_semantic_version = call_semantic_version - self.dep_key = dep_key - - -class TerminalNode(Node): - def __init__(self, module_name: str, obj_name: str, representation: TerminalData): - self.module_name = module_name - self.obj_name = obj_name - self.representation = representation - - @property - def key(self) -> DepKey: - return self.module_name, self.obj_name - - def present_key(self) -> str: - raise NotImplementedError - - @property - def content_hash(self) -> str: - raise NotImplementedError - - def content(self) -> Any: - raise NotImplementedError - - def readable_content(self) -> str: - raise NotImplementedError - - @staticmethod - def represent(obj: Any) -> Any: - raise NotImplementedError - - -class DependencyGraph: - def __init__(self): - self.nodes: Dict[DepKey, Node] = {} - self.roots: Set[DepKey] = set() - self.edges: Set[Tuple[DepKey, DepKey]] = set() - - def get_trace_state(self) -> Tuple[DepKey, Dict[DepKey, Node]]: - if len(self.roots) != 1: - raise ValueError(f"Expected exactly one root, got {len(self.roots)}") - component = list(self.roots)[0] - return component, self.nodes - - def show(self, path: Optional[Path] = None, how: str = "none"): - dot = to_dot(self) - output_ext = "svg" if how in ["browser"] else "png" - return write_output( - dot_string=dot, output_path=path, output_ext=output_ext, show_how=how - ) - - def __repr__(self) -> str: - if len(self.nodes) == 0: - return "DependencyGraph()" - return to_string(self) - - def add_node(self, node: Node): - self.nodes[node.key] = node - - def add_edge(self, source: Node, target: Node): - if source.key not in self.nodes: - self.add_node(source) - if target.key not in self.nodes: - self.add_node(target) - self.edges.add((source.key, target.key)) - - -from .viz import to_dot, to_string diff --git a/mandala/_next/deps/shallow_versions.py b/mandala/_next/deps/shallow_versions.py deleted file mode 100644 index 98282f8..0000000 --- a/mandala/_next/deps/shallow_versions.py +++ /dev/null @@ -1,463 +0,0 @@ -from typing import Literal -import textwrap -from ..common_imports import * -from ..utils import get_content_hash -from ..config import Config -from ..utils import ask_user -from ..viz import _get_colorized_diff, _get_diff - -if Config.has_rich: - from rich.tree import Tree - from rich.panel import Panel - from rich.text import Text - from rich.syntax import Syntax - - -# TODO: figure out how to apply compact diffs -def get_diff(a: str, b: str) -> Tuple[str, str]: - """ - Get a diff between two strings - """ - return (a, b) - - -def apply_diff(b: str, diff: Tuple[str, str]) -> str: - """ - Apply a diff to a string - """ - return diff[0] - - -class Commit: - """ - Tracks versions of the "shallow" content of a single component. For - functions/methods, this is just the function source code, without any - dependencies. For global variables, this is just (the content hash of) the - value of the variable. - """ - - def __init__( - self, - parents: List[str], - diffs: List[Any], - content_hash: str, - semantic_hash: str, - content: Optional[str], - ): - # content hashes of parent commits - self.parents = parents # currently there may be at most one parent - # diffs between this commit and its parents - self.diffs = diffs - self.content_hash = content_hash - # content hash of the semantic version this commit is associated with - self.semantic_hash = semantic_hash - # content of this commit, if it is a root commit - self._content = content - self.check_invariants() - - def check_invariants(self): - assert len(self.parents) == len(self.diffs) - assert len(self.parents) > 0 or self._content is not None - - def __repr__(self) -> str: - return f"Commit(content_hash={self.content_hash}, semantic_hash={self.semantic_hash}, parents={self.parents})" - - -T = TypeVar("T") -from typing import Generic - - -class ContentAdapter(Generic[T]): - def get_diff(self, a: T, b: T) -> Any: - """ - Get a diff between two objects - """ - raise NotImplementedError() - - def apply_diff(self, b: T, diff: Any) -> T: - """ - Apply a diff to an object - """ - raise NotImplementedError() - - def get_presentable_content(self, content: T) -> str: - """ - Get a presentable string representation of the content - """ - raise NotImplementedError() - - def get_content_hash(self, content: T) -> str: - raise NotImplementedError() - - -class StringContentAdapter(ContentAdapter[str]): - def get_diff(self, a: str, b: str) -> Tuple[str, str]: - """ - Get a diff between two strings - """ - return get_diff(a, b) - - def apply_diff(self, b: str, diff: Tuple[str, str]) -> str: - """ - Apply a diff to a string - """ - return apply_diff(b, diff) - - def get_presentable_content(self, content: str) -> str: - """ - Get a presentable string representation of the content - """ - return content - - def get_content_hash(self, content: str) -> str: - return get_content_hash(content) - - -GVContent = Tuple[str, str] # (content hash, repr) - - -class GlobalVariableContentAdapter(ContentAdapter[GVContent]): - def get_diff(self, a: GVContent, b: GVContent) -> Tuple[GVContent, GVContent]: - """ - Get a diff between two global variable contents - """ - return (a, b) - - def apply_diff(self, b: GVContent, diff: Tuple[GVContent, GVContent]) -> GVContent: - """ - Apply a diff to a global variable content - """ - return diff[0] - - def get_presentable_content(self, content: GVContent) -> str: - return content[1] - - def get_content_hash(self, content: GVContent) -> str: - return content[0] - - -class DAG(Generic[T]): - """ - Organizes the shallow versions of a single component in a `git`-like DAG. - """ - - def __init__(self, content_type: Literal["code", "global_variable"] = "code"): - # content hash of the current head - self.head: Optional[str] = None - self.commits: Dict[str, Commit] = {} - self._initial_commit: Optional[T] = None - if content_type == "code": - self.content_adapter = StringContentAdapter() - elif content_type == "global_variable": - self.content_adapter = GlobalVariableContentAdapter() - else: - raise ValueError(f"Invalid content_type: {content_type}") - self.check_invariants() - - @property - def initial_commit(self) -> str: - assert self.head is not None - return self._initial_commit - - def check_invariants(self): - if self.head is not None: - assert self.head in self.commits - for commit, commit_obj in self.commits.items(): - commit_obj.check_invariants() - assert all(p in self.commits for p in commit_obj.parents) - assert commit_obj.content_hash == commit - assert commit_obj.semantic_hash in self.commits - - def get_current_content(self) -> T: - # return the full content of the current head - assert self.head is not None - return self.get_content(commit=self.head) - - def get_presentable_content(self, commit: str) -> str: - return self.content_adapter.get_presentable_content( - content=self.get_content(commit=commit) - ) - - def get_content(self, commit: str) -> T: - # return the full content of a commit given its content hash - if commit not in self.commits: - raise ValueError(f"Commit {commit} not in DAG") - commit_obj = self.commits[commit] - if commit_obj._content is not None: - return commit_obj._content - else: - parent_content = self.get_content(commit_obj.parents[0]) - return self.content_adapter.apply_diff(parent_content, commit_obj.diffs[0]) - - def init(self, initial_content: T) -> str: - """ - Initialize the DAG with the initial content, and set the head to the - initial commit. Return the content hash of the initial commit. - """ - # initialize the DAG with the initial content - assert self.head is None - content_hash = self.content_adapter.get_content_hash(content=initial_content) - semantic_hash = content_hash - commit = Commit( - parents=[], - diffs=[], - content_hash=content_hash, - semantic_hash=semantic_hash, - content=initial_content, - ) - self.head = content_hash - self.commits[content_hash] = commit - self._initial_commit = content_hash - return content_hash - - def checkout(self, commit: str, implicit_merge: bool = False): - """ - checkout an *existing* commit, i.e. set the head to the given commit. - if implicit_merge is True, an edge will be added from the new head to - the old head - """ - if commit not in self.commits: - raise ValueError(f"Commit {commit} not in DAG") - if implicit_merge: - raise NotImplementedError - self.head = commit - - def commit(self, content: T, is_semantic_change: Optional[bool] = None) -> str: - """ - Commit a *new* version of the content and return the content hash - """ - assert self.head is not None - content_hash = self.content_adapter.get_content_hash(content=content) - assert content_hash not in self.commits - head_commit = self.commits[self.head] - if is_semantic_change is None: - presentable_diff = get_diff( - self.content_adapter.get_presentable_content(content), - self.get_presentable_content(commit=self.head), - ) - print( - _get_colorized_diff( - current=presentable_diff[1], new=presentable_diff[0] - ) - ) - answer = ask_user( - question="Does this change require recomputation of dependent calls?\nWARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\nAnswer: [y]es/[n]o/[a]bort", - valid_options=["y", "n", "a"], - ) - print(f'You answered: "{answer}"') - if answer == "a": - raise ValueError("Aborting commit") - is_semantic_change = answer == "y" - if is_semantic_change: - semantic_hash = content_hash - else: - semantic_hash = head_commit.semantic_hash - diff = self.content_adapter.get_diff( - content, self.get_content(commit=self.head) - ) - commit = Commit( - parents=[self.head], - diffs=[diff], - content_hash=content_hash, - semantic_hash=semantic_hash, - content=None, - ) - self.head = content_hash - self.commits[content_hash] = commit - return content_hash - - def sync(self, content: T, is_semantic_change: Optional[bool] = None) -> str: - """ - if the content: - - is the current head, do nothing - - is in the DAG, checkout the content - - is not in the DAG, commit the content - - return the content hash - """ - assert self.head is not None - content_hash = self.content_adapter.get_content_hash(content) - if self.head == content_hash: - result = self.head - elif content_hash in self.commits: - self.checkout(content_hash) - result = self.head - else: - result = self.commit(content, is_semantic_change=is_semantic_change) - assert self.head == content_hash - return result - - ############################################################################ - ### visualization and printing - ############################################################################ - def _get_tree_neighbors_representation(self) -> Dict[str, Set[str]]: - """ - Get a {parent: {children}} representation of the tree underlying the DAG - (obtained by following the first parent of each commit). - """ - result = defaultdict(set) - for commit in self.commits.values(): - if len(commit.parents) > 0: - parent = commit.parents[0] - result[parent].add(commit.content_hash) - return dict(result) - - def get_commit_presentation( - self, commit: str, diff_only: bool, include_metadata: bool = False - ) -> Tuple[str, str]: - if diff_only: - if commit == self.initial_commit: - content_to_show = self.get_presentable_content(commit) - content_type = "code" - else: - parent_content = self.get_content(self.commits[commit].parents[0]) - child_content = self.content_adapter.apply_diff( - b=parent_content, diff=self.commits[commit].diffs[0] - ) - parent_presentable_content = ( - self.content_adapter.get_presentable_content(content=parent_content) - ) - child_presentable_content = ( - self.content_adapter.get_presentable_content(content=child_content) - ) - content_to_show = _get_diff( - current=parent_presentable_content, - new=child_presentable_content, - ) - content_type = "diff" - else: - content_to_show = self.get_presentable_content(commit) - content_type = "code" - - content_version = commit - semantic_version = self.commits[commit].semantic_hash - - header_lines = [] - if commit == self.head: - header_lines.append(f"### ===HEAD===") - if include_metadata: - header_lines.append(f"### content_commit={content_version}") - header_lines.append(f"### semantic_commit={semantic_version}") - if len(header_lines) > 0: - header = "\n".join(header_lines) - content_to_show = f"{header}\n{content_to_show}" - return content_to_show, content_type - - if Config.has_rich: - from rich.panel import Panel - from rich.tree import Tree - - def get_commit_content_rich( - self, - commit: str, - diff_only: bool = False, - title: Optional[str] = None, - include_metadata: bool = False, - ) -> Panel: - """ - Get a rich panel representing the content and metadata of a commit. - """ - content_to_show, content_type = self.get_commit_presentation( - commit=commit, diff_only=diff_only, include_metadata=include_metadata - ) - if title is not None: - title = Text(title, style="bold") - lexer = "python" if content_type == "code" else "diff" - content = Syntax( - content_to_show, - lexer=lexer, - line_numbers=False, - theme="solarized-light", - ) - return Panel(renderable=content, title=title) - - def get_tree_rich( - self, compact: bool = False, include_metadata: bool = False - ) -> Tree: - """ - Get a rich tree representing the tree underlying the DAG (obtained - by following the first parent of each commit). - """ - assert Config.has_rich - if self.head is None: - return Tree(label="DAG(head=None)") - tree_neighbors = self._get_tree_neighbors_representation() - tree_objs = {} - initial_commit = self.initial_commit - - result = Tree( - label=self.get_commit_content_rich( - initial_commit, include_metadata=include_metadata - ) - ) - tree_objs[initial_commit] = result - - def grow(commit: str): - if commit in tree_neighbors: - for child in tree_neighbors[commit]: - current_tree = tree_objs[commit] - new_tree = current_tree.add( - self.get_commit_content_rich( - child, - diff_only=compact, - include_metadata=include_metadata, - ) - ) - tree_objs[child] = new_tree - grow(child) - - grow(initial_commit) - return result - - def __repr__(self) -> str: - num_content = len(self.commits) - num_semantic = len(set(c.semantic_hash for c in self.commits.values())) - return f"DAG(head={self.head}) with {num_content} content version(s) and {num_semantic} semantic version(s)" - - def show( - self, compact: bool = False, plain: bool = False, include_metadata: bool = False - ): - if Config.has_rich and not plain: - rich.print( - self.get_tree_rich(compact=compact, include_metadata=include_metadata) - ) - return - else: - if self.head is None: - return "DAG(head=None)" - commits = list(self.commits.keys()) - commits = [self.head] + [k for k in commits if k != self.head] - lines = [] - lines.append( - self.get_commit_presentation( - commit=self.initial_commit, - diff_only=compact, - include_metadata=include_metadata, - )[0] - ) - lines.append("--------") - tree_neighbors = self._get_tree_neighbors_representation() - - def visit(commit: str, depth: int): - if commit in tree_neighbors: - for child in tree_neighbors[commit]: - child_text = self.get_commit_presentation( - commit=child, - diff_only=compact, - include_metadata=include_metadata, - )[0] - child_text = textwrap.indent(child_text, " " * (depth + 1)) - lines.append(child_text) - lines.append("--------") - visit(child, depth + 1) - - visit(self.initial_commit, 0) - print("\n".join(lines)) - - @property - def size(self) -> int: - return len(self.commits) - - @property - def semantic_size(self) -> int: - return len(set(c.semantic_hash for c in self.commits.values())) diff --git a/mandala/_next/deps/tracers/__init__.py b/mandala/_next/deps/tracers/__init__.py deleted file mode 100644 index 5211fa9..0000000 --- a/mandala/_next/deps/tracers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tracer_base import TracerABC -from .dec_impl import DecTracer -from .sys_impl import SysTracer diff --git a/mandala/_next/deps/tracers/dec_impl.py b/mandala/_next/deps/tracers/dec_impl.py deleted file mode 100644 index 4de9a3a..0000000 --- a/mandala/_next/deps/tracers/dec_impl.py +++ /dev/null @@ -1,271 +0,0 @@ -import types -from functools import wraps, update_wrapper -from ...common_imports import * -from ...utils import unwrap_decorators -from ...config import Config -from ..model import ( - DependencyGraph, - CallableNode, - GlobalVarNode, - TerminalData, - TerminalNode, -) -from ..utils import ( - get_global_names_candidates, - is_global_val, - extract_code, - extract_func_obj, - DepKey, - GlobalsStrictness, -) -from .tracer_base import TracerABC, get_module_flow, KEEP, get_closure_names - - -class DecTracerConfig: - allow_class_tracking = True - restrict_global_accesses = True - allow_owned_class_accesses = False - allow_nonfunc_attributes = False - - -class TracerState: - tracer: Optional["DecTracer"] = None - registry: Dict[DepKey, Any] = {} - - @staticmethod - def is_tracked(f: Union[types.FunctionType, type]) -> bool: - assert isinstance(f, (types.FunctionType, type)) - dep_key = (f.__module__, f.__qualname__) - return dep_key in TracerState.registry and TracerState.registry[dep_key] is f - - -class TrackedDict(dict): - """ - A dictionary that tracks global variable accesses. - """ - def __init__(self, original: dict): - self.__original__ = original - - def __getitem__(self, __key: str) -> Any: - result = self.__original__.__getitem__(__key) - if TracerState.tracer is not None: - tracer = TracerState.tracer - unwrapped_result = unwrap_decorators(result, strict=False) - if isinstance(unwrapped_result, (types.FunctionType, type)): - is_owned = tracer.is_owned_obj(obj=unwrapped_result) - is_cls_access = isinstance(result, type) - if ( - is_owned - and is_cls_access - and not DecTracerConfig.allow_owned_class_accesses - ): - raise ValueError( - f"Attempting to access class {result} from module {unwrapped_result.__module__}." - ) - is_tracked = TracerState.is_tracked(unwrapped_result) - if is_owned and not is_tracked: - raise ValueError( - f"Function/class {result} from module {unwrapped_result.__module__} is accessed but not tracked" - ) - elif is_global_val(result): - TracerState.tracer.register_global_access(key=__key, value=result) - else: - if ( - DecTracerConfig.restrict_global_accesses - and not GlobalsStrictness.is_excluded(result) - ): - raise ValueError( - f"Accessing global value {result} of type {type(result)} is not allowed" - ) - return result - - -def make_tracked_copy(f: types.FunctionType) -> types.FunctionType: - result = types.FunctionType( - code=f.__code__, - globals=TrackedDict(f.__globals__), - name=f.__name__, - argdefs=f.__defaults__, - closure=f.__closure__, - ) - result = update_wrapper(result, f) - result.__module__ = f.__module__ - result.__kwdefaults__ = copy.deepcopy(f.__kwdefaults__) - result.__annotations__ = copy.deepcopy(f.__annotations__) - return result - - -def get_nonfunc_attributes(cls: type) -> Dict[str, Any]: - result = {} - for k, v in cls.__dict__.items(): - if not k.startswith("__") and not isinstance( - unwrap_decorators(v, strict=False), (types.FunctionType, type) - ): - result[k] = v - return result - - -def track(obj: Union[types.FunctionType, type]) -> "obj": - if isinstance(obj, type): - if not DecTracerConfig.allow_class_tracking: - raise ValueError("Class tracking is not allowed") - if not DecTracerConfig.allow_nonfunc_attributes: - nonfunc_attributes = get_nonfunc_attributes(obj) - if len(nonfunc_attributes) > 0: - raise ValueError( - f"Class tracking for {obj} is not allowed: found non-function attributes {nonfunc_attributes}" - ) - # decorate all methods/classes in the class - for k, v in obj.__dict__.items(): - if isinstance(v, (types.FunctionType, type)): - setattr(obj, k, track(v)) - TracerState.registry[(obj.__module__, obj.__qualname__)] = obj - return obj - elif isinstance(obj, types.FunctionType): - obj = make_tracked_copy(unwrap_decorators(obj, strict=True)) - - @wraps(obj) - def wrapper(*args, **kwargs) -> Any: - tracer = DecTracer.get_active_trace_obj() - if tracer is not None: - node = tracer.register_call(func=obj) - outcome = obj(*args, **kwargs) - tracer.register_return(node) - return outcome - else: - return obj(*args, **kwargs) - - TracerState.registry[(obj.__module__, obj.__qualname__)] = unwrap_decorators( - obj, strict=True - ) - return wrapper - elif type(obj).__name__ == Config.func_interface_cls_name: - obj.func_op.func = make_tracked_copy(f=obj.func_op.func) - TracerState.registry[ - (obj.func_op.func.__module__, obj.func_op.func.__qualname__) - ] = obj.func_op.func - return obj - else: - raise TypeError("Can only track callable objects") - - -class DecTracer(TracerABC): - """ - A decorator-based tracer that tracks function calls and global variable accesses. - """ - def __init__( - self, - paths: List[Path], - graph: Optional[DependencyGraph] = None, - strict: bool = True, - allow_methods: bool = False, - ): - self.call_stack: List[CallableNode] = [] - self.graph = DependencyGraph() if graph is None else graph - self.paths = paths - self.strict = strict - self.allow_methods = allow_methods - - self._traced = {} - self._traced_funcs = {} - - def is_owned_obj(self, obj: Union[types.FunctionType, type]) -> bool: - module_name = obj.__module__ - return get_module_flow(module_name=module_name, paths=self.paths) == KEEP - - @staticmethod - def get_active_trace_obj() -> Optional["DecTracer"]: - return TracerState.tracer - - @staticmethod - def set_active_trace_obj(trace_obj: Optional["DecTracer"]): - if trace_obj is not None and TracerState.tracer is not None: - raise ValueError("Tracer already active") - TracerState.tracer = trace_obj - - def get_globals(self, func: Callable) -> List[GlobalVarNode]: - result = [] - code_obj = extract_code(obj=func) - global_scope = extract_func_obj(obj=func, strict=self.strict).__globals__ - for name in get_global_names_candidates(code=code_obj): - # names used by the function; not all of them are global variables - if name in global_scope.keys(): - global_val = global_scope[name] - if not is_global_val(global_val): - continue - node = GlobalVarNode.from_obj( - obj=global_val, dep_key=(func.__module__, name) - ) - result.append(node) - return result - - def register_call(self, func: Callable) -> CallableNode: - module_name = func.__module__ - qualname = func.__qualname__ - # check for closure variables - closure_names = get_closure_names( - code_obj=func.__code__, func_qualname=qualname - ) - if len(closure_names) > 0: - msg = f"Found closure variables accessed by function {module_name}.{qualname}:\n{closure_names}" - raise ValueError(msg) - ### get call node - node = CallableNode.from_runtime( - module_name=module_name, obj_name=qualname, code_obj=extract_code(obj=func) - ) - self.call_stack.append(node) - self.graph.add_node(node) - if len(self.call_stack) > 1: - parent = self.call_stack[-2] - assert parent is not None - self.graph.add_edge(parent, node) - ### get globals - global_nodes = self.get_globals(func=func) - for global_node in global_nodes: - self.graph.add_edge(node, global_node) - if len(self.call_stack) == 1: - # this is the root of the graph - self.graph.roots.add(node.key) - return node - - def register_global_access(self, key: str, value: Any): - assert len(self.call_stack) > 0 - calling_node = self.call_stack[-1] - node = GlobalVarNode.from_obj( - obj=value, dep_key=(calling_node.module_name, key) - ) - self.graph.add_edge(calling_node, node) - - def register_return(self, node: CallableNode): - assert self.call_stack[-1] == node - self.call_stack.pop() - - @staticmethod - def register_leaf_event(trace_obj: "DecTracer", data: TerminalData): - unique_id = "_".join( - [ - data.op_internal_name, - str(data.op_version), - data.call_content_version, - data.call_semantic_version, - ] - ) - module_name = data.dep_key[0] - node = TerminalNode( - module_name=module_name, obj_name=unique_id, representation=data - ) - if len(trace_obj.call_stack) > 0: - trace_obj.graph.add_edge(trace_obj.call_stack[-1], node) - return - - @staticmethod - def leaf_signal(data): - # a way to detect the end of a trace - raise NotImplementedError - - def __enter__(self): - DecTracer.set_active_trace_obj(self) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - DecTracer.set_active_trace_obj(None) diff --git a/mandala/_next/deps/tracers/sys_impl.py b/mandala/_next/deps/tracers/sys_impl.py deleted file mode 100644 index c3b0d05..0000000 --- a/mandala/_next/deps/tracers/sys_impl.py +++ /dev/null @@ -1,243 +0,0 @@ -import types -from ...common_imports import * -from ..utils import ( - get_func_qualname, - is_global_val, - get_global_names_candidates, -) -from ..model import ( - DependencyGraph, - CallableNode, - TerminalData, - TerminalNode, - GlobalVarNode, -) -import sys -import importlib -from .tracer_base import TracerABC, get_closure_names - -################################################################################ -### tracer -################################################################################ -# control flow constants -from .tracer_base import BREAK, CONTINUE, KEEP, MAIN, get_module_flow - -LEAF_SIGNAL = "leaf_signal" - -# constants for special Python function names -LAMBDA = "" -COMPREHENSIONS = ("", "", "", "") -SKIP_FRAMES = tuple(list(COMPREHENSIONS) + [LAMBDA]) - - -class SysTracer(TracerABC): - def __init__( - self, - paths: List[Path], - graph: Optional[DependencyGraph] = None, - strict: bool = True, - allow_methods: bool = False, - ): - self.call_stack: List[Optional[CallableNode]] = [] - self.graph = DependencyGraph() if graph is None else graph - self.paths = paths - self.path_strs = [str(path) for path in paths] - self.strict = strict - self.allow_methods = allow_methods - - @staticmethod - def leaf_signal(data): - # a way to detect the end of a trace - pass - - @staticmethod - def register_leaf_event(trace_obj: types.FunctionType, data: Any): - SysTracer.leaf_signal(data) - - @staticmethod - def get_active_trace_obj() -> Optional[Any]: - return sys.gettrace() - - @staticmethod - def set_active_trace_obj(trace_obj: Any): - sys.settrace(trace_obj) - - def _process_failure(self, msg: str): - if self.strict: - raise RuntimeError(msg) - else: - logger.warning(msg) - - def find_most_recent_call(self) -> Optional[CallableNode]: - if len(self.call_stack) == 0: - return None - else: - # return the most recent non-None obj on the stack - for i in range(len(self.call_stack) - 1, -1, -1): - call = self.call_stack[i] - if isinstance(call, CallableNode): - return call - return None - - def __enter__(self): - if sys.gettrace() is not None: - # pre-check this is used correctly - raise RuntimeError("Another tracer is already active") - - def tracer(frame: types.FrameType, event: str, arg: Any): - if event not in ("call", "return"): - return - module_name = frame.f_globals.get("__name__") - # fast check to rule out non-user code - if event == "call": - try: - module = importlib.import_module(module_name) - if not any( - [ - module.__file__.startswith(path_str) - for path_str in self.path_strs - ] - ): - return - except: - if module_name != MAIN: - return - code_obj = frame.f_code - func_name = code_obj.co_name - if event == "return": - logging.debug(f"Returning from {func_name}") - if len(self.call_stack) > 0: - popped = self.call_stack.pop() - logging.debug(f"Popped {popped} from call stack") - # some sanity checks - if func_name in SKIP_FRAMES: - if popped != func_name: - self._process_failure( - f"Expected to pop {func_name} from call stack, but popped {popped}" - ) - else: - if popped.obj_name.split(".")[-1] != func_name: - self._process_failure( - f"Expected to pop {func_name} from call stack, but popped {popped.obj_name}" - ) - else: - # something went wrong - raise RuntimeError("Call stack is empty") - return - - if func_name == LEAF_SIGNAL: - data: TerminalData = frame.f_locals["data"] - unique_id = "_".join( - [ - data.op_internal_name, - str(data.op_version), - data.call_content_version, - data.call_semantic_version, - ] - ) - node = TerminalNode( - module_name=module_name, obj_name=unique_id, representation=data - ) - most_recent_option = self.find_most_recent_call() - if most_recent_option is not None: - self.graph.add_edge(source=most_recent_option, target=node) - # self.call_stack.append(None) - return - - module_control_flow = get_module_flow( - paths=self.paths, module_name=module_name - ) - if module_control_flow in (BREAK, CONTINUE): - frame.f_trace = None - return - - logger.debug(f"Tracing call to {module_name}.{func_name}") - - ### get the qualified name of the function/method - func_qualname = get_func_qualname( - func_name=func_name, code=code_obj, frame=frame - ) - if "." in func_qualname: - if not self.allow_methods: - raise RuntimeError( - f"Methods are currently not supported: {func_qualname} from {module_name}" - ) - - ### detect use of closure variables - closure_names = get_closure_names( - code_obj=code_obj, func_qualname=func_qualname - ) - if len(closure_names) > 0 and func_name not in SKIP_FRAMES: - closure_values = { - var: frame.f_locals.get(var, frame.f_globals.get(var, None)) - for var in closure_names - } - msg = f"Found closure variables accessed by function {module_name}.{func_name}:\n{closure_values}" - self._process_failure(msg=msg) - - ### get the global variables used by the function - globals_nodes = [] - for name in get_global_names_candidates(code=code_obj): - # names used by the function; not all of them are global variables - if name in frame.f_globals: - global_val = frame.f_globals[name] - if not is_global_val(global_val): - continue - node = GlobalVarNode.from_obj( - obj=global_val, dep_key=(module_name, name) - ) - globals_nodes.append(node) - - ### if this is a comprehension call, add the globals to the most - ### recent tracked call - if func_name in SKIP_FRAMES: - most_recent_tracked_call = self.find_most_recent_call() - assert most_recent_tracked_call is not None - for global_node in globals_nodes: - self.graph.add_edge( - source=most_recent_tracked_call, target=global_node - ) - self.call_stack.append(func_name) - return tracer - - ### manage the call stack - call_node = CallableNode.from_runtime( - module_name=module_name, obj_name=func_qualname, code_obj=code_obj - ) - self.graph.add_node(node=call_node) - ### global variable edges from this function always exist - for global_node in globals_nodes: - self.graph.add_edge(source=call_node, target=global_node) - ### call edges exist only if there is a caller on the stack - if len(self.call_stack) > 0: - # find the most recent tracked call - most_recent_tracked_call = self.find_most_recent_call() - if most_recent_tracked_call is not None: - self.graph.add_edge( - source=most_recent_tracked_call, target=call_node - ) - self.call_stack.append(call_node) - if len(self.call_stack) == 1: - self.graph.roots.add(call_node.key) - return tracer - - sys.settrace(tracer) - - def __exit__(self, *exc_info): - sys.settrace(None) # Stop tracing - - -class SuspendSysTraceContext: - def __init__(self): - self.suspended_trace = None - - def __enter__(self) -> "SuspendSysTraceContext": - if sys.gettrace() is not None: - self.suspended_trace = sys.gettrace() - sys.settrace(None) - return self - - def __exit__(self, *exc_info): - if self.suspended_trace is not None: - sys.settrace(self.suspended_trace) - self.suspended_trace = None diff --git a/mandala/_next/deps/tracers/tracer_base.py b/mandala/_next/deps/tracers/tracer_base.py deleted file mode 100644 index 0081ca4..0000000 --- a/mandala/_next/deps/tracers/tracer_base.py +++ /dev/null @@ -1,91 +0,0 @@ -from ...common_imports import * -from ...config import Config -import importlib -from ..model import DependencyGraph, CallableNode -from abc import ABC, abstractmethod - - -class TracerABC(ABC): - def __init__( - self, - paths: List[Path], - strict: bool = True, - allow_methods: bool = False, - ): - self.call_stack: List[Optional[CallableNode]] = [] - self.graph = DependencyGraph() - self.paths = paths - self.strict = strict - self.allow_methods = allow_methods - - @abstractmethod - def __enter__(self): - raise NotImplementedError - - @abstractmethod - def __exit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError - - @staticmethod - @abstractmethod - def get_active_trace_obj() -> Optional[Any]: - raise NotImplementedError - - @staticmethod - @abstractmethod - def set_active_trace_obj(trace_obj: Any): - raise NotImplementedError - - @staticmethod - @abstractmethod - def register_leaf_event(trace_obj: Any, data: Any): - raise NotImplementedError - - -BREAK = "break" # stop tracing (currently doesn't really work b/c python) -CONTINUE = "continue" # continue tracing, but don't add call to dependencies -KEEP = "keep" # continue tracing and add call to dependencies -MAIN = "__main__" - - -def get_closure_names(code_obj: types.CodeType, func_qualname: str) -> Tuple[str]: - closure_vars = code_obj.co_freevars - if "." in func_qualname and "__class__" in closure_vars: - closure_vars = tuple([var for var in closure_vars if var != "__class__"]) - return closure_vars - - -def get_module_flow(module_name: Optional[str], paths: List[Path]) -> str: - if module_name is None: - return BREAK - if module_name == MAIN: - return KEEP - try: - module = importlib.import_module(module_name) - is_importable = True - except ModuleNotFoundError: - is_importable = False - if not is_importable: - return BREAK - try: - module_path = Path(inspect.getfile(module)) - except TypeError: - # this happens when the module is a built-in module - return BREAK - if ( - not any(root in module_path.parents for root in paths) - and module_path not in paths - ): - # module is not in the paths we're inspecting; stop tracing - logger.debug(f" Module {module_name} not in paths, BREAK") - return BREAK - elif module_name.startswith(Config.module_name) and not module_name.startswith( - Config.tests_module_name - ): - # this function is part of `mandala` functionality. Continue tracing - # but don't add it to the dependency state - logger.debug(f" Module {module_name} is mandala, CONTINUE") - return CONTINUE - else: - logger.debug(f" Module {module_name} is not mandala but in paths, KEEP") - return KEEP diff --git a/mandala/_next/deps/utils.py b/mandala/_next/deps/utils.py deleted file mode 100644 index bad7405..0000000 --- a/mandala/_next/deps/utils.py +++ /dev/null @@ -1,228 +0,0 @@ -import types -import dis -import importlib -import gc -from typing import Literal - -from ..common_imports import * -from ..utils import get_content_hash, unwrap_decorators -from ..config import Config - -DepKey = Tuple[str, str] # (module name, object address in module) - - -class GlobalsStrictness: - SCALARS = "scalars" - DATA = "data" - ALL = "all" - - @staticmethod - def is_excluded(obj: Any) -> bool: - return ( - inspect.ismodule(obj) # exclude modules - or isinstance(obj, type) # exclude classes - or inspect.isfunction(obj) # exclude functions - or callable(obj) # exclude callables - or type(obj).__name__ - == Config.func_interface_cls_name #! a hack to exclude memoized functions - ) - - @staticmethod - def is_scalar(obj: Any) -> bool: - result = isinstance(obj, (int, float, str, bool, type(None))) - return result - - @staticmethod - def is_data(obj: Any) -> bool: - if GlobalsStrictness.is_scalar(obj): - result = True - elif type(obj) in (tuple, list): - result = all(GlobalsStrictness.is_data(x) for x in obj) - elif type(obj) is dict: - result = all(GlobalsStrictness.is_data((x, y)) for (x, y) in obj.items()) - elif type(obj) in (np.ndarray, pd.DataFrame, pd.Series, pd.Index): - result = True - else: - result = False - # if not result and not GlobalsStrictness.is_callable(obj): - # logger.warning(f'Access to global variable "{obj}" is not tracked because it is not a scalar or a data structure') - return result - - -def is_global_val(obj: Any, strictness: str = "data") -> bool: - if strictness == GlobalsStrictness.SCALARS: - return GlobalsStrictness.is_scalar(obj=obj) - elif strictness == GlobalsStrictness.DATA: - return GlobalsStrictness.is_data(obj=obj) - elif strictness == GlobalsStrictness.ALL: - return not ( - inspect.ismodule(obj) # exclude modules - or isinstance(obj, type) # exclude classes - or inspect.isfunction(obj) # exclude functions - or callable(obj) # exclude callables - or type(obj).__name__ - == Config.func_interface_cls_name #! a hack to exclude memoized functions - ) - else: - raise ValueError( - f"Unknown strictness level for tracking global variables: {strictness}" - ) - - -def is_callable_obj(obj: Any, strict: bool) -> bool: - if type(obj).__name__ == Config.func_interface_cls_name: - return True - if isinstance(obj, types.FunctionType): - return True - if not strict and callable(obj): # quite permissive - return True - return False - - -def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType: - if type(obj).__name__ == Config.func_interface_cls_name: - return obj.f - obj = unwrap_decorators(obj, strict=strict) - if isinstance(obj, types.BuiltinFunctionType): - raise ValueError(f"Expected a non-built-in function, but got {obj}") - if not isinstance(obj, types.FunctionType): - if not strict: - if ( - isinstance(obj, type) - and hasattr(obj, "__init__") - and isinstance(obj.__init__, types.FunctionType) - ): - return obj.__init__ - else: - return unknown_function - else: - raise ValueError(f"Expected a function, but got {obj} of type {type(obj)}") - return obj - - -def extract_code(obj: Callable) -> types.CodeType: - if type(obj).__name__ == Config.func_interface_cls_name: - obj = obj.f - if isinstance(obj, property): - obj = obj.fget - obj = unwrap_decorators(obj, strict=True) - if not isinstance(obj, (types.FunctionType, types.MethodType)): - logger.debug(f"Expected a function or method, but got {type(obj)}") - # raise ValueError(f"Expected a function or method, but got {obj}") - return obj.__code__ - - -def get_runtime_description(code: types.CodeType) -> Any: - assert isinstance(code, types.CodeType) - return get_sanitized_bytecode_representation(code=code) - - -def get_global_names_candidates(code: types.CodeType) -> Set[str]: - result = set() - instructions = list(dis.get_instructions(code)) - for instr in instructions: - if instr.opname == "LOAD_GLOBAL": - result.add(instr.argval) - if isinstance(instr.argval, types.CodeType): - result.update(get_global_names_candidates(instr.argval)) - return result - - -def get_sanitized_bytecode_representation( - code: types.CodeType, -) -> List[dis.Instruction]: - instructions = list(dis.get_instructions(code)) - result = [] - for instr in instructions: - if isinstance(instr.argval, types.CodeType): - result.append( - dis.Instruction( - instr.opname, - instr.opcode, - instr.arg, - get_sanitized_bytecode_representation(instr.argval), - "", - instr.offset, - instr.starts_line, - is_jump_target=instr.is_jump_target, - ) - ) - else: - result.append(instr) - return result - - -def unknown_function(): - # this is a placeholder function that we use to get the source of - # functions that we can't get the source of - pass - - -UNKNOWN_GLOBAL_VAR = "UNKNOWN_GLOBAL_VAR" - - -def get_bytecode(f: Union[types.FunctionType, types.CodeType, str]) -> str: - if isinstance(f, str): - f = compile(f, "", "exec") - instructions = dis.get_instructions(f) - return "\n".join([str(i) for i in instructions]) - - -def hash_dict(d: dict) -> str: - return get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())]) - - -def load_obj(module_name: str, obj_name: str) -> Tuple[Any, bool]: - module = importlib.import_module(module_name) - parts = obj_name.split(".") - current = module - found = True - for part in parts: - if not hasattr(current, part): - found = False - break - else: - current = getattr(current, part) - return current, found - - -def get_dep_key_from_func(func: types.FunctionType) -> DepKey: - module_name = func.__module__ - qualname = func.__qualname__ - return module_name, qualname - - -def get_func_qualname( - func_name: str, - code: types.CodeType, - frame: types.FrameType, -) -> str: - # this is evil - referrers = gc.get_referrers(code) - func_referrers = [r for r in referrers if isinstance(r, types.FunctionType)] - matching_name = [r for r in func_referrers if r.__name__ == func_name] - if len(matching_name) != 1: - return get_func_qualname_fallback(func_name=func_name, code=code, frame=frame) - else: - return matching_name[0].__qualname__ - - -def get_func_qualname_fallback( - func_name: str, code: types.CodeType, frame: types.FrameType -) -> str: - # get the argument names to *try* to tell if the function is a method - arg_names = code.co_varnames[: code.co_argcount] - # a necessary but not sufficient condition for this to - # be a method - is_probably_method = ( - len(arg_names) > 0 - and arg_names[0] == "self" - and hasattr(frame.f_locals["self"].__class__, func_name) - ) - if is_probably_method: - # handle nested classes via __qualname__ - cls_qualname = frame.f_locals["self"].__class__.__qualname__ - func_qualname = f"{cls_qualname}.{func_name}" - else: - func_qualname = func_name - return func_qualname diff --git a/mandala/_next/deps/versioner.py b/mandala/_next/deps/versioner.py deleted file mode 100644 index 3552080..0000000 --- a/mandala/_next/deps/versioner.py +++ /dev/null @@ -1,596 +0,0 @@ -import typing -from collections import OrderedDict -import textwrap - -from ..common_imports import * -from ..utils import is_subdict -from ..config import Config -from .utils import DepKey, hash_dict -from .model import ( - Node, - CallableNode, - GlobalVarNode, - TerminalNode, - DependencyGraph, -) -from .crawler import crawl_static -from .tracers import TracerABC -from ..viz import _get_colorized_diff - -if Config.has_rich: - from rich.panel import Panel - from rich.text import Text - from rich.syntax import Syntax - from rich.console import Group - -from .shallow_versions import DAG -from .deep_versions import Version - - -class CodeState: - def __init__(self, nodes: Dict[DepKey, Node]): - self.nodes = nodes - - def __repr__(self) -> str: - lines = [] - for dep_key, node in self.nodes.items(): - lines.append(f"{dep_key}:") - lines.append(f"{node.content()}") - return "\n".join(lines) - - def get_content_version(self, support: Iterable[DepKey]) -> str: - return hash_dict({k: self.nodes[k].content_hash for k in support}) - - def add_globals_from(self, graph: DependencyGraph): - for node in graph.nodes.values(): - if isinstance(node, GlobalVarNode) and node.key not in self.nodes: - self.nodes[node.key] = node - - -class Versioner: - def __init__( - self, - paths: List[Path], - TracerCls: type, - strict, - track_methods, - package_name: Optional[str] = None, - ): - assert len(paths) in [0, 1] - self.paths = paths - self.TracerCls = TracerCls - self.strict = strict - self.allow_methods = track_methods - self.package_name = package_name - self.global_topology: DependencyGraph = DependencyGraph() - self.nodes: Dict[DepKey, Node] = {} - self.component_dags: Dict[DepKey, DAG] = {} - # all versions here must be synced with the DAGs already - self.versions: Dict[DepKey, Dict[str, Version]] = {} - columns = [ - "pre_call_uid", - "semantic_version", - "content_version", - "outputs", - ] - self.df = pd.DataFrame(columns=columns) - - def get_version_ids( - self, - pre_call_uid: str, - tracer_option: Optional[TracerABC], - is_recompute: bool, - ) -> Tuple[Optional[str], Optional[str]]: - """ - Get the content and semantic IDs for the version corresponding to the - given pre-call uid. - - Inputs: - - `is_recompute`: this should be true only if this is a call with - transient outputs that we already computed once. - """ - assert tracer_option is not None - version = self.process_trace( - graph=tracer_option.graph, - pre_call_uid=pre_call_uid, - outputs=None, - is_recompute=is_recompute, - ) - content_version = version.content_version - semantic_version = version.semantic_version - return content_version, semantic_version - - def update_global_topology(self, graph: DependencyGraph): - for node in graph.nodes.values(): - if isinstance(node, (CallableNode, GlobalVarNode)): - self.global_topology.add_node(node) - for edge in graph.edges: - if ( - edge[0] in self.global_topology.nodes.keys() - and edge[1] in self.global_topology.nodes.keys() - ): - self.global_topology.edges.add(edge) - - def make_tracer(self) -> TracerABC: - return self.TracerCls( - paths=[Config.mandala_path] + self.paths, - strict=self.strict, - allow_methods=self.allow_methods, - ) - - def guess_code_state(self) -> CodeState: - result_graph = DependencyGraph() - fallback_result = {} - for dep_key in self.global_topology.nodes.keys(): - node = self.global_topology.nodes[dep_key] - if ( - isinstance(node, (GlobalVarNode, CallableNode)) - and dep_key not in result_graph.nodes.keys() - ): - obj = node.load_obj(allow_fallback=not self.strict) - fallback_result[dep_key] = node.from_obj(obj=obj, dep_key=dep_key) - nodes = {**result_graph.nodes, **fallback_result} - # fill in the gaps with a static crawl - static_result, objs = crawl_static( - root=None if len(self.paths) == 0 else self.paths[0], - strict=self.strict, - package_name=self.package_name, - include_methods=self.allow_methods, - ) - for dep_key, node in static_result.items(): - if dep_key not in nodes.keys(): - nodes[dep_key] = node - result = CodeState(nodes=nodes) - return result - - def get_codestate_semantic_hashes( - self, - code_state: CodeState, - ) -> Optional[Dict[DepKey, str]]: - """ - Given a code state, return the semantic hashes of the components found - in this code state that *also* appear in the global component topology, - or None if the code state is not fully compatible with the commits we - have in the DAGs. - """ - result = {} - if not self.global_topology.nodes.keys() <= code_state.nodes.keys(): - extra_keys = self.global_topology.nodes.keys() - code_state.nodes.keys() - raise ValueError( - f"Found extra keys in global topology not in code state: {extra_keys}." - ) - for component in code_state.nodes.keys(): - if component in self.global_topology.nodes.keys(): - component_content_hash = code_state.nodes[component].content_hash - dag = self.component_dags[component] - if component_content_hash not in dag.commits: - print( - f"Could not find commit for {component} with content hash {component_content_hash}" - ) - return None - result[component] = dag.commits[component_content_hash].semantic_hash - return result - - def apply_state_hypothesis( - self, hypothesis: CodeState, trace_result: Dict[DepKey, Node] - ): - keys_to_remove = [] - for trace_dep_key, trace_node in trace_result.items(): - if isinstance(trace_node, TerminalNode): - continue - if trace_dep_key not in hypothesis.nodes.keys() and isinstance( - trace_node, GlobalVarNode - ): - continue - if trace_dep_key in hypothesis.nodes.keys(): - if isinstance(trace_node, GlobalVarNode): - if ( - not hypothesis.nodes[trace_dep_key].content() - == trace_node.content() - ): - print(f"Content mismatch for {trace_dep_key}") - print(f"Expected: {hypothesis.nodes[trace_dep_key].content()}") - print(f"Actual: {trace_node.content()}") - print( - f"Diff:\n{_get_colorized_diff(hypothesis.nodes[trace_dep_key].content(), trace_node.content())}" - ) - raise ValueError(f"Content mismatch for {trace_dep_key}") - elif isinstance(trace_node, CallableNode): - runtime_description_hypothesis = hypothesis.nodes[ - trace_dep_key - ].runtime_description - if runtime_description_hypothesis != trace_node.runtime_description: - if self.strict: - raise ValueError( - f"Bytecode mismatch for {trace_dep_key} (you probably need to re-import the module)" - ) - trace_node.representation = hypothesis.nodes[ - trace_dep_key - ].representation - else: - continue - else: - if self.strict: - raise ValueError( - f"Unexpected dependency {trace_dep_key}; expected {textwrap.shorten(str(hypothesis.nodes.keys()), width=80)}" - ) - else: - keys_to_remove.append(trace_dep_key) - for key in keys_to_remove: - del trace_result[key] - - def get_semantic_version( - self, semantic_hashes: Dict[DepKey, str], support: Iterable[DepKey] - ) -> str: - return hash_dict({k: semantic_hashes[k] for k in support}) - - def init_component(self, component: DepKey, node: Node, initial_content: str): - """ - Initialize a new component with an initial state. - """ - if isinstance(node, CallableNode): - content_type = "code" - elif isinstance(node, GlobalVarNode): - content_type = "global_variable" - else: - raise ValueError(f"Unexpected node type {type(node)}") - dag = DAG(content_type=content_type) - dag.init(initial_content=initial_content) - self.nodes[component] = node - self.component_dags[component] = dag - self.versions[component] = {} - - def sync_codebase(self, code_state: CodeState): - """ - Sync all the known components from the current state of the codebase. - """ - dags = copy.deepcopy(self.component_dags) - for component, dag in dags.items(): - content = code_state.nodes[component].content() - content_hash = code_state.nodes[component].content_hash - if content_hash not in dag.commits.keys() and dag.head is not None: - dependent_versions = self.get_dependent_versions( - dep_key=component, commit=dag.head - ) - dependent_versions_presentation = textwrap.indent( - text="\n".join([v.presentation for v in dependent_versions]), - prefix=" ", - ) - print(f"CHANGE DETECTED in {component[1]} from module {component[0]}") - print(f"Dependent components:\n{dependent_versions_presentation}") - print(f"===DIFF===:") - dag.sync(content=content) - # update the DAGs if all commits succeeded - self.component_dags = dags - - def sync_component( - self, - component: DepKey, - is_semantic_change: Optional[bool], - code_state: CodeState, - ) -> str: - """ - Sync a single component from the current state of the codebase. Useful - as a low-level API for testing. - """ - commit = self.component_dags[component].sync( - content=code_state.nodes[component].content(), - is_semantic_change=is_semantic_change, - ) - return commit - - def get_current_versions( - self, component: DepKey, code_state: CodeState - ) -> List[Version]: - code_semantic_hashes = self.get_codestate_semantic_hashes(code_state=code_state) - result = [] - if code_semantic_hashes is None: - return result - for _, version in self.versions[component].items(): - if is_subdict(version.semantic_expansion, code_semantic_hashes): - result.append(version) - return result - - def get_semantically_compatible_versions( - self, component: DepKey, code_state: CodeState - ) -> List[Version]: - code_semantic_hashes = self.get_codestate_semantic_hashes(code_state=code_state) - if code_semantic_hashes is None: - return [] - result = [] - for version in self.versions[component].values(): - if all( - [ - version.semantic_expansion[dep_key] == code_semantic_hashes[dep_key] - for dep_key in version.semantic_expansion.keys() - ] - ): - result.append(version) - return result - - ############################################################################ - ### processing traces - ############################################################################ - def create_new_components_from_nodes(self, nodes: Dict[DepKey, Node]): - """ - Given the result of a trace, create any components necessary. - """ - ### new components must be found among the nodes in the trace result - for dep_key, node in nodes.items(): - if dep_key not in self.nodes and not isinstance(node, TerminalNode): - content = node.content() - self.init_component( - component=dep_key, node=node, initial_content=content - ) - - def sync_version(self, version: Version, require_exists: bool = False) -> Version: - # TODO - this is impure - version.sync(component_dags=self.component_dags, all_versions=self.versions) - if version.content_version not in self.versions[version.component]: - if require_exists: - raise ValueError(f"Version {version} does not exist in VersioningState") - # logging.info(f'Adding new version for {version.component}') - self.versions[version.component][version.content_version] = version - return version - - def lookup_call( - self, component: DepKey, pre_call_uid: str, code_state: CodeState - ) -> Optional[Tuple[str, str]]: - """ - Return a tuple of (content_version, semantic_version), or None if the - call is not found. - - Inputs: - - `pre_call_uid`: this is a hash of the content IDs of the inputs, - together with the function's name. - - This works as follows: - - we figure out the semantic hashes (i.e. shallow semantic versions) of - the elements of the code state present in the global topology we have on - record - - we restrict to the records that match the given `pre_call_uid` - - we search among these - """ - codebase_semantic_hashes = self.get_codestate_semantic_hashes( - code_state=code_state - ) - if codebase_semantic_hashes is None: - return None - candidates = self.df[self.df["pre_call_uid"] == pre_call_uid] - if len(candidates) == 0: - return None - else: - content_versions = candidates["content_version"].values.tolist() - semantic_versions = candidates["semantic_version"].values.tolist() - for content_version, semantic_version in zip( - content_versions, semantic_versions - ): - version = self.versions[component][content_version] - codebase_semantic = self.get_semantic_version( - semantic_hashes=codebase_semantic_hashes, support=version.support - ) - if codebase_semantic == semantic_version: - return content_version, semantic_version - return None - - def process_trace( - self, - graph: DependencyGraph, - pre_call_uid: str, - outputs: Any, - is_recompute: bool, - ) -> Version: - component, nodes = graph.get_trace_state() - self.create_new_components_from_nodes(nodes=nodes) - version = Version.from_trace( - component=component, nodes=nodes, strict=self.strict - ) - version = self.sync_version(version=version) - row = { - "pre_call_uid": pre_call_uid, - "semantic_version": version.semantic_version, - "content_version": version.content_version, - "outputs": outputs, - } - if not is_recompute: - self._check_semantic_distinguishability( - component=component, pre_call_uid=pre_call_uid, call_version=version - ) - # logging.info(f"Adding new call for {pre_call_uid} for {component}") - self.df = pd.concat([self.df, pd.DataFrame([row])], ignore_index=True) - return version - - def _check_semantic_distinguishability( - self, component: DepKey, pre_call_uid: str, call_version: Version - ): - ### check semantic distinguishability between calls - # TODO: make this more efficient - candidates = self.df[self.df["pre_call_uid"] == pre_call_uid] - existing_semantic_expansions = {} - for content_version, semantic_version in zip( - candidates["content_version"].values.tolist(), - candidates["semantic_version"].values.tolist(), - ): - if semantic_version not in existing_semantic_expansions.keys(): - existing_call_version = self.versions[component][content_version] - existing_semantic_expansions[ - semantic_version - ] = existing_call_version.semantic_expansion - new_semantic_dep_hashes = call_version.semantic_expansion - for semantic_version, semantic_hashes in existing_semantic_expansions.items(): - overlap = set(semantic_hashes.keys()).intersection( - set(new_semantic_dep_hashes.keys()) - ) - if all([semantic_hashes[k] == new_semantic_dep_hashes[k] for k in overlap]): - raise ValueError( - f"Call to {component} with pre_call_uid={pre_call_uid} is not semantically distinguishable from call for semantic version {semantic_version}" - ) - - ############################################################################ - ### inspecting the state - ############################################################################ - def get_flat_versions(self) -> Dict[str, Version]: - return { - k: v - for component, versions in self.versions.items() - for k, v in versions.items() - } - - def get_dependent_versions(self, dep_key: DepKey, commit: str) -> List[Version]: - """ - Get a list of versions of components dependent on a given commit to a - given component - """ - dep_semantic = self.component_dags[dep_key].commits[commit].semantic_hash - return [ - version - for version in self.get_flat_versions().values() - if version.semantic_expansion.get(dep_key) == dep_semantic - ] - - def present_dependencies( - self, - commits: Dict[DepKey, str], - include_metadata: bool = True, - header: Optional[str] = None, - ) -> str: - """ - Get a code snippet for a given state of some dependencies - """ - result_lines = [] - if header is not None: - result_lines.extend(header.splitlines()) - module_groups = self.get_canonical_groups(components=commits.keys()) - for module_name, components_in_module in module_groups.items(): - result_lines.append(80 * "#") - result_lines.append(f'### IN MODULE "{module_name}"') - result_lines.append(80 * "#") - commits_in_module = {k: commits[k] for k in components_in_module} - nodes = {k: self.nodes[k] for k in commits_in_module.keys()} - semantic_hashes = { - k: self.component_dags[k].commits[v].semantic_hash - for k, v in commits_in_module.items() - } - global_keys = {k for k, v in nodes.items() if isinstance(v, GlobalVarNode)} - callable_keys = {k for k, v in nodes.items() if isinstance(v, CallableNode)} - metadatas = { - k: f"### {nodes[k].present_key()}\n### content_commit={commits_in_module[k]}\n### semantic_commit={semantic_hashes[k]}" - for k in commits_in_module.keys() - } - for global_key in sorted(global_keys): - if include_metadata: - result_lines.append(metadatas[global_key]) - global_name = global_key[1] - result_lines.append( - f"{global_name} = {self.component_dags[global_key].get_presentable_content(commits_in_module[global_key])}" - ) - result_lines.append("") - callable_keys = list(sorted(callable_keys)) - is_class_start = [] - for i, callable_key in enumerate(callable_keys): - this_is_class = "." in callable_key[1] - if i == 0 and this_is_class: - is_class_start.append(True) - continue - prev_is_class = "." in callable_keys[i - 1][1] - if this_is_class and not prev_is_class: - is_class_start.append(True) - continue - if ( - this_is_class - and prev_is_class - and callable_keys[i - 1][1].rsplit(".", maxsplit=1)[0] - != callable_key[1].rsplit(".", maxsplit=1)[0] - ): - is_class_start.append(True) - continue - is_class_start.append(False) - for i, callable_key in enumerate(callable_keys): - if is_class_start[i]: - result_lines.append( - f"### in class {callable_key[1].rsplit('.', 1)[0]}:" - ) - if include_metadata: - result_lines.append(metadatas[callable_key]) - result_lines.append( - self.component_dags[callable_key].get_presentable_content( - commits_in_module[callable_key] - ) - ) - result_lines.append("") - return "\n".join(result_lines) - - def show_versions( - self, - component: DepKey, - only_semantic: bool = False, - include_metadata: bool = True, - plain: bool = False, - ): - versions_dict = self.versions[component] - if only_semantic: - # use just 1 semantic representative per content version - versions = list( - {v.semantic_version: v for v in versions_dict.values()}.values() - ) - else: - versions = list(versions_dict.values()) - if Config.has_rich and not plain: - version_panels: List[Panel] = [] - for version in versions: - header_lines = [ - f"### Dependencies for version of {self.nodes[component].present_key()}" - ] - header_lines.append(f"### content_version_id={version.content_version}") - header_lines.append( - f"### semantic_version_id={version.semantic_version}\n\n" - ) - version_panels.append( - Panel( - Syntax( - self.present_dependencies( - header="\n".join(header_lines), - commits=version.semantic_expansion, - include_metadata=include_metadata, - ), - lexer="python", - theme="solarized-light", - ), - title=None, - expand=True, - ) - ) - rich.print(Group(*version_panels)) - else: - for version in versions: - print(version.presentation) - print( - textwrap.indent( - self.present_dependencies( - commits=version.semantic_expansion, - include_metadata=include_metadata, - ), - prefix=" ", - ) - ) - - def get_canonical_groups( - self, components: Iterable[DepKey] - ) -> typing.OrderedDict[str, List[DepKey]]: - """ - Order components by module name alphabetically, and within each module, - put the global variables first, then the callables. - """ - result = OrderedDict() - for component in components: - module_name = component[0] - if module_name not in result: - result[module_name] = [] - result[module_name].append(component) - for module_name, module_components in result.items(): - result[module_name] = sorted( - module_components, - key=lambda x: (isinstance(self.nodes[x], CallableNode), x), - ) - result = OrderedDict(sorted(result.items(), key=lambda x: x[0])) - return result diff --git a/mandala/_next/deps/viz.py b/mandala/_next/deps/viz.py deleted file mode 100644 index 82907fb..0000000 --- a/mandala/_next/deps/viz.py +++ /dev/null @@ -1,110 +0,0 @@ -import textwrap -from ..common_imports import * -from ..viz import ( - Node as DotNode, - Edge as DotEdge, - Group as DotGroup, - to_dot_string, - SOLARIZED_LIGHT, -) - -from .utils import ( - DepKey, -) - - -def to_string(graph: "model.DependencyGraph") -> str: - """ - Get a string for pretty-printing. - """ - # group the nodes by module - module_groups: Dict[str, List["model.Node"]] = {} - for key, node in graph.nodes.items(): - module_name, _ = key - module_groups.setdefault(module_name, []).append(node) - lines = [] - for module_name, nodes in module_groups.items(): - global_nodes = [node for node in nodes if isinstance(node, model.GlobalVarNode)] - callable_nodes = [ - node for node in nodes if isinstance(node, model.CallableNode) - ] - module_desc = f"MODULE: {module_name}" - lines.append(module_desc) - lines.append("-" * len(module_desc)) - lines.append("===Global Variables===") - for node in global_nodes: - desc = f"{node.obj_name} = {node.readable_content()}" - lines.append(textwrap.indent(desc, 4 * " ")) - # lines.append(f" {node.diff_representation()}") - lines.append("") - lines.append("===Functions===") - # group the methods by class - method_nodes = [node for node in callable_nodes if node.is_method] - func_nodes = [node for node in callable_nodes if not node.is_method] - methods_by_class: Dict[str, List["model.CallableNode"]] = {} - for method_node in method_nodes: - methods_by_class.setdefault(method_node.class_name, []).append(method_node) - for class_name, method_nodes in methods_by_class.items(): - lines.append(textwrap.indent(f"class {class_name}:", 4 * " ")) - for node in method_nodes: - desc = node.readable_content() - lines.append(textwrap.indent(textwrap.dedent(desc), 8 * " ")) - lines.append("") - for node in func_nodes: - desc = node.readable_content() - lines.append(textwrap.indent(textwrap.dedent(desc), 4 * " ")) - lines.append("") - return "\n".join(lines) - - -def to_dot(graph: "model.DependencyGraph") -> str: - nodes: Dict[DepKey, DotNode] = {} - module_groups: Dict[str, DotGroup] = {} # module name -> Group - class_groups: Dict[str, DotGroup] = {} # class name -> Group - for key, node in graph.nodes.items(): - module_name, obj_addr = key - if module_name not in module_groups: - module_groups[module_name] = DotGroup( - label=module_name, nodes=[], parent=None - ) - if isinstance(node, model.GlobalVarNode): - color = SOLARIZED_LIGHT["red"] - elif isinstance(node, model.CallableNode): - color = ( - SOLARIZED_LIGHT["blue"] - if not node.is_method - else SOLARIZED_LIGHT["violet"] - ) - else: - color = SOLARIZED_LIGHT["base03"] - dot_node = DotNode( - internal_name=".".join(key), label=node.obj_name, color=color - ) - nodes[key] = dot_node - module_groups[module_name].nodes.append(dot_node) - if isinstance(node, model.CallableNode) and node.is_method: - class_name = node.class_name - class_groups.setdefault( - class_name, - DotGroup( - label=class_name, - nodes=[], - parent=module_groups[module_name], - ), - ).nodes.append(dot_node) - edges: Dict[Tuple[DotNode, DotNode], DotEdge] = {} - for source, target in graph.edges: - source_node = nodes[source] - target_node = nodes[target] - edge = DotEdge(source_node=source_node, target_node=target_node) - edges[(source_node, target_node)] = edge - dot_string = to_dot_string( - nodes=list(nodes.values()), - edges=list(edges.values()), - groups=list(module_groups.values()) + list(class_groups.values()), - rankdir="BT", - ) - return dot_string - - -from . import model diff --git a/mandala/_next/imports.py b/mandala/_next/imports.py deleted file mode 100644 index 66ff18c..0000000 --- a/mandala/_next/imports.py +++ /dev/null @@ -1,10 +0,0 @@ -from .storage import Storage -from .model import op, Ignore, NewArgDefault -from .tps import MList, MDict -from .deps.tracers.dec_impl import track - -from .common_imports import sess - - -def pprint_dict(d) -> str: - return '\n'.join([f" {k}: {v}" for k, v in d.items()]) \ No newline at end of file diff --git a/mandala/_next/tests/test_cfs.py b/mandala/_next/tests/test_cfs.py deleted file mode 100644 index b75ad16..0000000 --- a/mandala/_next/tests/test_cfs.py +++ /dev/null @@ -1,18 +0,0 @@ -from mandala._next.imports import * - - -def test_single_func(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage: - for i in range(10): - inc(i) - - cf = storage.cf(inc) - df = cf.df() - assert df.shape == (10, 3) - assert (df['output_0'] == df['x'] + 1).all() \ No newline at end of file diff --git a/mandala/_next/tests/test_memoization.py b/mandala/_next/tests/test_memoization.py deleted file mode 100644 index 2f47bda..0000000 --- a/mandala/_next/tests/test_memoization.py +++ /dev/null @@ -1,133 +0,0 @@ -from mandala._next.imports import * - - -def test_storage(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage: - x = 1 - y = inc(x) - z = inc(2) - w = inc(y) - - assert w.cid == z.cid - assert w.hid != y.hid - assert w.cid != y.cid - assert storage.unwrap(y) == 2 - assert storage.unwrap(z) == 3 - assert storage.unwrap(w) == 3 - for ref in (y, z, w): - assert storage.attach(ref).in_memory - assert storage.attach(ref).obj == storage.unwrap(ref) - - -def test_signatures(): - storage = Storage() - - @op # a function with a wild input/output signature - def add(x, *args, y: int = 1, **kwargs): - # just sum everything - res = x + sum(args) + y + sum(kwargs.values()) - if kwargs: - return res, kwargs - elif args: - return None - else: - return res - - with storage: - # call the func in all the ways - sum_1 = add(1) - sum_2 = add(1, 2, 3, 4, ) - sum_3 = add(1, 2, 3, 4, y=5) - sum_4 = add(1, 2, 3, 4, y=5, z=6) - sum_5 = add(1, 2, 3, 4, z=5, w=7) - - assert storage.unwrap(sum_1) == 2 - assert storage.unwrap(sum_2) == None - assert storage.unwrap(sum_3) == None - assert storage.unwrap(sum_4) == (21, {'z': 6}) - assert storage.unwrap(sum_5) == (23, {'z': 5, 'w': 7}) - - -def test_retracing(): - storage = Storage() - - @op - def inc(x): - return x + 1 - - ### iterating a function - with storage: - start = 1 - for i in range(10): - start = inc(start) - - with storage: - start = 1 - for i in range(10): - start = inc(start) - - ### composing functions - @op - def add(x, y): - return x + y - - with storage: - inp = [1, 2, 3, 4, 5] - stage_1 = [inc(x) for x in inp] - stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)] - - with storage: - inp = [1, 2, 3, 4, 5] - stage_1 = [inc(x) for x in inp] - stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)] - - -def test_lists(): - storage = Storage() - - @op - def get_sum(elts: MList[int]) -> int: - return sum(elts) - - @op - def primes_below(n: int) -> MList[int]: - primes = [] - for i in range(2, n): - for p in primes: - if i % p == 0: - break - else: - primes.append(i) - return primes - - @op - def chunked_square(elts: MList[int]) -> MList[int]: - # a model for an op that does something on chunks of a big thing - # to prevent OOM errors - return [x*x for x in elts] - - with storage: - n = 10 - primes = primes_below(n) - sum_primes = get_sum(primes) - assert len(primes) == 4 - # check indexing - assert storage.unwrap(primes[0]) == 2 - assert storage.unwrap(primes[:2]) == [2, 3] - - ### lists w/ overlapping elements - with storage: - n = 100 - primes = primes_below(n) - for i in range(0, len(primes), 2): - sum_primes = get_sum(primes[:i+1]) - - with storage: - elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - squares = chunked_square(elts) diff --git a/mandala/_next/utils.py b/mandala/_next/utils.py deleted file mode 100644 index bbfb65a..0000000 --- a/mandala/_next/utils.py +++ /dev/null @@ -1,249 +0,0 @@ -from .common_imports import * -import joblib -import io -import inspect -import prettytable -import sqlite3 -from .config import * -from abc import ABC, abstractmethod - -def dataframe_to_prettytable(df: pd.DataFrame) -> str: - # Initialize a PrettyTable object - table = prettytable.PrettyTable() - - # Set the column names - table.field_names = df.columns.tolist() - - # Add rows to the table - for row in df.itertuples(index=False): - table.add_row(row) - - # Return the pretty-printed table as a string - return table.get_string() - - -def serialize(obj: Any) -> bytes: - """ - ! this may lead to different serializations for objects x, y such that x - ! == y in Python. This is because of things like set ordering, which is not - ! determined by the contents of the set. For example, {1, 2} and {2, 1} would - ! `serialize()` to different things, but they would be equal in Python. - """ - buffer = io.BytesIO() - joblib.dump(obj, buffer) - return buffer.getvalue() - - -def deserialize(value: bytes) -> Any: - buffer = io.BytesIO(value) - return joblib.load(buffer) - - -def get_content_hash(obj: Any) -> str: - if hasattr(obj, "__get_mandala_dict__"): - obj = obj.__get_mandala_dict__() - if Config.has_torch: - # TODO: ideally, should add a label to distinguish this from a numpy - # array with the same contents! - obj = tensor_to_numpy(obj) - if isinstance(obj, pd.DataFrame): - # DataFrames cause collisions for joblib hashing for some reason - # TODO: the below may be incomplete - obj = { - "columns": obj.columns, - "values": obj.values, - "index": obj.index, - } - result = joblib.hash(obj) # this hash is canonical wrt python collections - if result is None: - raise RuntimeError("joblib.hash returned None") - return result - - -def dump_output_name(index: int, output_names: Optional[List[str]] = None) -> str: - if output_names is not None and index < len(output_names): - return output_names[index] - else: - return f"output_{index}" - - -def parse_output_name(name: str) -> int: - return int(name.split("_")[-1]) - - -def get_setdict_union( - a: Dict[str, Set[str]], b: Dict[str, Set[str]] -) -> Dict[str, Set[str]]: - return {k: a.get(k, set()) | b.get(k, set()) for k in a.keys() | b.keys()} - - -def get_setdict_intersection( - a: Dict[str, Set[str]], b: Dict[str, Set[str]] -) -> Dict[str, Set[str]]: - return {k: a[k] & b[k] for k in a.keys() & b.keys()} - - -def get_dict_union_over_keys(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]: - return {k: a[k] if k in a else b[k] for k in a.keys() | b.keys()} - - -def get_dict_intersection_over_keys( - a: Dict[str, Any], b: Dict[str, Any] -) -> Dict[str, Any]: - return {k: a[k] for k in a.keys() & b.keys()} - - -def get_adjacency_union( - a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]] -) -> Dict[str, Dict[str, Set[str]]]: - return { - k: get_setdict_union(a.get(k, {}), b.get(k, {})) for k in a.keys() | b.keys() - } - - -def get_adjacency_intersection( - a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]] -) -> Dict[str, Dict[str, Set[str]]]: - return {k: get_setdict_intersection(a[k], b[k]) for k in a.keys() & b.keys()} - - -def get_nullable_union(*sets: Set[str]) -> Set[str]: - return set.union(*sets) if len(sets) > 0 else set() - - -def get_nullable_intersection(*sets: Set[str]) -> Set[str]: - return set.intersection(*sets) if len(sets) > 0 else set() - - -def get_adj_from_edges( - edges: Set[Tuple[str, str, str]], node_support: Optional[Set[str]] = None -) -> Tuple[Dict[str, Dict[str, Set[str]]], Dict[str, Dict[str, Set[str]]]]: - """ - Given edges, convert them into the adjacency representation used by the - `ComputationFrame` class. - """ - out = {} - inp = {} - for src, dst, label in edges: - if src not in out: - out[src] = {} - if label not in out[src]: - out[src][label] = set() - out[src][label].add(dst) - if dst not in inp: - inp[dst] = {} - if label not in inp[dst]: - inp[dst][label] = set() - inp[dst][label].add(src) - if node_support is not None: - for node in node_support: - if node not in out: - out[node] = {} - if node not in inp: - inp[node] = {} - return out, inp - - -def parse_returns( - sig: inspect.Signature, - returns: Any, - nout: Union[Literal["auto", "var"], int], - output_names: Optional[List[str]] = None, -) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - Return two dicts based on the returns: - - {output name: output value} - - {output name: output type annotation}, where things like `Tuple[T, ...]` are expanded. - """ - ### figure out the number of outputs, and convert them to a tuple - if nout == "auto": # infer from the returns - if isinstance(returns, tuple): - nout = len(returns) - returns_tuple = returns - else: - nout = 1 - returns_tuple = (returns,) - elif nout == "var": - assert isinstance(returns, tuple) - nout = len(returns) - returns_tuple = returns - else: # nout is an integer - assert isinstance(nout, int) - assert isinstance(returns, tuple) - assert len(returns) == nout - returns_tuple = returns - ### get the dict of outputs - outputs_dict = { - dump_output_name(i, output_names): returns_tuple[i] for i in range(nout) - } - ### figure out the annotations - annotations_dict = {} - output_annotation = sig.return_annotation - if output_annotation is inspect._empty: # no annotation - annotations_dict = {k: Any for k in outputs_dict.keys()} - else: - if ( - hasattr(output_annotation, "__origin__") - and output_annotation.__origin__ is tuple - ): - if ( - len(output_annotation.__args__) == 2 - and output_annotation.__args__[1] == Ellipsis - ): - annotations_dict = { - k: output_annotation.__args__[0] for k in outputs_dict.keys() - } - else: - annotations_dict = { - k: output_annotation.__args__[i] - for i, k in enumerate(outputs_dict.keys()) - } - else: - assert nout == 1 - annotations_dict = {k: output_annotation for k in outputs_dict.keys()} - return outputs_dict, annotations_dict - - -def unwrap_decorators( - obj: Callable, strict: bool = True -) -> Union[types.FunctionType, types.MethodType]: - while hasattr(obj, "__wrapped__"): - obj = obj.__wrapped__ - if not isinstance(obj, (types.FunctionType, types.MethodType)): - msg = f"Expected a function or method, but got {type(obj)}" - if strict: - raise RuntimeError(msg) - else: - logger.debug(msg) - return obj - -def is_subdict(a: Dict, b: Dict) -> bool: - """ - Check that all keys in `a` are in `b` with the same value. - """ - return all((k in b and a[k] == b[k]) for k in a) - -_KT, _VT = TypeVar("_KT"), TypeVar("_VT") -def invert_dict(d: Dict[_KT, _VT]) -> Dict[_VT, List[_KT]]: - """ - Invert a dictionary - """ - out = {} - for k, v in d.items(): - if v not in out: - out[v] = [] - out[v].append(k) - return out - -def ask_user(question: str, valid_options: List[str]) -> str: - """ - Ask the user a question and return their response. - """ - prompt = f"{question} " - while True: - print(prompt) - response = input().strip().lower() - if response in valid_options: - return response - else: - print(f"Invalid response: {response}") diff --git a/mandala/all.py b/mandala/all.py deleted file mode 100644 index badf8c7..0000000 --- a/mandala/all.py +++ /dev/null @@ -1,22 +0,0 @@ -from .common_imports import * -from .core.model import ValueRef, Call, FuncOp, wrap_atom, TransientObj -from .core.wrapping import unwrap -from .core.builtins_ import DictRef, ListRef, SetRef -from .queries.weaver import BuiltinQueries -from .queries.main import Querier -from .queries.viz import show -from .core.sig import Signature -from .core.config import Config -from .ui.storage import Storage, MODES, FuncInterface -from .ui.contexts import Context, GlobalContext -from .ui.funcs import op, Q, superop, Transient -from .ui.utils import wrap_ui as wrap -from .ui.cfs import ComputationFrame - -# from .storages.rel_impls.duckdb_impl import DuckDBRelStorage -from .storages.rel_impls.sqlite_impl import SQLiteRelStorage -from .storages.rels import serialize, deserialize -from .deps.tracers.dec_impl import track, TracerState -from .deps.tracers import TracerABC, DecTracer, SysTracer - -from .queries import ListQ, SetQ, DictQ diff --git a/mandala/_next/cf.py b/mandala/cf.py similarity index 100% rename from mandala/_next/cf.py rename to mandala/cf.py diff --git a/mandala/_next/cf_examples.ipynb b/mandala/cf_examples.ipynb similarity index 100% rename from mandala/_next/cf_examples.ipynb rename to mandala/cf_examples.ipynb diff --git a/mandala/common_imports.py b/mandala/common_imports.py index 68c8962..c65c395 100644 --- a/mandala/common_imports.py +++ b/mandala/common_imports.py @@ -28,6 +28,7 @@ Set, Union, TypeVar, + Literal, ) from pathlib import Path @@ -35,26 +36,6 @@ import pyarrow as pa import numpy as np - -class Session: - # for debugging - - def __init__(self): - self.items = [] - self._scope = None - - def d(self): - scope = inspect.currentframe().f_back.f_locals - self._scope = scope - - def dump(self): - # put the scope into the current locals - assert self._scope is not None - scope = inspect.currentframe().f_back.f_locals - print(f"Dumping {self._scope.keys()} into local scope") - scope.update(self._scope) - - try: import rich @@ -78,10 +59,5 @@ def dump(self): logging.basicConfig(format=FORMAT) logger.setLevel(logging.INFO) -sess = Session() - -TableType = TypeVar("TableType", pa.Table, pd.DataFrame) - -class InternalError(Exception): - pass +from mandala.common_imports import sess diff --git a/mandala/_next/config.py b/mandala/config.py similarity index 100% rename from mandala/_next/config.py rename to mandala/config.py diff --git a/mandala/core/builtins_.py b/mandala/core/builtins_.py deleted file mode 100644 index 9056d2e..0000000 --- a/mandala/core/builtins_.py +++ /dev/null @@ -1,431 +0,0 @@ -from collections.abc import Sequence, Mapping, Set as SetABC -from ..common_imports import * - -from .config import MODES -from .model import Ref, FuncOp, Call, wrap_atom, ValueRef -from .utils import Hashing -from .tps import AnyType - -HASHERS = { - "__list__": Hashing.hash_list, - "__set__": Hashing.hash_multiset, - "__dict__": Hashing.hash_dict, -} - - -class StructRef(Ref): - builtin_id = None - - def __init__( - self, - uid: Optional[str], - obj: Optional[Any], - in_memory: bool, - transient: bool = False, - ): - if uid is None: - builtin_id = self.builtin_id - hasher = HASHERS[builtin_id] - uids = Builtins.map(func=lambda elt: elt.uid, obj=obj, struct_id=builtin_id) - uid = Builtins._make_builtin_uid(uid=hasher(uids), builtin_id=builtin_id) - super().__init__(uid=uid, obj=obj, in_memory=in_memory, transient=transient) - - def as_inputs_list(self) -> List[Dict[str, Ref]]: - raise NotImplementedError - - def causify_up(self): - raise NotImplementedError - - def get_call(self, wrapped_inputs: Dict[str, Ref]) -> Call: - call_uid = Builtins.OPS[self.builtin_id].get_pre_call_uid( - input_uids={k: v.uid for k, v in wrapped_inputs.items()} - ) - return Call( - uid=call_uid, - inputs=wrapped_inputs, - outputs=[], - func_op=Builtins.OPS[self.builtin_id], - transient=False, - ) - - def get_calls(self) -> List[Call]: - inputs_list = self.as_inputs_list() - res = [] - for wrapped_inputs in inputs_list: - res.append(self.get_call(wrapped_inputs=wrapped_inputs)) - return res - - @staticmethod - def map(obj: Iterable, func: Callable) -> Iterable: - raise NotImplementedError - - @staticmethod - def elts(obj: Iterable) -> Iterable: - raise NotImplementedError - - def children(self) -> Iterable[Ref]: - return self.elts(self.obj) - - def unlinked(self, keep_causal: bool) -> "StructRef": - if not self.in_memory: - res = type(self)(uid=self.uid, obj=None, in_memory=False) - if keep_causal: - res._causal_uid = self._causal_uid - return res - else: - unlinked_elts = type(self).map( - obj=self.obj, func=lambda elt: elt.unlinked(keep_causal=keep_causal) - ) - res = type(self)(uid=self.uid, obj=unlinked_elts, in_memory=True) - if keep_causal: - res._causal_uid = self._causal_uid - return res - - -class ListRef(StructRef, Sequence): - """ - Immutable list of Refs. - """ - - builtin_id = "__list__" - - def as_inputs_list(self) -> List[Dict[str, Ref]]: - idxs = [wrap_atom(idx) for idx in range(len(self))] - for idx in idxs: - causify_atom(idx) - wrapped_inputs_list = [ - {"lst": self, "elt": elt, "idx": idx} for elt, idx in zip(self.obj, idxs) - ] - return wrapped_inputs_list - - def causify_up(self): - assert all(elt.causal_uid is not None for elt in self.obj) - self.causal_uid = Hashing.hash_list([elt.causal_uid for elt in self.obj]) - - @staticmethod - def map(obj: list, func: Callable) -> list: - return [func(elt) for elt in obj] - - @staticmethod - def elts(obj: list) -> list: - return obj - - def dump(self) -> "ListRef": - return ListRef( - uid=self.uid, obj=[vref.detached() for vref in self.obj], in_memory=True - ) - - ############################################################################ - ### list interface - ############################################################################ - def __getitem__( - self, idx: Union[int, "ValNode", Ref, slice] - ) -> Union[Ref, "ValNode"]: - self._auto_attach() - if isinstance(idx, Ref): - prepare_query(ref=idx, tp=AnyType()) - res_query = BuiltinQueries.GetListItemQuery(lst=self.query, idx=idx.query) - res = self.obj[idx.obj].unlinked(keep_causal=True) - res._query = res_query - return res - elif isinstance(idx, ValNode): - res = BuiltinQueries.GetListItemQuery(lst=self.query, idx=idx) - return res - elif isinstance(idx, int): - if ( - GlobalContext.current is not None - and GlobalContext.current.mode != MODES.run - ): - raise ValueError - res: Ref = self.obj[idx] - res = res.unlinked(keep_causal=True) - wrapped_idx = wrap_atom(obj=idx) - causify_atom(ref=wrapped_idx) - call = self.get_call( - wrapped_inputs={"lst": self, "idx": wrapped_idx, "elt": res} - ) - call.link(orientation=StructOrientations.destruct) - return res - elif isinstance(idx, slice): - if ( - GlobalContext.current is not None - and GlobalContext.current.mode != MODES.run - ): - raise ValueError - res = self.obj[idx] - res = [elt.unlinked(keep_causal=True) for elt in res] - wrapped_idxs = [wrap_atom(obj=i) for i in range(*idx.indices(len(self)))] - for wrapped_idx in wrapped_idxs: - causify_atom(ref=wrapped_idx) - for wrapped_idx, elt in zip(wrapped_idxs, res): - call = self.get_call( - wrapped_inputs={"lst": self, "idx": wrapped_idx, "elt": elt} - ) - call.link(orientation=StructOrientations.destruct) - return res - - def __iter__(self): - self._auto_attach() - return iter(self.obj) - - def __len__(self) -> int: - self._auto_attach() - return len(self.obj) - - -class DictRef(StructRef): # don't inherit from Mapping because it's not hashable - """ - Immutable string-keyed dict of Refs. - """ - - builtin_id = "__dict__" - - def as_inputs_list(self) -> List[Dict[str, Ref]]: - keys = {k: wrap_atom(k) for k in self.obj.keys()} - for k in keys.values(): - causify_atom(k) - wrapped_inputs_list = [ - {"dct": self, "key": keys[k], "val": v} for k, v in self.obj.items() - ] - return wrapped_inputs_list - - def causify_up(self): - assert all(elt.causal_uid is not None for elt in self.obj.values()) - self.causal_uid = Hashing.hash_dict( - {k: v.causal_uid for k, v in self.obj.items()} - ) - - @staticmethod - def map(obj: dict, func: Callable) -> dict: - return {k: func(v) for k, v in obj.items()} - - @staticmethod - def elts(obj: dict) -> Iterable: - return iter(obj.values()) - - def dump(self) -> "DictRef": - assert self.in_memory - return DictRef( - uid=self.uid, - obj={k: vref.detached() for k, vref in self.obj.items()}, - in_memory=True, - ) - - ############################################################################ - ### dict interface - ############################################################################ - def __getitem__(self, key: Union[str, "ValNode", Ref]) -> Union[Ref, "ValNode"]: - self._auto_attach() - if isinstance(key, str): - if ( - GlobalContext.current is not None - and GlobalContext.current.mode != MODES.run - ): - raise ValueError - return self.obj[key] - if isinstance(key, Ref): - prepare_query(ref=key, tp=AnyType()) - res_query = BuiltinQueries.GetDictItemQuery(dct=self.query, key=key.query) - res = self.obj[key.obj].unlinked(keep_causal=True) - res._query = res_query - return res - elif isinstance(key, ValNode): - res = BuiltinQueries.GetDictItemQuery(dct=self.query, key=key) - return res - else: - raise ValueError - - def __iter__(self): - self._auto_attach() - return iter(self.obj) - - def __len__(self) -> int: - self._auto_attach() - return len(self.obj) - - -class SetRef(StructRef): # don't subclass from set, because it's not hashable - """ - Immutable set of Refs. - """ - - builtin_id = "__set__" - - def as_inputs_list(self) -> List[Dict[str, Ref]]: - wrapped_inputs_list = [{"st": self, "elt": elt} for elt in self.obj] - return wrapped_inputs_list - - def causify_up(self): - assert all(elt.causal_uid is not None for elt in self.obj) - self.causal_uid = Hashing.hash_multiset([elt.causal_uid for elt in self.obj]) - - @staticmethod - def map(obj: set, func: Callable) -> list: - return [func(elt) for elt in obj] - - @staticmethod - def elts(obj: set) -> Iterable: - return iter(obj) - - def dump(self) -> "SetRef": - assert self.in_memory - return SetRef( - uid=self.uid, obj={vref.detached() for vref in self.obj}, in_memory=True - ) - - ############################################################################ - ### set interface - ############################################################################ - def __contains__(self, item: Ref) -> bool: - from .wrapping import unwrap - - if not self.in_memory: - logging.warning( - "Checking membership in a lazy SetRef requires loading the entire set into memory." - ) - self._auto_attach(shallow=False) - return item in unwrap(self.obj) - else: - return item in unwrap(self.obj) - - def __iter__(self): - self._auto_attach() - return iter(self.obj) - - def __len__(self) -> int: - self._auto_attach() - return len(self.obj) - - -class Builtins: - IDS = ("__list__", "__dict__", "__set__") - - @staticmethod - def list_func(lst: List[Any], elt: Any, idx: Any): - assert lst[idx] is elt - - @staticmethod - def dict_func(dct: Dict[str, Any], key: str, val: Any): - assert dct[key] is val - - @staticmethod - def set_func(st: Set[Any], elt: Any): - assert elt in st - - list_op = FuncOp( - func=list_func.__func__, _is_builtin=True, version=0, ui_name="__list__" - ) - dict_op = FuncOp( - func=dict_func.__func__, _is_builtin=True, version=0, ui_name="__dict__" - ) - set_op = FuncOp( - func=set_func.__func__, _is_builtin=True, version=0, ui_name="__set__" - ) - - OPS = { - "__list__": list_op, - "__dict__": dict_op, - "__set__": set_op, - } - - REF_CLASSES = { - "__list__": ListRef, - "__dict__": DictRef, - "__set__": SetRef, - } - - PY_TYPES = { - "__list__": list, - "__dict__": dict, - "__set__": set, - } - - IO = { - "construct": { - "__list__": {"in": {"elt", "idx"}, "out": {"lst"}}, - "__dict__": {"in": {"key", "val"}, "out": {"dct"}}, - "__set__": {"in": {"elt"}, "out": {"st"}}, - }, - "destruct": { - "__list__": {"in": {"lst", "idx"}, "out": {"elt"}}, - "__dict__": {"in": {"dct", "key"}, "out": {"val"}}, - "__set__": {"in": {"st"}, "out": {"elt"}}, - }, - } - - @staticmethod - def reassign_io_using_orientation( - in_dict: Dict[str, Any], - out_dict: Dict[str, Any], - orientation: str, - builtin_id: str, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - io_assignment = Builtins.IO[orientation][builtin_id] - in_names = io_assignment["in"] - out_names = io_assignment["out"] - res_in = {} - res_out = {} - for k, v in itertools.chain(in_dict.items(), out_dict.items()): - if k in in_names: - res_in[k] = v - if k in out_names: - res_out[k] = v - return res_in, res_out - - @staticmethod - def _make_builtin_uid(uid: str, builtin_id: str) -> str: - return f"{builtin_id}.{uid}" - - @staticmethod - def is_builtin_uid(uid: str) -> bool: - return ("." in uid) and (uid.split(".")[0] in Builtins.IDS) - - @staticmethod - def parse_builtin_uid(uid: str) -> Tuple[str, str]: - assert Builtins.is_builtin_uid(uid) - builtin_id, uid = uid.split(".", 1) - return builtin_id, uid - - @staticmethod - def spawn_builtin( - builtin_id: str, uid: str, causal_uid: Optional[str] = None - ) -> Ref: - assert builtin_id in Builtins.IDS - uid = Builtins._make_builtin_uid(uid=uid, builtin_id=builtin_id) - res = Builtins.REF_CLASSES[builtin_id](uid=uid, obj=None, in_memory=False) - if causal_uid is not None: - res._causal_uid = causal_uid - return res - - @staticmethod - def map( - func: Callable, obj: Union[List, Dict, Set], struct_id: str - ) -> Union[List, Dict, Set]: - if struct_id == "__list__": - return [func(elt) for elt in obj] - elif struct_id == "__dict__": - return {key: func(val) for key, val in obj.items()} - elif struct_id == "__set__": - return {func(elt) for elt in obj} - else: - raise ValueError(f"Invalid struct_id: {struct_id}") - - @staticmethod - def collect_all_calls(ref: Ref) -> List[Call]: - if isinstance(ref, ValueRef): - return [] - elif isinstance(ref, StructRef): - if not ref.in_memory: - return [] - else: - calls = ref.get_calls() - for elt in ref.elts(ref.obj): - calls.extend(Builtins.collect_all_calls(elt)) - return calls - else: - raise ValueError(f"Unexpected ref type: {type(ref)}") - - -from ..queries.weaver import ValNode, BuiltinQueries, prepare_query, StructOrientations -from ..ui.contexts import GlobalContext -from .wrapping import causify_atom diff --git a/mandala/core/config.py b/mandala/core/config.py deleted file mode 100644 index 9d6b7a5..0000000 --- a/mandala/core/config.py +++ /dev/null @@ -1,164 +0,0 @@ -from ..common_imports import * -from typing import Literal - - -def get_mandala_path() -> Path: - import mandala - - return Path(os.path.dirname(mandala.__file__)) - - -class Config: - ### options modifying library behavior - # whether ops automatically wrap their inputs as value references, or - # require to be explicitly passed value references - ### settings - # whether to automatically wrap inputs when a call is made to an op - autowrap_inputs = True - # whether to automatically unwrap inputs when an op is actually executed - autounwrap_inputs = True - # how to assign UIDs to outputs - output_wrap_method = "content" - # whether to empty the call and vref caches upon committing to the RDBMS - evict_on_commit = True - # whether to commit on context exit - autocommit = True - # whether signatures are verified against the database each time a function - # is called - check_signature_on_each_call = False - # always create storage with a persistent database - _persistent_storage_testing = False - enable_ref_magics = False - warnings = True - spillover_threshold_mb = 50 - db_backend = "sqlite" - query_engine: Literal["sql", "naive", "_test"] = "sql" - verbose_queries: bool = True - func_interface_cls_name = "FuncInterface" - - ### constants - # used for columns containing UIDs of value references or calls - uid_col = "__uid__" - causal_uid_col = "__causal_uid__" - full_uid_col = "__full_uid__" - content_version_col = "__content_version__" - semantic_version_col = "__semantic_version__" - transient_col = "__transient__" - # columns that are not inputs or outputs in a memoization table - special_call_cols = [ - uid_col, - causal_uid_col, - content_version_col, - semantic_version_col, - transient_col, - ] - # name for the table that holds the value reference UIDs - vref_table = "__vrefs__" - causal_vref_table = "__causal_vrefs__" - vref_value_col = "value" - temp_arrow_table = "__arrow__" - # name for the event log table - event_log_table = "__event_log__" - # todo: currently unused? - # schema_event_log_table = "__schema_event_log__" - schema_table = "__schema__" - # table for keeping track of function dependencies - deps_table = "__deps__" - provenance_table = "__provenance__" - # all output names begin with this string - # todo: prevent creating inputs with this name - output_name_prefix = "output_" - - ### checking for optional dependencies - try: - import dask - - has_dask = True - except ImportError: - has_dask = False - - try: - import torch - - has_torch = True - except ImportError: - has_torch = False - - try: - import cityhash - - has_cityhash = True - except ImportError: - has_cityhash = False - - try: - import PIL - - has_pil = True - except ImportError: - has_pil = False - - try: - import duckdb - - has_duckdb = True - except ImportError: - has_duckdb = False - - try: - import rich - - has_rich = True - except ImportError: - has_rich = False - - if has_rich: - from rich import pretty - - pretty.install() - - ### some module names needed by internals - mandala_path = get_mandala_path() - module_name = "mandala" - tests_module_name = "mandala.tests" - - # hashing method - content_hasher: Literal["cityhash", "blake2b", "joblib"] = "joblib" - - -def dump_output_name(index: int) -> str: - return f"{Config.output_name_prefix}{index}" - - -def parse_output_idx(output_name: str) -> int: - return int(output_name[len(Config.output_name_prefix) :]) - - -def is_output_name(name: str) -> bool: - return ( - name.startswith(Config.output_name_prefix) - and name[len(Config.output_name_prefix) :].isdigit() - ) - - -if Config.has_torch: - import torch - - -class MODES: - run = "run" - query = "query" - batch = "batch" - noop = "noop" - define = "define" - delete = "delete" - - all_ = (run, query, batch, noop, define, delete) - - -class Provenance: - causal_uid = "causal" - direction = "direction" - call_causal_uid = "call_causal" - name = "name" - op_id = "op_id" diff --git a/mandala/core/integrations.py b/mandala/core/integrations.py deleted file mode 100644 index 43c8c9c..0000000 --- a/mandala/core/integrations.py +++ /dev/null @@ -1,54 +0,0 @@ -from .sig import Signature -from ..common_imports import * -from .config import Config - -if Config.has_torch: - import torch - - def sig_from_jit_script( - f: torch.jit.ScriptFunction, version: int - ) -> Tuple[Signature, inspect.Signature]: - """ - Parse a torch.jit.ScriptFunction into a FuncOp. - - @torch.jit.script-decorated functions must follow a special format: - - there can't be a variable number of arguments - - there can't be keyword arguments with defaults - """ - # get the names of the inputs - input_names = [arg.name for arg in f.schema.arguments] - # get the number of outputs - n_outputs = len(f.schema.returns) - # get the default values for the inputs of the ScriptFunction - defaults = {} - for argument in f.schema.arguments: - if argument.default_value is not None: - # defaults are assigned non-null values by the JIT based on - # inferred type and default value in the signature. - defaults[argument.name] = argument.default_value - parameters = OrderedDict() - for arg in f.schema.arguments: - kind = ( - inspect.Parameter.KEYWORD_ONLY - if arg.kwarg_only - else inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - if defaults.get(arg.name) is not None: - param = inspect.Parameter( - name=arg.name, kind=kind, default=defaults[arg.name] - ) - else: - param = inspect.Parameter(name=arg.name, kind=kind) - parameters[arg.name] = param - if n_outputs == 0: - return_annotation = inspect._empty - elif n_outputs == 1: - return_annotation = f.schema.returns[0].type - else: - return_annotation = tuple([r.type for r in f.schema.returns]) - py_sig = inspect.Signature( - parameters=list(parameters.values()), - return_annotation=return_annotation, - __validate_parameters__=True, - ) - return Signature.from_py(name=str(f.name), version=version, sig=py_sig), py_sig diff --git a/mandala/core/model.py b/mandala/core/model.py deleted file mode 100644 index 070158f..0000000 --- a/mandala/core/model.py +++ /dev/null @@ -1,743 +0,0 @@ -from typing import Type -import textwrap -from .config import Config, parse_output_idx, dump_output_name, is_output_name, MODES -from ..common_imports import * -from .utils import Hashing, get_uid, get_full_uid, parse_full_uid -from .sig import ( - Signature, - _postprocess_outputs, -) -from .tps import Type, AnyType, ListType, DictType, SetType - -from ..deps.tracers import TracerABC, DecTracer - -if Config.has_torch: - import torch - from .integrations import sig_from_jit_script - - -class Delayed: - pass - - -class TransientObj: - def __init__(self, obj: Any, unhashable: bool = False): - self.obj = obj - self.unhashable = unhashable - - -def get_transient_uid(content_hash: str) -> str: - return f"__transient__.{content_hash}" - - -def is_transient_uid(uid: str) -> bool: - return uid.startswith("__transient__") - - -################################################################################ -### refs -################################################################################ -class Ref: - def __init__(self, uid: str, obj: Any, in_memory: bool, transient: bool): - self.uid = uid - self._causal_uid = None - self._obj = obj - self.in_memory = in_memory - self.transient = transient - # runtime only - self._query: Optional[ValNode] = None - - @property - def causal_uid(self) -> str: - return self._causal_uid - - @causal_uid.setter - def causal_uid(self, causal_uid: str): - if self.causal_uid is not None and self.causal_uid != causal_uid: - raise ValueError( - f"causal_uid already set to {self.causal_uid}, cannot set to {causal_uid}" - ) - self._causal_uid = causal_uid - - @property - def full_uid(self) -> str: - return get_full_uid(uid=self.uid, causal_uid=self.causal_uid) - - @staticmethod - def parse_full_uid(full_uid: str) -> Tuple[str, str]: - return parse_full_uid(full_uid=full_uid) - - @property - def obj(self) -> Any: - return self._obj - - @staticmethod - def from_full_uid(full_uid: str) -> "Ref": - uid, causal_uid = Ref.parse_full_uid(full_uid=full_uid) - res = Ref.from_uid(uid=uid) - res.causal_uid = causal_uid - return res - - @staticmethod - def from_uid(uid: str, causal_uid: Optional[str] = None) -> "Ref": - from .builtins_ import Builtins - - if Builtins.is_builtin_uid(uid=uid): - builtin_id, uid = Builtins.parse_builtin_uid(uid=uid) - return Builtins.spawn_builtin( - builtin_id=builtin_id, uid=uid, causal_uid=causal_uid - ) - elif is_transient_uid(uid=uid): - res = ValueRef(uid=uid, obj=None, in_memory=False, transient=True) - else: - res = ValueRef(uid=uid, obj=None, in_memory=False) - if causal_uid is not None: - res._causal_uid = causal_uid - return res - - def is_delayed(self) -> bool: - return isinstance(self.obj, Delayed) - - @staticmethod - def make_delayed(RefCls) -> "Ref": - return RefCls(uid="", obj=Delayed(), in_memory=False) - - def attach(self, reference: "Ref"): - assert self.uid == reference.uid and reference.in_memory - self._obj = reference.obj - self.in_memory = True - - def _auto_attach(self, shallow: bool = True): - if not self.in_memory: - - context = GlobalContext.current - if context is None: - ## emit a warning - # logger.warning( - # "No context found, cannot attach ref." - # ) - # fail silently - return - if context.mode != MODES.run: - return - assert context.storage is not None - storage = context.storage - storage.cache.mattach(vrefs=[self], shallow=shallow) - # storage.rel_adapter.mattach(vrefs=[self], shallow=shallow) - causify_down(ref=self, start=self.causal_uid, stop_at_causal=False) - - def detached(self, keep_causal: bool = True) -> "Ref": - """ - Produce a metadata-only copy of this `Ref` that is unlinked from the - computational graph. - """ - res = self.__class__(uid=self.uid, obj=None, in_memory=False) - if keep_causal: - res.causal_uid = self.causal_uid - return res - - def unlinked(self, keep_causal: bool) -> "Ref": - """ - Produce a copy unlinked from the computational graph that preserves the - `in_memory` status and all other properties. - - If `keep_causal` is True, the causal UID is preserved. Otherwise, it is - removed recursively. - - ! Actually, this is quite a lot like a deepcopy for refs, which suggests - that we could refactor - """ - raise NotImplementedError - - def clone(self) -> "Ref": - """ - Clone a `Ref` by creating a new object which has the same data - """ - return self.unlinked(keep_causal=True) - - @property - def _uid_suffix(self) -> str: - return self.uid.split(".")[-1] - - @property - def _short_uid(self) -> str: - return self._uid_suffix[:3] + "..." - - def __repr__(self, shorten: bool = False) -> str: - if self.in_memory: - obj_repr = repr(self.obj) - if shorten: - obj_repr = textwrap.shorten(obj_repr, width=50, placeholder="...") - if "\n" in obj_repr: - obj_repr = f'\n{textwrap.indent(obj_repr, " ")}' - return f"{self.__class__.__name__}({obj_repr}, uid={self._short_uid})" - else: - return f"{self.__class__.__name__}(in_memory=False, uid={self._short_uid})" - - @property - def query(self) -> "ValNode": - if self._query is None: - raise ValueError("Ref has no query") - return self._query - - def pin(self, *values): - assert self._query is not None - if len(values) == 0: - constraint = [self.full_uid] - else: - constraint = self.query.get_constraint(*values) - self._query.constraint = constraint - - def unpin(self): - assert self._query is not None - self._query.constraint = None - - -class ValueRef(Ref): - def __init__(self, uid: str, obj: Any, in_memory: bool, transient: bool = False): - super().__init__(uid=uid, obj=obj, in_memory=in_memory, transient=transient) - - def dump(self) -> "ValueRef": - if not self.transient: - return ValueRef(uid=self.uid, obj=self.obj, in_memory=True) - return ValueRef( - uid=self.uid, obj=TransientObj(obj=None), in_memory=False, transient=True - ) - - def unlinked(self, keep_causal: bool) -> "Ref": - res = ValueRef( - uid=self.uid, - obj=self.obj, - in_memory=self.in_memory, - transient=self.transient, - ) - if keep_causal: - res.causal_uid = self.causal_uid - return res - - ############################################################################ - ### magic methods forwarding - ############################################################################ - def _init_magic(self) -> Callable: - if not Config.enable_ref_magics: - raise RuntimeError( - "Ref magic methods (typecasting/comparison operators/binary operators) are disabled; enable with Config.enable_ref_magics = True" - ) - if Config.warnings: - logging.warning( - f"Automatically unwrapping `Ref` to run magic method (typecasting/comparison operators/binary operators)." - ) - from .wrapping import unwrap - - self._auto_attach(shallow=True) - return unwrap - - ### typecasting - def __bool__(self) -> bool: - self._init_magic() - return self.obj.__bool__() - - def __int__(self) -> int: - self._init_magic() - return self.obj.__int__() - - def __index__(self) -> int: - self._init_magic() - return self.obj.__index__() - - ### comparison - def __lt__(self, other: Any) -> bool: - unwrap = self._init_magic() - return self.obj.__lt__(unwrap(other)) - - def __le__(self, other: Any) -> bool: - unwrap = self._init_magic() - return self.obj.__le__(unwrap(other)) - - def __eq__(self, other: Any) -> bool: - unwrap = self._init_magic() - return self.obj.__eq__(unwrap(other)) - - def __hash__(self) -> int: - return id(self) - - def __ne__(self, other: Any) -> bool: - unwrap = self._init_magic() - return self.obj.__ne__(unwrap(other)) - - def __gt__(self, other: Any) -> bool: - unwrap = self._init_magic() - return self.obj.__gt__(unwrap(other)) - - def __ge__(self, other: Any) -> bool: - unwrap = self._init_magic() - return self.obj.__ge__(unwrap(other)) - - ### binary operations - def __add__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__add__(unwrap(other)) - - def __sub__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__sub__(unwrap(other)) - - def __mul__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__mul__(unwrap(other)) - - def __floordiv__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__floordiv__(unwrap(other)) - - def __truediv__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__truediv__(unwrap(other)) - - def __mod__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__mod__(unwrap(other)) - - def __or__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__or__(unwrap(other)) - - def __and__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__and__(unwrap(other)) - - def __xor__(self, other: Any) -> Any: - unwrap = self._init_magic() - return self.obj.__xor__(unwrap(other)) - - -################################################################################ -### calls -################################################################################ - - -class Call: - def __init__( - self, - uid: str, - inputs: Dict[str, Ref], - outputs: List[Ref], - func_op: "FuncOp", - transient: bool, - causal_uid: Optional[str] = None, - semantic_version: Optional[str] = None, - content_version: Optional[str] = None, - ): - self.func_op = func_op.detached() - self.uid = uid - self.semantic_version = semantic_version - self.content_version = content_version - self.inputs = inputs - self.outputs = outputs - self.transient = transient # if outputs contain transient objects - if causal_uid is None: - input_uids = {k: v.uid for k, v in self.inputs.items()} - input_causal_uids = {k: v.causal_uid for k, v in self.inputs.items()} - assert all([v is not None for v in input_causal_uids.values()]) - causal_uid = self.func_op.get_call_causal_uid( - input_uids=input_uids, - input_causal_uids=input_causal_uids, - semantic_version=semantic_version, - ) - self.causal_uid = causal_uid - self._func_query = None - - @property - def full_uid(self) -> str: - return f"{self.uid}.{self.causal_uid}" - - def link( - self, - orientation: Optional[str] = None, - ): - if self._func_query is not None: - return - input_types, output_types = self.func_op.input_types, self.func_op.output_types - for k, v in self.inputs.items(): - prepare_query(ref=v, tp=input_types[k]) - for i, v in enumerate(self.outputs): - prepare_query(ref=v, tp=output_types[i]) - outputs = {dump_output_name(i): v.query for i, v in enumerate(self.outputs)} - self._func_query = CallNode.link( - inputs={k: v.query for k, v in self.inputs.items()}, - func_op=self.func_op.detached(), - outputs=outputs, - orientation=orientation, - constraint=None, - ) - - def unlink(self): - assert self._func_query is not None - self._func_query.unlink() - self._func_query = None - - @property - def func_query(self) -> "CallNode": - assert self._func_query is not None - return self._func_query - - def __repr__(self) -> str: - tuples: List[Tuple[str, str]] = [ - ("uid", self.uid), - ] - if self.semantic_version is not None: - tuples.append(("semantic_version", self.semantic_version)) - if self.content_version is not None: - tuples.append(("content_version", self.content_version)) - tuples.extend( - [ - ("inputs", textwrap.shorten(str(self.inputs), width=80)), - ("outputs", textwrap.shorten(str(self.outputs), width=80)), - ("func_op", self.func_op), - ] - ) - data_str = ",\n".join([f" {k}={v}" for k, v in tuples]) - return f"Call(\n{data_str}\n)" - - @staticmethod - def from_row(row: Union[pa.Table, dict], func_op: "FuncOp") -> "Call": - """ - Generate a `Call` from a single-row table encoding the UID, input and - output UIDs. - - NOTE: this does not include the objects for the inputs and outputs to the call! - """ - columns = row.column_names if isinstance(row, pa.Table) else row.keys() - output_columns = [column for column in columns if is_output_name(column)] - input_columns = [ - column - for column in columns - if column not in output_columns and column not in Config.special_call_cols - ] - process_boolean = lambda x: True if x == "1" else False - if isinstance(row, pa.Table): - return Call( - uid=row.column(Config.uid_col)[0].as_py(), - causal_uid=row.column(Config.causal_uid_col)[0].as_py(), - semantic_version=row.column(Config.semantic_version_col)[0].as_py(), - content_version=row.column(Config.content_version_col)[0].as_py(), - transient=process_boolean(row.column(Config.transient_col)[0].as_py()), - inputs={ - k: Ref.from_full_uid(full_uid=row.column(k)[0].as_py()) - for k in input_columns - }, - outputs=[ - Ref.from_full_uid(full_uid=row.column(k)[0].as_py()) - for k in sorted(output_columns, key=parse_output_idx) - ], - func_op=func_op, - ) - else: - return Call( - uid=row[Config.uid_col], - causal_uid=row[Config.causal_uid_col], - semantic_version=row[Config.semantic_version_col], - content_version=row[Config.content_version_col], - transient=process_boolean(row[Config.transient_col]), - inputs={k: Ref.from_full_uid(full_uid=row[k]) for k in input_columns}, - outputs=[ - Ref.from_full_uid(full_uid=row[k]) - for k in sorted(output_columns, key=parse_output_idx) - ], - func_op=func_op, - ) - - def set_input_values(self, inputs: Dict[str, Ref]) -> "Call": - res = copy.deepcopy(self) - assert set(inputs.keys()) == set(res.inputs.keys()) - for k, v in inputs.items(): - current = res.inputs[k] - current._obj = v.obj - current.in_memory = True - return res - - def set_output_values(self, outputs: List[Ref]) -> "Call": - res = copy.deepcopy(self) - assert len(outputs) == len(res.outputs) - for i, v in enumerate(outputs): - current = res.outputs[i] - current._obj = v.obj - current.in_memory = True - return res - - def detached(self) -> "Call": - """ - Produce a copy of this call that is unlinked from the computational - graph, and has all its inputs and outputs detached. - - See `Ref.detached` for more details. - """ - return Call( - uid=self.uid, - causal_uid=self.causal_uid, - inputs={k: v.detached() for k, v in self.inputs.items()}, - outputs=[v.detached() for v in self.outputs], - func_op=self.func_op, - semantic_version=self.semantic_version, - content_version=self.content_version, - transient=self.transient, - ) - - -################################################################################ -### ops -################################################################################ -class FuncOp: - def __init__( - self, - func: Optional[Callable] = None, - sig: Optional[Signature] = None, - version: Optional[int] = None, - ui_name: Optional[str] = None, - is_super: bool = False, - n_outputs_override: Optional[int] = None, - _is_builtin: bool = False, - ): - self.is_super = is_super - self._is_builtin = _is_builtin - if func is None: - self.sig = sig - self.py_sig = None - self._func = None - self._module = None - self._qualname = None - else: - self._func = func - self._module = func.__module__ - self._qualname = func.__qualname__ - if Config.has_torch and isinstance(func, torch.jit.ScriptFunction): - sig, py_sig = sig_from_jit_script(self._func, version=version) - ui_name = sig.ui_name - else: - py_sig = inspect.signature(self._func) - ui_name = self._func.__name__ if ui_name is None else ui_name - self.py_sig = py_sig - self.sig = Signature.from_py( - sig=self.py_sig, name=ui_name, version=version, _is_builtin=_is_builtin - ) - self.n_outputs_override = n_outputs_override - if n_outputs_override is not None: - self.sig.n_outputs = n_outputs_override - self.sig.output_annotations = [Any] * n_outputs_override - - def __repr__(self) -> str: - return f"FuncOp({self.sig.ui_name}, version={self.sig.version})" - - @property - def func(self) -> Callable: - assert self._func is not None - return self._func - - @property - def is_builtin(self) -> bool: - return self._is_builtin - - @func.setter - def func(self, func: Optional[Callable]): - self._func = func - if func is not None: - self.py_sig = inspect.signature(self._func) - - @property - def input_annotations(self) -> Dict[str, Any]: - assert self.sig is not None - return self.sig.input_annotations - - @property - def input_types(self) -> Dict[str, Type]: - return {k: Type.from_annotation(v) for k, v in self.input_annotations.items()} - - @property - def output_annotations(self) -> List[Any]: - assert self.sig is not None - return self.sig.output_annotations - - @property - def output_types(self) -> List[Type]: - return [Type.from_annotation(a) for a in self.output_annotations] - - def compute( - self, - inputs: Dict[str, Any], - tracer: Optional[TracerABC] = None, - ) -> List[Any]: - if tracer is not None: - with tracer: - if isinstance(tracer, DecTracer): - node = tracer.register_call(func=self.func) - result = self.func(**inputs) - if isinstance(tracer, DecTracer): - tracer.register_return(node=node) - else: - result = self.func(**inputs) - return _postprocess_outputs(sig=self.sig, result=result) - - @staticmethod - def _from_data( - sig: Signature, - func: Callable, - ) -> "FuncOp": - """ - Create a `FuncOp` object based on a signature and maybe a function. For - internal use only. - """ - res = FuncOp( - func=func, - version=sig.version, - ui_name=sig.ui_name, - _is_builtin=sig.is_builtin, - ) - res.sig = sig - return res - - @staticmethod - def _from_sig(sig: Signature) -> "FuncOp": - return FuncOp(func=None, sig=sig, _is_builtin=sig.is_builtin) - - def detached(self) -> "FuncOp": - if self._func is None: - return copy.deepcopy(self) - result = FuncOp( - func=None, - sig=copy.deepcopy(self.sig), - version=self.sig.version, - ui_name=self.sig.ui_name, - is_super=self.is_super, - n_outputs_override=self.n_outputs_override, - _is_builtin=self._is_builtin, - ) - result._module = self._module - result._qualname = self._qualname - result.py_sig = self.py_sig # signature objects are immutable - return result - - def get_active_inputs(self, input_uids: Dict[str, str]) -> Dict[str, str]: - """ - Return a dict of external -> internal input names for inputs that are - not set to their default values. - """ - res = {} - for k, v in input_uids.items(): - internal_k = self.sig.ui_to_internal_input_map[k] - if internal_k in self.sig._new_input_defaults_uids: - internal_uid = Ref.parse_full_uid( - full_uid=self.sig._new_input_defaults_uids[internal_k] - )[0] - if internal_uid == v: - continue - res[k] = internal_k - return res - - def get_call_causal_uid( - self, - input_uids: Dict[str, str], - input_causal_uids: Dict[str, str], - semantic_version: Optional[str], - ) -> str: - """ - Combine the causal UIDs of the inputs, the semantic version of the call, - and the versioned internal name of the function to generate a unique - causal UID for the call. - """ - active_inputs = self.get_active_inputs(input_uids=input_uids) - return Hashing.get_content_hash( - obj=[ - { - active_inputs[k]: v - for k, v in input_causal_uids.items() - if k in active_inputs.keys() - }, - semantic_version, - self.sig.versioned_internal_name, - ] - ) - - def get_pre_call_uid(self, input_uids: Dict[str, str]) -> str: - # get call UID using *internal names* to guarantee the same UID will be - # assigned regardless of renamings - active_inputs = self.get_active_inputs(input_uids=input_uids) - hashable_input_uids = { - active_inputs[k]: v - for k, v in input_uids.items() - if k in active_inputs.keys() - } - call_uid = Hashing.get_content_hash( - obj=[ - hashable_input_uids, - self.sig.versioned_internal_name, - ] - ) - return call_uid - - def get_call_uid(self, pre_call_uid: str, semantic_version: Optional[str]) -> str: - return Hashing.get_content_hash((pre_call_uid, semantic_version)) - - -################################################################################ -### wrapping -################################################################################ -def wrap_atom(obj: Any, uid: Optional[str] = None) -> ValueRef: - """ - Wraps a value as a `ValueRef`, if it isn't one already. - - The uid is either explicitly set, or a content hash is generated. Note that - content hashing may take non-trivial time for large objects. When `obj` is - already a `ValueRef` and `uid` is provided, an error is raised. - """ - if isinstance(obj, Ref) and not isinstance(obj, ValueRef): - raise ValueError(f"Cannot wrap {obj} as a ValueRef") - if isinstance(obj, ValueRef): - if uid is not None: - # protect against accidental misuse - raise ValueError(f"Cannot change uid of ValueRef: {obj}") - return obj - elif not isinstance(obj, TransientObj): - uid = Hashing.get_content_hash(obj) if uid is None else uid - return ValueRef(uid=uid, obj=obj, in_memory=True) - else: - if obj.unhashable: - uid = get_uid() - else: - uid = Hashing.get_content_hash(obj.obj) if uid is None else uid - uid = get_transient_uid(content_hash=uid) - return ValueRef(uid=uid, obj=obj.obj, in_memory=True, transient=True) - - -def collect_detached(refs: Iterable[Ref], include_transient: bool) -> List[Ref]: - """ - Recursively get all detached `Ref`s present in `refs` - """ - detached_vrefs = [] - for ref in refs: - if ( - isinstance(ref, Ref) - and not ref.in_memory - and (include_transient or not ref.transient) - ): - detached_vrefs.append(ref) - elif isinstance(ref, StructRef) and ref.in_memory: - detached_vrefs.extend( - collect_detached(ref.children(), include_transient=include_transient) - ) - else: - continue - return detached_vrefs - - -def clone(ref: Ref) -> Ref: - """ - Clone a `Ref` by creating a new object which has the same data - """ - if isinstance(ref, ValueRef): - return ref.dump() - - -from .builtins_ import ListRef, DictRef, SetRef, Builtins, StructRef -from ..queries.weaver import ValNode, StructOrientations, CallNode, prepare_query -from ..ui.contexts import GlobalContext -from .wrapping import causify_down, causify_atom diff --git a/mandala/core/prov.py b/mandala/core/prov.py deleted file mode 100644 index a2d2292..0000000 --- a/mandala/core/prov.py +++ /dev/null @@ -1,80 +0,0 @@ -from ..common_imports import * -from .builtins_ import ListRef, DictRef, SetRef - -BUILTIN_IDS = [ListRef.builtin_id, DictRef.builtin_id, SetRef.builtin_id] -BUILTIN_OP_IDS = [f"{x}_0" for x in BUILTIN_IDS] - - -def propagate_struct_provenance(prov_df: pd.DataFrame) -> pd.DataFrame: - """ - Compute directions of structural calls in a new column `direction_new` by - inferring from the data. Currently for backward compatibility. - - The algorithm is as follows: - - find all the refs that are the direct result if a non-struct op call - - find all the struct calls where these refs appear as the struct - - find all the struct calls expressing the items of these structs - - mark these items as outputs - - repeat the process for these items, until no new structs are found among them - - Note that this assigns `direction_new` for all calls involved in the structs - found by this process. - - For every struct call that hasn't been assigned `direction_new` yet, we mark - the struct as output, and the items (and indices) as inputs. - """ - prov_df = prov_df.copy() - prov_df["direction_new"] = [None for _ in range(len(prov_df))] - nonstruct_outputs_causal_uids = prov_df.query( - 'direction == "output" and op_id not in @BUILTIN_OP_IDS' - ).causal.values - structs_df = get_structs_df(prov_df, nonstruct_outputs_causal_uids) - items_df = get_items_df(prov_df, structs_df.call_causal.values) - while len(structs_df) > 0: - # mark only the items (not structs or indices) as outputs - prov_df["direction_new"][ - prov_df.call_causal.isin(items_df.call_causal) - & ~(prov_df.name.isin(["lst", "dct", "st", "idx", "key"])) - ] = "output" - items_causal_uids = items_df.causal.values - structs_df = get_structs_df(prov_df, items_causal_uids) - items_df = get_items_df(prov_df, structs_df.call_causal.values) - remaining_struct_mask = ( - (prov_df["direction_new"] != prov_df["direction_new"]) - & (prov_df["op_id"].isin(BUILTIN_OP_IDS)) - & (prov_df["name"].isin(["lst", "dct", "st"])) - ) - prov_df.loc[remaining_struct_mask, "direction_new"] = "output" - - remaining_things = prov_df.query("direction_new != direction_new").index - prov_df.loc[remaining_things, "direction_new"] = prov_df.query( - "direction_new != direction_new" - ).direction - return prov_df - - -def get_structs_df(prov_df: pd.DataFrame, causal_uids: Iterable[str]) -> pd.DataFrame: - """ - Given some causal UIDs and a provenance dataframe, return the sub-dataframe - where these causal UIDs appear in the role of the struct in a structural - call - """ - return prov_df.query( - 'causal in @causal_uids and op_id in @BUILTIN_OP_IDS and name in ["lst", "dct", "st"]' - ) - - -def get_items_df( - prov_df: pd.DataFrame, struct_call_uids: Iterable[str] -) -> pd.DataFrame: - """ - Given some structural causal call UIDs and a provenance dataframe, return - the sub-dataframe where these structural calls are associated with items - (elements/values) of the structs - """ - # get the sub-dataframe for these structural calls containing the items (elts/values) of the structs - return prov_df.query('call_causal in @struct_call_uids and name in ["elt", "val"]') - - -def get_idx_df(prov_df: pd.DataFrame, struct_call_uids: Iterable[str]) -> pd.DataFrame: - return prov_df.query('call_causal in @struct_call_uids and name in ["idx", "key"]') diff --git a/mandala/core/sig.py b/mandala/core/sig.py deleted file mode 100644 index 298aef9..0000000 --- a/mandala/core/sig.py +++ /dev/null @@ -1,511 +0,0 @@ -from ..common_imports import * -from .config import Config, is_output_name -from .utils import get_uid, Hashing, is_subdict, get_full_uid -from ..utils import serialize -from pickle import PicklingError - -if Config.has_torch: - import torch - - -def sanitize_annotation(annotation: Any) -> Any: - try: - serialize(annotation) - return annotation - except PicklingError: - return Any - except Exception as e: - raise ValueError(f"Invalid annotation: {annotation}") from e - - -class Signature: - """ - Holds and manipulates the relevant metadata for a memoized function, which - includes - - the function's user-interface (human-facing) and internal (used by storage) - name, - - the user-interface and internal input names (and the mapping between them), - - the version, - - and the default values. - - (optional) superop status - - Responsible for manipulations to this state, and keeping it consistent, so - e.g. logic for checking if a refactoring makes sense should be hidden here. - - The internal name of the function is an immutable UID that is used to - identify the function throughout its entire lifetime for the storage it is - connected to. The UI name is what the function is named in the source - code, and can be changed. Same for the internal/UI input names. - - What goes through most of the system at runtime are the UI names, to make it - easier to debug and inspect things. The internal names are used only in very - specific and isolated parts of the architecture. - """ - - def __init__( - self, - ui_name: str, - input_names: Set[str], - n_outputs: int, - defaults: Dict[str, Any], # ui name -> default value - version: int, - input_annotations: Dict[str, Any], - output_annotations: List[Any], - _is_builtin: bool = False, - ): - self.ui_name = ui_name - self.input_names = input_names - self.defaults = defaults - self.n_outputs = n_outputs - self.version = version - self._internal_name = None - # ui name -> internal name for inputs - self._ui_to_internal_input_map = None - # internal input name -> UID of default value - # this stores the UIDs of default values for inputs that have been - # added to the function since its creation - self._new_input_defaults_uids = {} - - self.input_annotations = { - k: sanitize_annotation(v) for k, v in input_annotations.items() - } - self.output_annotations = [sanitize_annotation(v) for v in output_annotations] - - self._is_builtin = _is_builtin - if self.is_builtin: - self._internal_name = ui_name - self._ui_to_internal_input_map = {name: name for name in input_names} - - @property - def is_builtin(self) -> bool: - return self._is_builtin - - def check_invariants(self): - assert set(self.defaults.keys()) <= self.input_names - assert set(self.input_annotations.keys()) == self.input_names - assert len(self.output_annotations) == self.n_outputs - if self.has_internal_data: - assert set(self._ui_to_internal_input_map.keys()) == self.input_names - assert set(self._new_input_defaults_uids.keys()) <= set( - self._ui_to_internal_input_map.values() - ) - - def __repr__(self) -> str: - return ( - f"Signature(ui_name={self.ui_name}, input_names={self.input_names}, " - f"n_outputs={self.n_outputs}, defaults={self.defaults}, " - f"version={self.version}, internal_name={self._internal_name}, " - f"ui_to_internal_input_map={self._ui_to_internal_input_map}, " - f"new_input_defaults_uids={self._new_input_defaults_uids}, " - f"is_builtin={self.is_builtin})" - ) - - @property - def versioned_ui_name(self) -> str: - """ - Return the version-qualified human-readable name of this signature, used to - disambiguate between different versions of the same function. - """ - return f"{self.ui_name}_{self.version}" - - @property - def versioned_internal_name(self) -> str: - """ - Return the version-qualified internal name of this signature - """ - return f"{self.internal_name}_{self.version}" - - @property - def internal_name(self) -> str: - if self._internal_name is None: - raise ValueError("Internal name not set") - return self._internal_name - - @staticmethod - def parse_versioned_name(versioned_name: str) -> Tuple[str, int]: - """ - Recover the name and version from a version-qualified name - """ - name, version_string = versioned_name.rsplit("_", 1) - return name, int(version_string) - - @staticmethod - def dump_versioned_name(name: str, version: int) -> str: - """ - Return a version-qualified name from a name and version - """ - return f"{name}_{version}" - - @property - def ui_to_internal_input_map(self) -> Dict[str, str]: - if self._ui_to_internal_input_map is None: - raise ValueError("Internal input names not set") - return self._ui_to_internal_input_map - - @property - def internal_to_ui_input_map(self) -> Dict[str, str]: - """ - Mapping from internal input names to their UI names - """ - if not self.has_internal_data: - raise ValueError() - return {v: k for k, v in self.ui_to_internal_input_map.items()} - - @property - def has_internal_data(self) -> bool: - """ - Whether this signature has had its internal data (internal signature - name and internal input names) set. - """ - return ( - self._internal_name is not None - and self._ui_to_internal_input_map is not None - and self._ui_to_internal_input_map.keys() == self.input_names - ) - - @property - def new_ui_input_default_uids(self) -> Dict[str, str]: - return { - self.internal_to_ui_input_map[k]: v - for k, v in self._new_input_defaults_uids.items() - } - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, Signature) - and self.ui_name == other.ui_name - and self.input_names == other.input_names - and self.n_outputs == other.n_outputs - and self.defaults == other.defaults - and self.version == other.version - and self._internal_name == other._internal_name - and self._ui_to_internal_input_map == other._ui_to_internal_input_map - and self._new_input_defaults_uids == other._new_input_defaults_uids - ) - - ############################################################################ - ### PURE methods for manipulating the signature - ### to avoid broken state - ############################################################################ - def _generate_internal(self, internal_name: Optional[str] = None) -> "Signature": - """ - Assign internal names to random UIDs. - - Providing `internal_name` explicitly can be used to set the same - internal name for different versions of the same function. - """ - res = copy.deepcopy(self) - if not self.is_builtin: - if internal_name is None: - internal_name = get_uid() - ui_to_internal_map = {k: get_uid() for k in self.input_names} - res._internal_name, res._ui_to_internal_input_map = ( - internal_name, - ui_to_internal_map, - ) - if len(self._new_input_defaults_uids) > 0: - assert self.has_internal_data - res._new_input_defaults_uids = { - ui_to_internal_map[self.internal_to_ui_input_map[k]]: v - for k, v in self._new_input_defaults_uids.items() - } - res.check_invariants() - return res - - def is_compatible(self, new: "Signature") -> Tuple[bool, Optional[str]]: - """ - Check if a new signature (possibly without internal data) is compatible - with this signature. - - Currently, the only way to be compatible is to be either the same object - or an extension with new arguments. - - Returns: - Tuple[bool, str]: (outcome, (reason if `False`, None if True)) - """ - if new.version != self.version: - return False, "Versions do not match" - if new.ui_name != self.ui_name: - return False, "UI names do not match" - if new.has_internal_data and self.has_internal_data: - if new.internal_name != self.internal_name: - return False, "Internal names do not match" - if not is_subdict( - self.ui_to_internal_input_map, new.ui_to_internal_input_map - ): - return False, "UI -> internal input mapping is inconsistent" - new_internal_names = set(new._ui_to_internal_input_map.values()) - current_internal_names = set(self._ui_to_internal_input_map.values()) - if not set.issubset(current_internal_names, new_internal_names): - return False, "Internal input names must be a superset of current" - if not set.issubset(set(self.input_names), set(new.input_names)): - return False, "Removing inputs is not supported" - # if not self.n_outputs == new.n_outputs: - # return False, "Changing the number of outputs is not supported" - if not is_subdict(self.defaults, new.defaults): - return False, "New defaults are inconsistent with current defaults" - if new.has_internal_data and not is_subdict( - self._new_input_defaults_uids, new._new_input_defaults_uids - ): - return False, "New default UIDs are inconsistent with current default UIDs" - for k in new.input_names: - if k not in self.input_names: - if k not in new.defaults.keys(): - return False, f"All new arguments must be created with defaults!" - return True, None - - def update(self, new: "Signature") -> Tuple["Signature", dict]: - """ - Return an updated version of this signature based on a new signature - (possibly without internal data), plus a description of the updates. - - If the new signature has internal data, it is copied over. - - NOTE: the new signature need not have internal data. The goal of this - method is to be able to update from a function provided by the user that - has not been synchronized yet. - - This takes care of - - checking that the new signature is compatible with the old one - - generating names for new inputs, if any. - - Returns: - - new `Signature` object - - a dictionary of {new ui input name: default value} for any new inputs - that were created - """ - is_compatible, reason = self.is_compatible(new) - if not is_compatible: - raise ValueError(reason) - new_defaults = new.defaults - res = copy.deepcopy(self) - updates = {} - for k in new.input_names: - if k not in res.input_names: - # this means a new input is being created - if new.has_internal_data: - internal_name = new.ui_to_internal_input_map[k] - else: - internal_name = None - res = res.create_input( - name=k, - default=new_defaults[k], - internal_name=internal_name, - annotation=new.input_annotations[k], - ) - updates[k] = new_defaults[k] - if new.n_outputs != self.n_outputs: - res.n_outputs = new.n_outputs - res.output_annotations = new.output_annotations - res.check_invariants() - return res, updates - - def create_input( - self, - name: str, - default: Any, - annotation: Any, - internal_name: Optional[str] = None, - ) -> "Signature": - """ - Add an input with a default value to this signature. This takes care of - all the internal bookkeeping, including figuring out the UID for the - default value. - """ - if name in self.input_names: - raise ValueError(f'Input "{name}" already exists') - if not self.has_internal_data: - raise InternalError( - "Cannot add inputs to a signature without internal data" - ) - res = copy.deepcopy(self) - res.input_names.add(name) - internal_name = get_uid() if internal_name is None else internal_name - res.ui_to_internal_input_map[name] = internal_name - res.defaults[name] = default - uid = Hashing.get_content_hash(obj=default) - full_uid = get_full_uid(uid=uid, causal_uid=uid) - res._new_input_defaults_uids[internal_name] = full_uid - res.input_annotations[name] = annotation - res.check_invariants() - return res - - def rename(self, new_name: str) -> "Signature": - """ - Change the ui name - """ - res = copy.deepcopy(self) - res.ui_name = new_name - res.check_invariants() - return res - - def rename_inputs(self, mapping: Dict[str, str]) -> "Signature": - """ - Change UI names according to the given mapping. - - Supporting only a method that changes multiple names at once is more - convenient, since we must support applying updates in bulk anyway. - """ - assert all(k in self.input_names for k in mapping.keys()) - current_names = list(self.input_names) - new_names = [mapping.get(k, k) for k in current_names] - if len(set(new_names)) != len(new_names): - raise ValueError("Input name collision") - res = copy.deepcopy(self) - # migrate input names - for current_name in mapping.keys(): - res.input_names.remove(current_name) - for new_name in mapping.values(): - res.input_names.add(new_name) - # migrate defaults - res.defaults = {mapping.get(k, k): v for k, v in res.defaults.items()} - # migrate annotations - res.input_annotations = { - mapping.get(k, k): v for k, v in res.input_annotations.items() - } - # migrate internal data - for current_name, new_name in mapping.items(): - res.ui_to_internal_input_map[new_name] = res.ui_to_internal_input_map.pop( - current_name - ) - res.check_invariants() - return res - - def bump_version(self) -> "Signature": - res = copy.deepcopy(self) - res.version += 1 - res.check_invariants() - return res - - @staticmethod - def from_py( - name: str, - version: int, - sig: inspect.Signature, - _is_builtin: bool = False, - ) -> "Signature": - """ - Create a `Signature` from a Python function's signature and the other - necessary metadata, and check it satisfies mandala-specific constraints. - """ - input_names = set( - [ - param.name - for param in sig.parameters.values() - if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ] - ) - # ensure that there will be no collisions with input and output names - if any(is_output_name(name) for name in input_names): - raise ValueError( - f"Input names cannot be of the form {Config.output_name_prefix}_[number]" - ) - return_annotation = sig.return_annotation - if ( - hasattr(return_annotation, "__origin__") - and return_annotation.__origin__ is tuple - ): - n_outputs = len(return_annotation.__args__) - elif return_annotation is inspect._empty: - n_outputs = 0 - else: - n_outputs = 1 - defaults = { - param.name: param.default - for param in sig.parameters.values() - if param.default is not inspect.Parameter.empty - } - res = Signature( - ui_name=name, - input_names=input_names, - n_outputs=n_outputs, - defaults=defaults, - version=version, - input_annotations=_get_arg_annotations_from_sig(sig, support=input_names), - output_annotations=_get_return_annotations_from_sig( - sig=sig, support_size=n_outputs - ), - _is_builtin=_is_builtin, - ) - res.check_invariants() - return res - - -def _postprocess_outputs(sig: Signature, result) -> List[Any]: - if sig.n_outputs == 0: - assert ( - result is None - ), f"Operation with signature {sig} has zero outputs, but its function returned {result}" - return [] - elif sig.n_outputs == 1: - return [result] - else: - assert isinstance( - result, tuple - ), f"Operation {sig.ui_name} has multiple outputs, but its function returned a non-tuple: {result}" - assert ( - len(result) == sig.n_outputs - ), f"Operation {sig.ui_name} has {sig.n_outputs} outputs, but its function returned a tuple of length {len(result)}" - return list(result) - - -def _get_arg_annotations_from_sig( - sig: inspect.Signature, support: Set[str] -) -> Dict[str, Any]: - # Create an empty dictionary to store the argument annotations - arg_annotations = {} - # Iterate over the function's parameters - for param in sig.parameters.values(): - # Add the parameter name and its type annotation to the dictionary - arg_annotations[param.name] = param.annotation - for k in support: - if k not in arg_annotations: - arg_annotations[k] = Any - # Return the dictionary of argument annotations - return arg_annotations - - -def _get_arg_annotations(func: Callable, support: Set[str]) -> Dict[str, Any]: - if Config.has_torch and isinstance(func, torch.jit.ScriptFunction): - return {a.name: a.type for a in func.schema.arguments} - return _get_arg_annotations_from_sig(sig=inspect.signature(func), support=support) - - -def is_typing_tuple(obj: Any) -> bool: - """ - Check if an object is a typing.Tuple[...] thing - """ - try: - return obj.__origin__ is tuple - except AttributeError: - return False - - -def unpack_return_annotation(return_annotation: Any, support_size: int) -> List[Any]: - if is_typing_tuple(return_annotation): - # we must unpack the typing tuple in this weird way - res: List[Any] = list(return_annotation.__args__) - elif return_annotation is inspect._empty: - if support_size > 0: - res = [Any for _ in range(support_size)] - else: - res = [] - else: - res = [return_annotation] - return res - - -def _get_return_annotations_from_sig( - sig: inspect.Signature, support_size: int -) -> List[Any]: - return_annotation = sig.return_annotation - return unpack_return_annotation(return_annotation, support_size) - - -def _get_return_annotations(func: Callable, support_size: int) -> List[Any]: - if Config.has_torch and isinstance(func, torch.jit.ScriptFunction): - return_annotation = func.schema.returns[0].type - else: - sig = inspect.signature(func) - return_annotation = sig.return_annotation - return unpack_return_annotation(return_annotation, support_size) diff --git a/mandala/core/tps.py b/mandala/core/tps.py deleted file mode 100644 index 7bb54bf..0000000 --- a/mandala/core/tps.py +++ /dev/null @@ -1,90 +0,0 @@ -from ..common_imports import * -import typing - -################################################################################ -### types -################################################################################ -class Type: - @staticmethod - def from_annotation(annotation: Any) -> "Type": - if (annotation is None) or (annotation is inspect._empty): - return AnyType() - if annotation is typing.Any: - return AnyType() - elif isinstance(annotation, Type): - return annotation - elif isinstance(annotation, type): - if annotation is list: - return ListType() - elif annotation is dict: - return DictType() - elif annotation is set: - return SetType() - else: - return AnyType() - elif hasattr(annotation, "__origin__"): - if annotation.__origin__ is list: - elt_annotation = annotation.__args__[0] - return ListType( - elt_type=Type.from_annotation(annotation=elt_annotation) - ) - elif annotation.__origin__ is dict: - value_annotation = annotation.__args__[1] - return DictType( - elt_type=Type.from_annotation(annotation=value_annotation) - ) - elif annotation.__origin__ is set: - elt_annotation = annotation.__args__[0] - return SetType(elt_type=Type.from_annotation(annotation=elt_annotation)) - elif annotation.__origin__ is tuple: - return AnyType() - else: - return AnyType() - else: - return AnyType() - - def __eq__(self, other: Any) -> bool: - if type(self) != type(other): - return False - if isinstance(self, StructType): - return self.struct_id == other.struct_id and self.elt_type == other.elt_type - elif isinstance(self, AnyType): - return True - else: - raise NotImplementedError - - -class AnyType(Type): - def __repr__(self): - return "AnyType()" - - -class StructType(Type): - struct_id = None - - def __init__(self, elt_type: Optional[Type] = None): - self.elt_type = AnyType() if elt_type is None else elt_type - - -class ListType(StructType): - struct_id = "__list__" - model = list - - def __repr__(self): - return f"ListType(elt_type={self.elt_type})" - - -class DictType(StructType): - struct_id = "__dict__" - model = dict - - def __repr__(self): - return f"DictType(elt_type={self.elt_type})" - - -class SetType(StructType): - struct_id = "__set__" - model = set - - def __repr__(self): - return f"SetType(elt_type={self.elt_type})" diff --git a/mandala/core/utils.py b/mandala/core/utils.py deleted file mode 100644 index 6cb089b..0000000 --- a/mandala/core/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import typing -import hashlib -import textwrap -from weakref import WeakKeyDictionary -from .config import * -from ..common_imports import * - -if Config.has_cityhash: - import cityhash - -OpKey = Tuple[str, int] - - -def get_uid() -> str: - """ - Generate a sequence of 32 hexadecimal characters using the operating - system's "randomness". - """ - return "{}".format(binascii.hexlify(os.urandom(16)).decode("utf-8")) - - -def get_full_uid(uid: str, causal_uid: str) -> str: - return f"{uid}.{causal_uid}" - - -def parse_full_uid(full_uid: str) -> Tuple[str, str]: - uid, causal_uid = full_uid.rsplit(".", 1) - return uid, causal_uid - - -_KT = TypeVar("_KT") -_VT = TypeVar("_VT") - - -def get_fibers_as_lists(mapping: Dict[_KT, _VT]) -> Dict[_VT, List[_KT]]: - fibers = defaultdict(list) - for vq, name in mapping.items(): - fibers[name].append(vq) - return fibers - - -def is_subdict(a: Dict, b: Dict) -> bool: - """ - Check that all keys in `a` are in `b` with the same value. - """ - return all((k in b and a[k] == b[k]) for k in a) - - -def invert_dict(d: Dict) -> Dict: - """ - Invert a dictionary, assuming that all values are unique. - """ - return {v: k for k, v in d.items()} - - -def concat_lists(lists: List[list]) -> list: - return [x for lst in lists for x in lst] - - -class Hashing: - """ - Helpers for hashing e.g. function inputs and call metadata. - """ - - @staticmethod - def get_content_hash_blake2b(obj: Any) -> str: - # if Config.has_torch and isinstance(obj, torch.Tensor): - # #! torch tensors do not have a deterministic hash under the below - # # method - # obj = obj.cpu().numpy() - stream = io.BytesIO() - joblib.dump(value=obj, filename=stream) - stream.seek(0) - m = hashlib.blake2b() - m.update(str((stream.read())).encode()) - return m.hexdigest() - - @staticmethod - def get_cityhash(obj: Any) -> str: - stream = io.BytesIO() - joblib.dump(value=obj, filename=stream) - stream.seek(0) - h = cityhash.CityHash128(stream.read()) - digest = h.to_bytes(16, "little") - s = binascii.b2a_hex(digest) - res = s.decode() - return res - - @staticmethod - def get_joblib_hash(obj: Any) -> str: - if hasattr(obj, "__get_mandala_dict__"): - obj = obj.__get_mandala_dict__() - if Config.has_torch: - obj = tensor_to_numpy(obj) - if isinstance(obj, pd.DataFrame): - # DataFrames cause collisions with joblib hashing - obj = { - "columns": obj.columns, - "values": obj.values, - "index": obj.index, - } - result = joblib.hash(obj) - if result is None: - raise RuntimeError("joblib.hash returned None") - return result - - if Config.content_hasher == "blake2b": - get_content_hash = get_content_hash_blake2b - elif Config.content_hasher == "cityhash": - get_content_hash = get_cityhash - elif Config.content_hasher == "joblib": - get_content_hash = get_joblib_hash - else: - raise ValueError("Unknown content hasher: {}".format(Config.content_hasher)) - - ### deterministic hashing of common collections - @staticmethod - def hash_list(elts: List[str]) -> str: - return Hashing.get_content_hash(elts) - - @staticmethod - def hash_dict(elts: Dict[str, str]) -> str: - key_order = sorted(elts.keys()) - return Hashing.get_content_hash([(k, elts[k]) for k in key_order]) - - @staticmethod - def hash_set(elts: Set[str]) -> str: - return Hashing.get_content_hash(sorted(elts)) - - @staticmethod - def hash_multiset(elts: List[str]) -> str: - return Hashing.get_content_hash(sorted(elts)) - - -def unwrap_decorators( - obj: Callable, strict: bool = True -) -> Union[types.FunctionType, types.MethodType]: - while hasattr(obj, "__wrapped__"): - obj = obj.__wrapped__ - if not isinstance(obj, (types.FunctionType, types.MethodType)): - msg = f"Expected a function or method, but got {type(obj)}" - if strict: - raise RuntimeError(msg) - else: - logger.debug(msg) - return obj - - -if Config.has_torch: - - def tensor_to_numpy(obj: Union[torch.Tensor, dict, list, tuple, Any]) -> Any: - """ - Recursively convert PyTorch tensors in a data structure to numpy arrays. - - Parameters - ---------- - obj : any - The input data structure. - - Returns - ------- - any - The data structure with tensors converted to numpy arrays. - """ - if isinstance(obj, torch.Tensor): - return obj.detach().cpu().numpy() - elif isinstance(obj, dict): - return {k: tensor_to_numpy(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [tensor_to_numpy(v) for v in obj] - elif isinstance(obj, tuple): - return tuple(tensor_to_numpy(v) for v in obj) - else: - return obj diff --git a/mandala/core/wrapping.py b/mandala/core/wrapping.py deleted file mode 100644 index cf84fbf..0000000 --- a/mandala/core/wrapping.py +++ /dev/null @@ -1,247 +0,0 @@ -import typing -from ..common_imports import * -from .tps import Type, ListType, DictType, SetType, AnyType, StructType -from .utils import Hashing -from .model import Ref, Call, wrap_atom, ValueRef -from .builtins_ import ListRef, DictRef, SetRef, Builtins, StructRef - - -def typecheck(obj: Any, tp: Type): - if isinstance(tp, ListType) and not ( - isinstance(obj, list) or isinstance(obj, ListRef) - ): - raise ValueError(f"Expecting a list, got {type(obj)}") - elif isinstance(tp, DictType) and not ( - isinstance(obj, dict) or isinstance(obj, DictRef) - ): - raise ValueError(f"Expecting a dict, got {type(obj)}") - elif isinstance(tp, SetType) and not ( - isinstance(obj, set) or isinstance(obj, SetRef) - ): - raise ValueError(f"Expecting a set, got {type(obj)}") - - -def causify_atom(ref: ValueRef): - if ref.causal_uid is not None: - return - ref.causal_uid = Hashing.get_content_hash(ref.uid) - - -def causify_down(ref: Ref, start: str, stop_at_causal: bool = True): - """ - In-place top-down assignment of causal hashes to a Ref and its children. - Requires causal hashes to not be present initially. - """ - assert start is not None - if ref.causal_uid is not None and stop_at_causal: - return - if isinstance(ref, ValueRef): - ref.causal_uid = start - elif isinstance(ref, ListRef): - if ref.in_memory: - for i, elt in enumerate(ref.obj): - causify_down(elt, start=Hashing.hash_list([start, i])) - ref.causal_uid = start - elif isinstance(ref, DictRef): - if ref.in_memory: - for k, elt in ref.obj.items(): - causify_down(elt, start=Hashing.hash_list([start, k])) - ref.causal_uid = start - elif isinstance(ref, SetRef): - # sort by uid to ensure deterministic ordering - if ref.in_memory: - elts_by_uid = {elt.uid: elt for elt in ref.obj} - sorted_uids = sorted({elt.uid for elt in ref.obj}) - for i, uid in enumerate(sorted_uids): - causify_down(elts_by_uid[uid], start=Hashing.hash_list([start, uid])) - ref.causal_uid = start - else: - raise ValueError(f"Unknown ref type {type(ref)}") - - -def decausify(ref: Ref, stop_at_first_missing: bool = False): - """ - In-place recursive removal of causal hashes from a Ref - """ - if ref._causal_uid is None and stop_at_first_missing: - return - ref._causal_uid = None - if isinstance(ref, StructRef): - for elt in ref.children(): - decausify(elt, stop_at_first_missing=stop_at_first_missing) - - -def wrap_constructive(obj: Any, annotation: Any) -> Tuple[Ref, List[Call]]: - tp = Type.from_annotation(annotation=annotation) - typecheck(obj=obj, tp=tp) - calls = [] - if isinstance(obj, Ref): - res = obj, calls - elif isinstance(tp, AnyType): - res = wrap_atom(obj=obj), calls - causify_atom(ref=res[0]) - elif isinstance(tp, StructType): - RefCls: Union[ - typing.Type[ListRef], typing.Type[DictRef], typing.Type[SetRef] - ] = Builtins.REF_CLASSES[tp.struct_id] - assert type(obj) == tp.model - recursive_result = RefCls.map( - obj=obj, func=lambda elt: wrap_constructive(elt, annotation=tp.elt_type) - ) - wrapped_elts = RefCls.map(obj=recursive_result, func=lambda elt: elt[0]) - recursive_calls = RefCls.elts( - RefCls.map(obj=recursive_result, func=lambda elt: elt[1]) - ) - calls.extend([c for cs in recursive_calls for c in cs]) - obj: StructRef = RefCls(obj=wrapped_elts, uid=None, in_memory=True) - obj.causify_up() - obj_calls = obj.get_calls() - calls.extend(obj_calls) - res = obj, calls - else: - raise ValueError(f"Unknown type {tp}") - return res - - -def wrap_inputs( - objs: Dict[str, Any], - annotations: Dict[str, Any], -) -> Tuple[Dict[str, Ref], List[Call]]: - calls = [] - wrapped_objs = {} - for k, v in objs.items(): - wrapped_obj, wrapping_calls = wrap_constructive( - obj=v, - annotation=annotations[k], - ) - wrapped_objs[k] = wrapped_obj - calls.extend(wrapping_calls) - return wrapped_objs, calls - - -def causify_outputs(refs: List[Ref], call_causal_uid: str): - assert isinstance(call_causal_uid, str) - for i, ref in enumerate(refs): - causify_down(ref=ref, start=Hashing.hash_list([call_causal_uid, str(i)])) - - -def wrap_outputs( - objs: List[Any], - annotations: List[Any], -) -> Tuple[List[Ref], List[Call]]: - calls = [] - wrapped_objs = [] - for i, v in enumerate(objs): - wrapped_obj, wrapping_calls = wrap_constructive( - obj=v, - annotation=annotations[i], - ) - wrapped_objs.append(wrapped_obj) - calls.extend(wrapping_calls) - return wrapped_objs, calls - - -################################################################################ -### unwrapping -################################################################################ -T = TypeVar("T") - - -def unwrap(obj: Union[T, Ref], through_collections: bool = True) -> T: - """ - If an object is a `ValueRef`, returns the wrapped object; otherwise, return - the object itself. - - If `through_collections` is True, recursively unwraps objects in lists, - tuples, sets, and dict values. - """ - if isinstance(obj, ValueRef) and obj.transient: - return obj.obj - if isinstance(obj, Ref) and not obj.in_memory: - from ..ui.contexts import GlobalContext - - if GlobalContext.current is None: - raise ValueError( - "Cannot unwrap a Ref with `in_memory=False` outside a context" - ) - storage = GlobalContext.current.storage - storage.cache.mattach(vrefs=[obj]) - if isinstance(obj, ValueRef): - return obj.obj - elif isinstance(obj, StructRef): - return type(obj).map( - obj=obj.obj, - func=lambda elt: unwrap(elt, through_collections=through_collections), - ) - elif type(obj) is tuple and through_collections: - return tuple(unwrap(v, through_collections=through_collections) for v in obj) - elif type(obj) is set and through_collections: - return {unwrap(v, through_collections=through_collections) for v in obj} - elif type(obj) is list and through_collections: - return [unwrap(v, through_collections=through_collections) for v in obj] - elif type(obj) is dict and through_collections: - return { - k: unwrap(v, through_collections=through_collections) - for k, v in obj.items() - } - else: - return obj - - -def contains_transient(ref: Ref) -> bool: - if isinstance(ref, ValueRef): - return ref.transient - elif isinstance(ref, StructRef): - return any(contains_transient(elt) for elt in ref.children()) - else: - raise ValueError(f"Unexpected ref type {type(ref)}") - - -def contains_not_in_memory(ref: Ref) -> bool: - if isinstance(ref, ValueRef): - return not ref.in_memory - elif isinstance(ref, StructRef): - return any(contains_not_in_memory(elt) for elt in ref.children()) - else: - raise ValueError(f"Unexpected ref type {type(ref)}") - - -def _sanitize_value(value: Any) -> Any: - if isinstance(value, Ref): - return (_sanitize_value(value.obj), value.in_memory, value.uid) - try: - hash(value) - return value - except TypeError: - if isinstance(value, bytearray): - return value.hex() - elif isinstance(value, list): - return tuple([_sanitize_value(v) for v in value]) - else: - raise NotImplementedError(f"Got value of type {type(value)}") - - -def compare_dfs_as_relations( - df_1: pd.DataFrame, df_2: pd.DataFrame, return_reason: bool = False -) -> Union[bool, Tuple[bool, str]]: - if df_1.shape != df_2.shape: - result, reason = False, f"Shapes differ: {df_1.shape} vs {df_2.shape}" - if set(df_1.columns) != set(df_2.columns): - result, reason = False, f"Columns differ: {df_1.columns} vs {df_2.columns}" - # reorder columns of df_2 to match df_1 - df_2 = df_2[df_1.columns] - # sanitize values to make them hashable - df_1 = df_1.applymap(_sanitize_value) - df_2 = df_2.applymap(_sanitize_value) - # compare as sets of tuples - result = set(map(tuple, df_1.itertuples(index=False))) == set( - map(tuple, df_2.itertuples(index=False)) - ) - if result: - reason = "" - else: - reason = f"Dataframe rows differ: {df_1} vs {df_2}" - if return_reason: - return result, reason - else: - return result diff --git a/mandala/deps/crawler.py b/mandala/deps/crawler.py index 75bee40..25f109c 100644 --- a/mandala/deps/crawler.py +++ b/mandala/deps/crawler.py @@ -1,6 +1,6 @@ import types from ..common_imports import * -from ..core.utils import unwrap_decorators +from ..utils import unwrap_decorators import importlib from .model import ( DepKey, diff --git a/mandala/deps/model.py b/mandala/deps/model.py index b0599d2..731bb0e 100644 --- a/mandala/deps/model.py +++ b/mandala/deps/model.py @@ -3,8 +3,8 @@ import types from ..common_imports import * -from ..core.utils import Hashing -from ..ui.viz import ( +from ..utils import get_content_hash +from ..viz import ( write_output, ) @@ -115,7 +115,7 @@ def representation(self) -> str: def _set_representation(self, value: str): assert isinstance(value, str) self._representation = value - self._content_hash = Hashing.get_content_hash(value) + self._content_hash = get_content_hash(value) @representation.setter def representation(self, value: str): @@ -138,8 +138,8 @@ def represent( obj: Union[types.FunctionType, types.CodeType, Callable], allow_fallback: bool = False, ) -> str: - if type(obj).__name__ == "FuncInterface": - obj = obj.func_op.func + if type(obj).__name__ == "Op": + obj = obj.f if not isinstance(obj, (types.FunctionType, types.MethodType, types.CodeType)): logger.warning(f"Found {obj} of type {type(obj)}") try: @@ -201,7 +201,7 @@ def representation(self) -> Tuple[str, str]: def represent(obj: Any, allow_fallback: bool = False) -> Tuple[str, str]: truncated_repr = textwrap.shorten(text=repr(obj), width=80) try: - content_hash = Hashing.get_content_hash(obj=obj) + content_hash = get_content_hash(obj=obj) except Exception as e: shortened_exception = textwrap.shorten(text=str(e), width=80) msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}" diff --git a/mandala/deps/shallow_versions.py b/mandala/deps/shallow_versions.py index c01dc14..98282f8 100644 --- a/mandala/deps/shallow_versions.py +++ b/mandala/deps/shallow_versions.py @@ -1,10 +1,10 @@ from typing import Literal import textwrap from ..common_imports import * -from ..core.utils import Hashing -from ..core.config import Config +from ..utils import get_content_hash +from ..config import Config from ..utils import ask_user -from ..ui.viz import _get_colorized_diff, _get_diff +from ..viz import _get_colorized_diff, _get_diff if Config.has_rich: from rich.tree import Tree @@ -110,7 +110,7 @@ def get_presentable_content(self, content: str) -> str: return content def get_content_hash(self, content: str) -> str: - return Hashing.get_content_hash(content) + return get_content_hash(content) GVContent = Tuple[str, str] # (content hash, repr) @@ -241,7 +241,7 @@ def commit(self, content: T, is_semantic_change: Optional[bool] = None) -> str: ) ) answer = ask_user( - question="Does this change require recomputation of dependent calls? [y]es/[n]o/[a]bort", + question="Does this change require recomputation of dependent calls?\nWARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\nAnswer: [y]es/[n]o/[a]bort", valid_options=["y", "n", "a"], ) print(f'You answered: "{answer}"') diff --git a/mandala/deps/tracers/dec_impl.py b/mandala/deps/tracers/dec_impl.py index 66b01bc..4de9a3a 100644 --- a/mandala/deps/tracers/dec_impl.py +++ b/mandala/deps/tracers/dec_impl.py @@ -1,8 +1,8 @@ import types from functools import wraps, update_wrapper from ...common_imports import * -from ...core.utils import unwrap_decorators -from ...core.config import Config +from ...utils import unwrap_decorators +from ...config import Config from ..model import ( DependencyGraph, CallableNode, @@ -40,6 +40,9 @@ def is_tracked(f: Union[types.FunctionType, type]) -> bool: class TrackedDict(dict): + """ + A dictionary that tracks global variable accesses. + """ def __init__(self, original: dict): self.__original__ = original @@ -147,6 +150,9 @@ def wrapper(*args, **kwargs) -> Any: class DecTracer(TracerABC): + """ + A decorator-based tracer that tracks function calls and global variable accesses. + """ def __init__( self, paths: List[Path], diff --git a/mandala/deps/tracers/tracer_base.py b/mandala/deps/tracers/tracer_base.py index 9542d2e..0081ca4 100644 --- a/mandala/deps/tracers/tracer_base.py +++ b/mandala/deps/tracers/tracer_base.py @@ -1,5 +1,5 @@ from ...common_imports import * -from ...core.config import Config +from ...config import Config import importlib from ..model import DependencyGraph, CallableNode from abc import ABC, abstractmethod diff --git a/mandala/deps/utils.py b/mandala/deps/utils.py index f2dbda4..bad7405 100644 --- a/mandala/deps/utils.py +++ b/mandala/deps/utils.py @@ -5,8 +5,8 @@ from typing import Literal from ..common_imports import * -from ..core.utils import Hashing, unwrap_decorators -from ..core.config import Config +from ..utils import get_content_hash, unwrap_decorators +from ..config import Config DepKey = Tuple[str, str] # (module name, object address in module) @@ -81,7 +81,7 @@ def is_callable_obj(obj: Any, strict: bool) -> bool: def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType: if type(obj).__name__ == Config.func_interface_cls_name: - return obj.func_op.func + return obj.f obj = unwrap_decorators(obj, strict=strict) if isinstance(obj, types.BuiltinFunctionType): raise ValueError(f"Expected a non-built-in function, but got {obj}") @@ -102,7 +102,7 @@ def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType: def extract_code(obj: Callable) -> types.CodeType: if type(obj).__name__ == Config.func_interface_cls_name: - obj = obj.func_op.func + obj = obj.f if isinstance(obj, property): obj = obj.fget obj = unwrap_decorators(obj, strict=True) @@ -169,7 +169,7 @@ def get_bytecode(f: Union[types.FunctionType, types.CodeType, str]) -> str: def hash_dict(d: dict) -> str: - return Hashing.get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())]) + return get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())]) def load_obj(module_name: str, obj_name: str) -> Tuple[Any, bool]: diff --git a/mandala/deps/versioner.py b/mandala/deps/versioner.py index 3fd9a67..3552080 100644 --- a/mandala/deps/versioner.py +++ b/mandala/deps/versioner.py @@ -3,8 +3,8 @@ import textwrap from ..common_imports import * -from ..core.utils import is_subdict -from ..core.config import Config +from ..utils import is_subdict +from ..config import Config from .utils import DepKey, hash_dict from .model import ( Node, @@ -15,7 +15,7 @@ ) from .crawler import crawl_static from .tracers import TracerABC -from ..ui.viz import _get_colorized_diff +from ..viz import _get_colorized_diff if Config.has_rich: from rich.panel import Panel @@ -84,6 +84,10 @@ def get_version_ids( """ Get the content and semantic IDs for the version corresponding to the given pre-call uid. + + Inputs: + - `is_recompute`: this should be true only if this is a call with + transient outputs that we already computed once. """ assert tracer_option is not None version = self.process_trace( @@ -335,6 +339,17 @@ def lookup_call( """ Return a tuple of (content_version, semantic_version), or None if the call is not found. + + Inputs: + - `pre_call_uid`: this is a hash of the content IDs of the inputs, + together with the function's name. + + This works as follows: + - we figure out the semantic hashes (i.e. shallow semantic versions) of + the elements of the code state present in the global topology we have on + record + - we restrict to the records that match the given `pre_call_uid` + - we search among these """ codebase_semantic_hashes = self.get_codestate_semantic_hashes( code_state=code_state @@ -408,7 +423,7 @@ def _check_semantic_distinguishability( ) if all([semantic_hashes[k] == new_semantic_dep_hashes[k] for k in overlap]): raise ValueError( - f"Call {pre_call_uid} is not semantically distinguishable from call for semantic version {semantic_version}" + f"Call to {component} with pre_call_uid={pre_call_uid} is not semantically distinguishable from call for semantic version {semantic_version}" ) ############################################################################ diff --git a/mandala/deps/viz.py b/mandala/deps/viz.py index c5baa9b..82907fb 100644 --- a/mandala/deps/viz.py +++ b/mandala/deps/viz.py @@ -1,6 +1,6 @@ import textwrap from ..common_imports import * -from ..ui.viz import ( +from ..viz import ( Node as DotNode, Edge as DotEdge, Group as DotGroup, diff --git a/mandala/_next/docs/01_storage_and_ops.ipynb b/mandala/docs/01_storage_and_ops.ipynb similarity index 100% rename from mandala/_next/docs/01_storage_and_ops.ipynb rename to mandala/docs/01_storage_and_ops.ipynb diff --git a/mandala/_next/docs/02_retracing.ipynb b/mandala/docs/02_retracing.ipynb similarity index 100% rename from mandala/_next/docs/02_retracing.ipynb rename to mandala/docs/02_retracing.ipynb diff --git a/mandala/_next/docs/03_cf.ipynb b/mandala/docs/03_cf.ipynb similarity index 100% rename from mandala/_next/docs/03_cf.ipynb rename to mandala/docs/03_cf.ipynb diff --git a/mandala/_next/docs/04_versions.ipynb b/mandala/docs/04_versions.ipynb similarity index 100% rename from mandala/_next/docs/04_versions.ipynb rename to mandala/docs/04_versions.ipynb diff --git a/mandala/_next/docs/05_collections.ipynb b/mandala/docs/05_collections.ipynb similarity index 100% rename from mandala/_next/docs/05_collections.ipynb rename to mandala/docs/05_collections.ipynb diff --git a/mandala/_next/docs/06_advanced_cf.ipynb b/mandala/docs/06_advanced_cf.ipynb similarity index 100% rename from mandala/_next/docs/06_advanced_cf.ipynb rename to mandala/docs/06_advanced_cf.ipynb diff --git a/mandala/_next/docs/make_docs.py b/mandala/docs/make_docs.py similarity index 100% rename from mandala/_next/docs/make_docs.py rename to mandala/docs/make_docs.py diff --git a/mandala/_next/docs/readme.md b/mandala/docs/readme.md similarity index 100% rename from mandala/_next/docs/readme.md rename to mandala/docs/readme.md diff --git a/mandala/imports.py b/mandala/imports.py index 75de7f7..66ff18c 100644 --- a/mandala/imports.py +++ b/mandala/imports.py @@ -1,11 +1,10 @@ -""" -Intended way for users to import mandala -""" -from .core.model import wrap_atom -from .core.wrapping import unwrap +from .storage import Storage +from .model import op, Ignore, NewArgDefault +from .tps import MList, MDict from .deps.tracers.dec_impl import track -from .queries import ListQ, SetQ, DictQ -from .core.config import Config -from .ui.storage import Storage -from .ui.funcs import op, superop, Q, Transient -from .ui.utils import wrap_ui as wrap + +from .common_imports import sess + + +def pprint_dict(d) -> str: + return '\n'.join([f" {k}: {v}" for k, v in d.items()]) \ No newline at end of file diff --git a/mandala/_next/model.py b/mandala/model.py similarity index 100% rename from mandala/_next/model.py rename to mandala/model.py diff --git a/mandala/queries/__init__.py b/mandala/queries/__init__.py deleted file mode 100644 index e7c481c..0000000 --- a/mandala/queries/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .weaver import BuiltinQueries - -ListQ = BuiltinQueries.ListQ -DictQ = BuiltinQueries.DictQ -SetQ = BuiltinQueries.SetQ diff --git a/mandala/queries/compiler.py b/mandala/queries/compiler.py deleted file mode 100644 index 902d43b..0000000 --- a/mandala/queries/compiler.py +++ /dev/null @@ -1,129 +0,0 @@ -from ..common_imports import * -from ..core.model import FuncOp -from ..core.sig import Signature -from ..core.config import Config, dump_output_name -from .weaver import ValNode, CallNode, traverse_all -from .viz import visualize_graph, get_names -from ..core.utils import invert_dict, OpKey -from pypika import Query, Table, Criterion - - -class Compiler: - def __init__(self, vqs: List[ValNode], fqs: List[CallNode]): - if not len(vqs) == len({id(x) for x in vqs}): - raise InternalError - if not len(fqs) == len({id(x) for x in fqs}): - raise InternalError - self.vqs = vqs - self.fqs = fqs - self.val_aliases, self.func_aliases = self._generate_aliases() - - def _generate_aliases(self) -> Tuple[Dict[ValNode, Table], Dict[CallNode, Table]]: - func_aliases = {} - func_counter = 0 - for i, func_query in enumerate(self.fqs): - op_table = Table(func_query.func_op.sig.versioned_ui_name) - func_aliases[func_query] = op_table.as_(f"_func_{func_counter}") - func_counter += 1 - val_aliases = {} - val_counter = 0 - for val_query in self.vqs: - val_table = Table(Config.causal_vref_table) - val_aliases[val_query] = val_table.as_(f"_var_{val_counter}") - val_counter += 1 - return val_aliases, func_aliases - - def compile_func( - self, fq: CallNode, semantic_versions: Optional[Set[str]] = None - ) -> Tuple[list, list]: - """ - Compile the query corresponding to an op, including built-in ops - """ - constraints = [] - select_fields = [] - func_alias = self.func_aliases[fq] - for input_name, val_query in fq.inputs.items(): - val_alias = self.val_aliases[val_query] - constraints.append(val_alias[Config.full_uid_col] == func_alias[input_name]) - for output_name, val_query in fq.outputs.items(): - val_alias = self.val_aliases[val_query] - constraints.append( - val_alias[Config.full_uid_col] == func_alias[output_name] - ) - if semantic_versions is not None: - constraints.append( - func_alias[Config.semantic_version_col].isin(semantic_versions) - ) - if fq.constraint is not None: - constraints.append(func_alias[Config.causal_uid_col].isin(fq.constraint)) - return constraints, select_fields - - def compile_val(self, val_query: ValNode) -> Tuple[list, list]: - """ - Compile the query corresponding to a variable - """ - constraints = [] - select_fields = [] - val_alias = self.val_aliases[val_query] - if val_query.constraint is not None: - constraints.append( - val_alias[Config.full_uid_col].isin(val_query.constraint) - ) - select_fields.append(val_alias[Config.full_uid_col]) - return constraints, select_fields - - def compile( - self, - select_queries: List[ValNode], - semantic_version_constraints: Optional[Dict[OpKey, Optional[Set[str]]]] = None, - filter_duplicates: bool = False, - ): - """ - Compile the query induced by the data of this compiler instance to - an SQL select query. If `semantic_version_constraints` is not provided, - no constraints are placed. - - NOTE: - - for each value query, we select both columns of the variable - table: the index and the partition. This is to be able to convert - the query result directly into locations. - - The list of columns, partitioned into sublists per value query, is - also returned. - """ - if not len(select_queries) == len({id(x) for x in select_queries}): - raise InternalError - assert all([vq in self.vqs for vq in select_queries]) - from_tables = [] - all_constraints = [] - select_cols = [ - self.val_aliases[vq][Config.full_uid_col] for vq in select_queries - ] - if semantic_version_constraints is None: - semantic_version_constraints = { - (op_query.func_op.sig.internal_name, op_query.func_op.sig.version): None - for op_query in self.fqs - } - for func_query in self.fqs: - op_key = ( - func_query.func_op.sig.internal_name, - func_query.func_op.sig.version, - ) - constraints, select_fields = self.compile_func( - func_query, semantic_versions=semantic_version_constraints[op_key] - ) - func_alias = self.func_aliases[func_query] - from_tables.append(func_alias) - all_constraints.extend(constraints) - for val_query in self.vqs: - val_alias = self.val_aliases[val_query] - constraints, select_fields = self.compile_val(val_query) - from_tables.append(val_alias) - all_constraints.extend(constraints) - query = Query - for table in from_tables: - query = query.from_(table) - query = query.select(*select_cols) - if filter_duplicates: - query = query.distinct() - query = query.where(Criterion.all(all_constraints)) - return query diff --git a/mandala/queries/graphs.py b/mandala/queries/graphs.py deleted file mode 100644 index aefc120..0000000 --- a/mandala/queries/graphs.py +++ /dev/null @@ -1,340 +0,0 @@ -from typing import Literal -from collections import deque -from ..core.utils import Hashing, invert_dict -from ..core.tps import StructType -from ..core.config import Provenance -from ..common_imports import * -from .weaver import ValNode, CallNode, Node - - -def get_deps(nodes: Set[Node]) -> Set[Node]: - res = set() - - def visit(node: Node): - if node not in res: - res.add(node) - for n in node.neighbors("backward"): - visit(n) - - for n in nodes: - visit(n) - return res - - -def prune( - vqs: Set[ValNode], fqs: Set[CallNode], selection: List[ValNode] -) -> Tuple[Set[ValNode], Set[CallNode]]: - # remove vqs of degree 1 that are not in selection - pass - - -def is_connected(val_queries: Set[ValNode], func_queries: Set[CallNode]) -> bool: - """ - Check if the graph defined by these objects is a connected graph - """ - if len(val_queries) == 0: - raise NotImplementedError - visited: Dict[Node, bool] = {} - - def visit(node: Node): - if node not in visited.keys(): - visited[node] = True - for n in node.neighbors("both"): - visit(n) - - start = list(val_queries)[0] - visit(start) - return len(visited) == len(val_queries) + len(func_queries) - - -def is_component(nodes: Set[Node]) -> bool: - """ - Check if the given nodes are a full component of the graph - """ - for node in nodes: - if any(n not in nodes for n in node.neighbors()): - return False - return True - - -def get_fibers(v_map: Dict[Node, str]) -> Dict[str, Set[Node]]: - fibers = defaultdict(set) - for vq, name in v_map.items(): - fibers[name].add(vq) - return fibers - - -def hash_groups(groups: Dict[str, Set[Node]]) -> str: - return Hashing.hash_set( - {Hashing.hash_set({str(id(v)) for v in g}) for g in groups.values()} - ) - - -class InducedSubgraph: - def __init__(self, vqs: Set[ValNode], fqs: Set[CallNode]): - self.vqs = vqs - self.fqs = fqs - self.nodes: Set[Node] = vqs.union(fqs) - - def neighbors( - self, node: Node, direction: Literal["forward", "backward", "both"] - ) -> Set[Node]: - return {n for n in node.neighbors(direction=direction) if n in self.nodes} - - def fq_inputs(self, fq: CallNode) -> Dict[str, ValNode]: - return {k: v for k, v in fq.inputs.items() if v in self.vqs} - - def fq_outputs(self, fq: CallNode) -> Dict[str, ValNode]: - return {k: v for k, v in fq.outputs.items() if v in self.vqs} - - def consumers(self, vq: ValNode) -> List[Tuple[str, CallNode]]: - return [(k, v) for k, v in zip(vq.consumed_as, vq.consumers) if v in self.fqs] - - def creators(self, vq: ValNode) -> List[Tuple[str, CallNode]]: - return [(k, v) for k, v in zip(vq.created_as, vq.creators) if v in self.fqs] - - def topsort( - self, canonical_labels: Optional[Dict[Node, str]] - ) -> Tuple[List[Node], Set[Node], Set[Node]]: - """ - Get the topsort, the sources, and the sinks. Optionally, provide a - canonical label for each node to get a canonical topsort. - """ - # return a topological sort of the graph, and the sources and sinks - in_degrees = { - n: len(self.neighbors(node=n, direction="backward")) for n in self.nodes - } - sources = {v for v in self.nodes if in_degrees[v] == 0} - sinks = set() - res = [] - S = [v for v in self.nodes if in_degrees[v] == 0] - if canonical_labels is not None: - S = sorted(S, key=lambda x: canonical_labels[x]) - while len(S) > 0: - v = S.pop(0) - res.append(v) - forward_neighbors = self.neighbors(node=v, direction="forward") - if len(forward_neighbors) == 0: - sinks.add(v) - for w in self.neighbors(node=v, direction="forward"): - in_degrees[w] -= 1 - if in_degrees[w] == 0: - S.append(w) - if canonical_labels is not None: - # TODO: inefficient, needs a heap instead - S = sorted(S, key=lambda x: canonical_labels[x]) - return res, sources, sinks - - def digest_neighbors_fq(self, fq: CallNode, colors: Dict[Node, str]) -> str: - neighs = {**self.fq_inputs(fq=fq), **self.fq_outputs(fq=fq)} - return Hashing.hash_dict({k: colors[v] for k, v in neighs.items()}) - - def digest_neighbors_vq(self, vq: ValNode, colors: Dict[Node, str]) -> str: - creator_labels = [ - Hashing.hash_list([k, colors[v]]) for k, v in self.creators(vq=vq) - ] - consumer_labels = [ - Hashing.hash_list([k, colors[v]]) for k, v in self.consumers(vq=vq) - ] - if isinstance(vq.tp, StructType): - lst = [ - Hashing.hash_set(set(creator_labels)), - Hashing.hash_set(set(consumer_labels)), - ] - else: - lst = [ - Hashing.hash_multiset(creator_labels), - Hashing.hash_multiset(consumer_labels), - ] - return Hashing.hash_list(lst) - - def run_color_refinement( - self, initialization: Dict[Node, str], verbose: bool = False - ) -> Dict[Node, str]: - """ - Run the color refinement algorithm on this graph. Return a dict mapping - each node to its color. - """ - colors = initialization - groups = get_fibers(initialization) - iteration = 0 - while True: - if verbose: - logger.info(f"Color refinement: iteration {iteration}") - new_colors = {} - for node in self.nodes: - if isinstance(node, CallNode): - neighbors_digest = self.digest_neighbors_fq(fq=node, colors=colors) - elif isinstance(node, ValNode): - neighbors_digest = self.digest_neighbors_vq(vq=node, colors=colors) - else: - raise Exception(f"Unknown node type {type(node)}") - new_colors[node] = Hashing.hash_list([colors[node], neighbors_digest]) - new_groups = get_fibers(new_colors) - if hash_groups(new_groups) == hash_groups(groups): - break - else: - groups = new_groups - colors = new_colors - iteration += 1 - return colors - - @staticmethod - def is_homomorphism( - s: "InducedSubgraph", - t: "InducedSubgraph", - v_map: Dict[ValNode, ValNode], - f_map: Dict[CallNode, CallNode], - ) -> bool: - for svq, tvq in v_map.items(): - if not isinstance(svq, ValNode): - return False - if not isinstance(tvq, ValNode): - return False - if not svq.tp == tvq.tp: - return False - for sfq, tfq in f_map.items(): - if not isinstance(sfq, CallNode): - return False - if not isinstance(tfq, CallNode): - return False - if not sfq.func_op.sig == tfq.func_op.sig: - return False - sfq_inps, sfq_outps = s.fq_inputs(fq=sfq), s.fq_outputs(fq=sfq) - tfq_inps, tfq_outps = t.fq_inputs(fq=tfq), t.fq_outputs(fq=tfq) - if not set(sfq_inps.keys()) == set(tfq_inps.keys()): - return False - if not set(sfq_outps.keys()) == set(tfq_outps.keys()): - return False - if not all(v_map[v] == tfq_inps[k] for k, v in sfq_inps.items()): - return False - if not all(v_map[v] == tfq_outps[k] for k, v in sfq_outps.items()): - return False - return True - - @staticmethod - def are_canonically_isomorphic(s: "InducedSubgraph", t: "InducedSubgraph") -> bool: - try: - s_vlabels, s_flabels, s_topsort = s.canonicalize(strict=True) - t_vlabels, t_flabels, t_topsort = t.canonicalize(strict=True) - except AssertionError: - raise NotImplementedError - if len(s_vlabels) != len(t_vlabels) or len(s_flabels) != len(t_flabels): - return False - t_vinverse, t_finverse = invert_dict(t_vlabels), invert_dict(t_flabels) - v_map = {vq: t_vinverse[v] for vq, v in s_vlabels.items()} - f_map = {fq: t_finverse[f] for fq, f in s_flabels.items()} - if not InducedSubgraph.is_homomorphism(s, t, v_map, f_map): - return False - if not InducedSubgraph.is_homomorphism( - t, s, invert_dict(v_map), invert_dict(f_map) - ): - return False - return True - - def get_projections( - self, colors: Dict[Node, str] - ) -> Tuple[Dict[ValNode, ValNode], Dict[CallNode, CallNode]]: - v_groups = defaultdict(list) - f_groups = defaultdict(list) - for node, color in colors.items(): - if isinstance(node, ValNode): - v_groups[color].append(node) - elif isinstance(node, CallNode): - f_groups[color].append(node) - v_map: Dict[ValNode, ValNode] = {} - f_map: Dict[CallNode, CallNode] = {} - for color, gp in v_groups.items(): - representative = gp[0] - prototype = ValNode( - tp=representative.tp, constraint=representative.constraint - ) - with_constraint = [vq for vq in gp if vq.constraint is not None] - if len(with_constraint) == 0: - representative_constraint = None - elif len(with_constraint) == 1: - representative_constraint = with_constraint[0].constraint - else: - raise ValueError("Multiple constraints for a single value") - prototype = ValNode( - tp=representative.tp, constraint=representative_constraint - ) - for fq in gp: - v_map[fq] = prototype - for color, gp in f_groups.items(): - representative = gp[0] - rep_inps = self.fq_inputs(fq=representative) - rep_outps = self.fq_outputs(fq=representative) - prototype = CallNode.link( - inputs={k: v_map[vq] for k, vq in rep_inps.items()}, - outputs={k: v_map[vq] for k, vq in rep_outps.items()}, - func_op=representative.func_op, - orientation=representative.orientation, - constraint=None, - ) - for fq in gp: - f_map[fq] = prototype - t = InducedSubgraph(vqs=set(v_map.values()), fqs=set(f_map.values())) - assert self.is_homomorphism(s=self, t=t, v_map=v_map, f_map=f_map) - return v_map, f_map - - def canonicalize( - self, strict: bool = False, method: str = "wl1" - ) -> Tuple[Dict[ValNode, str], Dict[CallNode, str], List[Node]]: - """ - Return canonical labels for each node, as well as a canonical topological sort. - """ - initialization = {} - for vq in self.vqs: - initialization[vq] = "null" - for fq in self.fqs: - initialization[fq] = fq.func_op.sig.versioned_internal_name - colors = self.run_color_refinement(initialization=initialization) - v_colors = {vq: colors[vq] for vq in self.vqs} - f_colors = {fq: colors[fq] for fq in self.fqs} - if strict: - assert len(set(v_colors.values())) == len(v_colors) - assert len(set(f_colors.values())) == len(f_colors) - canonical_topsort, sources, sinks = self.topsort( - canonical_labels={**v_colors, **f_colors} - ) - return v_colors, f_colors, canonical_topsort - - def project( - self, - ) -> Tuple[Dict[ValNode, ValNode], Dict[CallNode, CallNode], List[Node]]: - v_colors, f_colors, topsort = self.canonicalize() - colors: Dict[Node, str] = {**v_colors, **f_colors} - return *self.get_projections(colors=colors), topsort - - -def get_canonical_order(vqs: Set[ValNode], fqs: Set[CallNode]) -> List[ValNode]: - g = InducedSubgraph(vqs=vqs, fqs=fqs) - _, _, canonical_topsort = g.canonicalize() - return [vq for vq in canonical_topsort if isinstance(vq, ValNode)] - - -def copy_subgraph( - vqs: Set[ValNode], fqs: Set[CallNode] -) -> Tuple[Dict[ValNode, ValNode], Dict[CallNode, CallNode]]: - """ - Copy the subgraph supported on the given nodes. Return maps from the - original nodes to their copies. - """ - v_map = { - v: ValNode(tp=v.tp, constraint=v.constraint, name=v.name, refs=v.refs) - for v in vqs - } - f_map = {} - for fq in fqs: - inputs = {k: v_map[v] for k, v in fq.inputs.items() if v in vqs} - outputs = {k: v_map[v] for k, v in fq.outputs.items() if v in vqs} - f_map[fq] = CallNode.link( - inputs=inputs, - outputs=outputs, - func_op=fq.func_op, - orientation=fq.orientation, - constraint=fq.constraint, - calls=fq.calls, - ) - return v_map, f_map diff --git a/mandala/queries/main.py b/mandala/queries/main.py deleted file mode 100644 index 6e4ed6f..0000000 --- a/mandala/queries/main.py +++ /dev/null @@ -1,182 +0,0 @@ -from collections import Counter -from ..common_imports import * -from ..core.utils import OpKey -from ..core.config import parse_output_idx -from .weaver import ValNode, CallNode -from .compiler import Compiler -from .solver import NaiveQueryEngine -from .viz import get_names -from .graphs import ( - InducedSubgraph, - is_connected, - copy_subgraph, -) - - -class Querier: - @staticmethod - def check_df( - vqs: Set[ValNode], - fqs: Set[CallNode], - df: pd.DataFrame, - funcs: Dict[str, Callable], - ): - """ - Check validity of a query result projected from the given graph against - executables. - """ - cols = set(df.columns) - assert cols <= set(vq.name for vq in vqs) - for fq in fqs: - func = funcs[fq.func_op.sig.ui_name] - input_cols: Dict[str, str] = { - k: vq.name for k, vq in fq.inputs.items() if vq.name in cols - } - ouptut_cols: Dict[int, str] = { - parse_output_idx(k): vq.name - for k, vq in fq.outputs.items() - if vq.name in cols - } - for i, row in df.iterrows(): - inputs = {k: row[v] for k, v in input_cols.items()} - outputs = func(**inputs) - if not (isinstance(outputs, tuple)): - outputs = (outputs,) - for j, v in enumerate(outputs): - assert row[ouptut_cols[j]] == v - - @staticmethod - def execute_naive( - vqs: Set[ValNode], - fqs: Set[CallNode], - selection: List[ValNode], - memoization_tables: Dict[str, pd.DataFrame], - filter_duplicates: bool, - table_evaluator: Callable, - visualize_steps_at: Optional[Path] = None, - ) -> pd.DataFrame: - v_copymap, f_copymap = copy_subgraph(vqs=vqs, fqs=fqs) - vqs = set(v_copymap.values()) - fqs = set(f_copymap.values()) - select_copies = [v_copymap[vq] for vq in selection] - tables = { - f: memoization_tables[f.func_op.sig.versioned_internal_name] for f in fqs - } - query_graph = NaiveQueryEngine( - vqs=vqs, - fqs=fqs, - selection=select_copies, - tables=tables, - _table_evaluator=table_evaluator, - _visualize_steps_at=visualize_steps_at, - ) - logger.debug("Solving query...") - df = query_graph.solve() - if filter_duplicates: - df = df.drop_duplicates(keep="first") - return df - - @staticmethod - def compile( - selection: List[ValNode], - vqs: Set[ValNode], - fqs: Set[CallNode], - version_constraints: Optional[Dict[OpKey, Optional[Set[str]]]], - filter_duplicates: bool = True, - call_uids: Optional[Dict[Tuple[str, int], List[str]]] = None, - ) -> str: - """ - Execute the given queries and return the result as a pandas DataFrame. - """ - Querier.add_fq_constraints(fqs=fqs, call_uids=call_uids) - compiler = Compiler(vqs=vqs, fqs=fqs) - query = compiler.compile( - select_queries=selection, - filter_duplicates=filter_duplicates, - semantic_version_constraints=version_constraints, - ) - return query - - @staticmethod - def add_fq_constraints( - fqs: Set[CallNode], call_uids: Optional[Dict[Tuple[str, int], List[str]]] - ): - if call_uids is None: - return - for fq in fqs: - if not fq.func_op.is_builtin: - sig = fq.func_op.sig - op_id = (sig.internal_name, sig.version) - if op_id in call_uids: - fq.constraint = call_uids[op_id] - - @staticmethod - def validate_query( - vqs: Set[ValNode], - fqs: Set[CallNode], - selection: List[ValNode], - names: Dict[ValNode, str], - ): - if not selection: # empty selection - raise ValueError("Empty selection") - if len(vqs) == 0: # empty query - raise ValueError("Query is empty") - if not is_connected(val_queries=vqs, func_queries=fqs): # disconnected query - msg = f"Query is not connected! This could lead to a very large table.\n" - logger.warning(msg) - if not len(set(names.values())) == len(names): # duplicate names - duplicates = [k for k, v in Counter(names.values()).items() if v > 1] - raise ValueError("Duplicate names in value queries: " + str(duplicates)) - - @staticmethod - def prepare_projection_query( - vqs: Set[ValNode], - fqs: Set[CallNode], - selection: List[ValNode], - name_hints: Dict[ValNode, str], - ): - graph = InducedSubgraph(vqs=vqs, fqs=fqs) - v_map, f_map, _ = graph.project() - validate_projection( - source_selection=selection, - v_map=v_map, - source_selection_names={ - k: v for k, v in name_hints.items() if k in selection - }, - ) - target_selection = [v_map[vq] for vq in selection] - ### get the names in the projected graph - g = InducedSubgraph(vqs=set(v_map.values()), fqs=set(f_map.values())) - _, _, canonical_topsort = g.canonicalize() - target_name_hints = { - v_map[vq]: name for vq, name in name_hints.items() if vq in v_map.keys() - } - target_names = get_names( - hints=target_name_hints, - canonical_order=[vq for vq in canonical_topsort if isinstance(vq, ValNode)], - ) - assert set(target_names.keys()) == set(v_map.values()) - return v_map, f_map, target_selection, target_names - - -def validate_projection( - source_selection: List[ValNode], - v_map: Dict[ValNode, ValNode], - source_selection_names: Dict[ValNode, str], -): - """ - Check that the selected nodes in the source project to distinct nodes in the - target. Print out an error message if this is not the case. - - Here `names` is a (partial) dict from source nodes to names - """ - fibers = defaultdict(list) - for vq in source_selection: - fibers[v_map[vq]].append(vq) - if any(len(fiber) > 1 for fiber in fibers.values()): - # find the first fiber with more than one query - fiber = next(fiber for fiber in fibers.values() if len(fiber) > 1) - raise ValueError( - f"Ambiguous query: nodes {[source_selection_names.get(x, '?') for x in fiber]} have the " - f"same role in the computational graph." - ) diff --git a/mandala/queries/solver.py b/mandala/queries/solver.py deleted file mode 100644 index c83d028..0000000 --- a/mandala/queries/solver.py +++ /dev/null @@ -1,264 +0,0 @@ -from ..common_imports import * -from ..core.model import FuncOp -from ..core.sig import Signature -from ..core.config import Config -from .weaver import ValNode, CallNode, traverse_all -from .graphs import InducedSubgraph -from .viz import visualize_graph, get_names -from ..core.utils import invert_dict - - -class NaiveQueryEngine: - """ - Represents the graph expressing a query in a form that is suitable for - incrementally computing the join of all the tables. - - Used as an alternative to a RDBMS engine for computing queries. - - Should work with induced subgraphs - """ - - def __init__( - self, - vqs: Set[ValNode], - fqs: Set[CallNode], - selection: List[ValNode], - tables: Dict[CallNode, pd.DataFrame], - _table_evaluator: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, - _visualize_steps_at: Optional[Path] = None, - ): - self.vqs = vqs - self.fqs = list(fqs) - self.g = InducedSubgraph(vqs=vqs, fqs=fqs) - self.selection = selection - # {func query: table of data}. Note that there may be multiple func - # queries with the same table, but we keep separate references to each - # in order to enable recursively joining nodes in the graph. - - # pass to induced tables wrt the graph - self.induce_tables(tables=tables) - self.tables = tables - for k, v in self.tables.items(): - for col in Config.special_call_cols: - if col in v.columns: - v.drop(columns=[col], inplace=True) - - # for visualization - self._visualize_intermediate_states = _visualize_steps_at is not None - self._table_evaluator = _table_evaluator - self._visualize_steps_at = _visualize_steps_at - if self._visualize_intermediate_states: - assert self._table_evaluator is not None - - def induce_tables(self, tables: Dict[CallNode, pd.DataFrame]): - for fq, df in tables.items(): - inps, outps = self.get_fq_inputs(fq), self.get_fq_outputs(fq) - induced_keys = set(inps.keys()) | set(outps.keys()) - df.drop( - columns=[c for c in df.columns if c not in induced_keys], inplace=True - ) - - def get_fq_inputs(self, fq: CallNode) -> Dict[str, ValNode]: - if fq in self.g.fqs: - return self.g.fq_inputs(fq=fq) - else: - return fq.inputs - - def get_fq_outputs(self, fq: CallNode) -> Dict[str, ValNode]: - if fq in self.g.fqs: - return self.g.fq_outputs(fq=fq) - else: - return fq.outputs - - def _get_col_to_vq_mappings( - self, func_query: CallNode - ) -> Tuple[Dict[str, ValNode], Dict[ValNode, List[str]]]: - """ - Given a FuncQuery, returns: - - a mapping from column names to the ValQuery that they point to - - a mapping from ValQuery objects to the list of column names that point to it - """ - df = self.tables[func_query] - col_to_vq = {} - vq_to_cols = defaultdict(list) - for name, val_query in self.get_fq_inputs(func_query).items(): - assert name in df.columns - col_to_vq[name] = val_query - vq_to_cols[val_query].append(name) - for output_name, val_query in self.get_fq_outputs(func_query).items(): - assert output_name in df.columns - col_to_vq[output_name] = val_query - vq_to_cols[val_query].append(output_name) - return col_to_vq, vq_to_cols - - def _drop_self_constraints( - self, df: pd.DataFrame, vq_to_cols: Dict[ValNode, List[str]] - ) -> Tuple[pd.DataFrame, Dict[str, ValNode], Dict[ValNode, str]]: - new_col_to_vq = {} - df = df.copy() - for vq, cols in vq_to_cols.items(): - if len(cols) > 1: - representative = cols[0] - for col in cols[1:]: - df = df[df[representative] == df[col]] - df.drop(columns=col, inplace=True) - new_col_to_vq[cols[0]] = vq - new_vq_to_col = {vq: col for col, vq in new_col_to_vq.items()} - return df, new_col_to_vq, new_vq_to_col - - def _join_dataframes( - self, - df1: pd.DataFrame, - df2: pd.DataFrame, - left_on: List[str], - right_on: List[str], - ) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, str]]: - """ - Join two dataframes along the specified dimensions. - - Returns - - the result, - - together with coprojections from the columns of each dataframe to - the columns of the result. - """ - assert len(left_on) == len(right_on) - # rename the dataframe columns to avoid conflicts - mapping1 = {col: f"input_{i}" for i, col in enumerate(df1.columns)} - ncols1 = len(df1.columns) - mapping2 = {col: f"input_{i + ncols1}" for i, col in enumerate(df2.columns)} - renamed_df1 = df1.rename(columns=mapping1) - renamed_df2 = df2.rename(columns=mapping2) - # join the dataframes - logger.info(f"Joining tables of shapes {df1.shape} and {df2.shape}...") - start = time.time() - if len(left_on) == 0: - df = renamed_df1.merge(renamed_df2, how="cross") - else: - df = renamed_df1.merge( - renamed_df2, - left_on=[mapping1[col] for col in left_on], - right_on=[mapping2[col] for col in right_on], - ) - end = time.time() - logger.info(f"Join took {round(end - start, 3)} seconds") - # drop duplicate columns from the *right* dataframe - df.drop(columns=[mapping2[col] for col in right_on], inplace=True) - # construct the mapping functions from the columns of each dataframe to - # the columns of the result - from1 = mapping1 - from2 = {} - for col in df2.columns: - if mapping2[col] in df.columns: - # first, assign the columns that we did not drop - from2[col] = mapping2[col] - for left_col, right_col in zip(left_on, right_on): - # fill in the remaining ones - from2[right_col] = mapping1[left_col] - return df, from1, from2 - - def merge(self, f1: CallNode, f2: CallNode): - """ - Merge two func query nodes in the graph by joining their tables along - the columns that correspond to the shared inputs/outputs - """ - # get the data - df1, df2 = self.tables[f1], self.tables[f2] - # compute correspondence between columns and vqs - col_to_vq1, vq_to_cols1 = self._get_col_to_vq_mappings(f1) - col_to_vq2, vq_to_cols2 = self._get_col_to_vq_mappings(f2) - # apply self-join constraints - df1, col_to_vq1, vq_to_col1 = self._drop_self_constraints(df1, vq_to_cols1) - df2, col_to_vq2, vq_to_col2 = self._drop_self_constraints(df2, vq_to_cols2) - # compute the pairs of columns along which we need to join - # {shared value query: (col1, col2)} - intersection_vq_to_col_pairs = OrderedDict({}) - for col, vq in col_to_vq1.items(): - if vq in col_to_vq2.values(): - intersection_vq_to_col_pairs[vq] = (col, vq_to_col2[vq]) - left_on = [col for _, (col, _) in intersection_vq_to_col_pairs.items()] - right_on = [col for _, (_, col) in intersection_vq_to_col_pairs.items()] - df, from1, from2 = self._join_dataframes( - df1=df1, df2=df2, left_on=left_on, right_on=right_on - ) - # get the correspondence between columns and vqs for the new table - inputs = {} - for col in df.columns: - if col in from1.values(): - col_1 = invert_dict(from1)[col] - inputs[col] = col_to_vq1[col_1] - elif col in from2.values(): - col_2 = invert_dict(from2)[col] - inputs[col] = col_to_vq2[col_2] - else: - raise ValueError() - # insert new func query - sig = Signature( - ui_name="internal_node", - input_names=set(inputs.keys()), - n_outputs=0, - version=0, - defaults={}, - input_annotations={k: Any for k in inputs.keys()}, - output_annotations=[Any for _ in range(0)], - ) - func_op = FuncOp._from_sig(sig=sig) - f = CallNode(inputs=inputs, func_op=func_op, outputs={}, constraint=None) - for k, v in inputs.items(): - v.add_consumer(consumer=f, consumed_as=k) - self.tables[f] = df - self.fqs.append(f) - # remove the old func queries from the graph - f1.unlink() - f2.unlink() - self.fqs = [f for f in self.fqs if f not in (f1, f2)] - del self.tables[f1], self.tables[f2] - - ### solver and solver utils - def compute_intersection_size(self, f1: CallNode, f2: CallNode) -> int: - col_to_vq1, vq_to_cols1 = self._get_col_to_vq_mappings(f1) - col_to_vq2, vq_to_cols2 = self._get_col_to_vq_mappings(f2) - return len(set(col_to_vq1.values()) & set(col_to_vq2.values())) - - def _visualize_state(self, step_num: int): - assert self._visualize_intermediate_states - val_queries, func_queries = traverse_all(vqs=self.selection) - memoization_tables = { - k: self._table_evaluator(v) for k, v in self.tables.items() - } - visualize_graph( - vqs=val_queries, - fqs=func_queries, - names=get_names(hints={}, canonical_order=list(val_queries)), - output_path=self._visualize_steps_at / f"{step_num}.svg", - layout="bipartite", - memoization_tables=memoization_tables, - ) - - def solve(self, verbose: bool = False) -> pd.DataFrame: - step_num = 0 - while len(self.fqs) > 1: - if verbose: - logger.info(f"step {step_num}") - if self._visualize_intermediate_states: - self._visualize_state(step_num=step_num) - intersections = {} - # compute pairwise intersections - for f1 in self.fqs: - for f2 in self.fqs: - if f1 == f2: - continue - intersections[(f1, f2)] = self.compute_intersection_size(f1, f2) - # pick the pair with the largest intersection - f1, f2 = max(intersections, key=lambda x: intersections.get(x, 0)) - # merge the pair - self.merge(f1, f2) - step_num += 1 - if self._visualize_intermediate_states: - self._visualize_state(step_num=step_num) - assert len(self.tables) == 1 - df = self.tables[self.fqs[0]] - f = self.fqs[0] - # figure out which columns to select - col_to_vq, vq_to_cols = self._get_col_to_vq_mappings(f) - cols = [vq_to_cols[vq][0] for vq in self.selection] - return df[cols] diff --git a/mandala/queries/viz.py b/mandala/queries/viz.py deleted file mode 100644 index 6492bd5..0000000 --- a/mandala/queries/viz.py +++ /dev/null @@ -1,600 +0,0 @@ -from abc import ABC, abstractmethod -from ..common_imports import * -from typing import Literal -from ..core.config import parse_output_idx, Config -from ..core.model import Ref -from ..core.wrapping import unwrap -from ..core.tps import ListType, StructType, DictType, SetType -from .weaver import ( - ValNode, - CallNode, - StructOrientations, - get_items, - get_elts, - get_elt_and_struct, - get_idx, - is_key, - is_idx, - traverse_all, - get_vq_orientation, -) -from .graphs import InducedSubgraph, get_canonical_order -from ..ui.viz import ( - Node, - Edge, - SOLARIZED_LIGHT, - to_dot_string, - write_output, - HTMLBuilder, - Cell, -) -import textwrap - - -class ValueLoaderABC(ABC): - @abstractmethod - def load_value(self, full_uid: str) -> Any: - raise NotImplementedError - - -class GraphPrinter: - def __init__( - self, - vqs: Set[ValNode], - fqs: Set[CallNode], - value_loader: Optional[ValueLoaderABC] = None, - names: Optional[Dict[ValNode, str]] = None, - fnames: Optional[Dict[CallNode, str]] = None, - ): - self.vqs = vqs - self.fqs = fqs - self.value_loader = value_loader - self.g = InducedSubgraph(vqs=vqs, fqs=fqs) - self.v_labels, self.f_labels, self.vq_to_node = self.g.canonicalize() - self.full_topsort, self.sources, self.sinks = self.g.topsort( - canonical_labels={**self.v_labels, **self.f_labels} - ) - if names is None: - names = get_names( - hints={}, canonical_order=get_canonical_order(vqs=vqs, fqs=fqs) - ) - self.names = names - self.fnames = fnames - - def get_struct_comment( - self, - vq: ValNode, - elt_names: Tuple[str, ...], - idx_names: Optional[Tuple[str, ...]] = None, - ) -> str: - id_to_name = {"__list__": "list", "__dict__": "dict", "__set__": "set"} - if len(elt_names) == 1: - s = f"{self.names[vq]} will match any {id_to_name[vq.tp.struct_id]} containing a match for {elt_names[0]}" - if idx_names is not None: - s += f" at index {idx_names[0]}" - else: - s = f'{self.names[vq]} will match any {id_to_name[vq.tp.struct_id]} containing matches for each of {", ".join(elt_names)}' - if idx_names is not None: - s += f" at indices {', '.join(idx_names)}" - return s - - def get_source_comment(self, vq: ValNode) -> str: - if is_idx(vq=vq): - return "index into list" - elif is_key(vq=vq): - return "key into dict" - else: - return "input to computation; can match anything" - - def get_construct_computation_rhs(self, node: ValNode) -> str: - """ - Given a constructive struct, return - [name_0, ..., name_1] for lists - {key_0: name_0, ..., key_n: name_n} for dicts - {name_0, ..., name_n} for sets - """ - if isinstance(node.tp, (ListType, DictType)): - idxs_and_elts = list(get_items(vq=node).values()) - idxs = [x[0] for x in idxs_and_elts] - idx_values = [ - unwrap(self.value_loader.load_value(full_uid=idx.constraint[0])) - if idx.constraint is not None - else i - for i, idx in enumerate(idxs) - ] - elts = [x[1] for x in idxs_and_elts] - elt_names = tuple([str(self.names[elt]) for elt in elts]) - idx_value_to_elt_name = { - idx_value: elt_name - for idx_value, elt_name in zip(idx_values, elt_names) - } - if isinstance(node.tp, ListType): - assert sorted(idx_value_to_elt_name.keys()) == list(range(len(elts))) - rhs = f'[{", ".join([idx_value_to_elt_name[i] for i in range(len(elts))])}]' - elif isinstance(node.tp, DictType): - elt_strings = [ - f"'{idx_value}': {elt_name}" - for idx_value, elt_name in idx_value_to_elt_name.items() - ] - rhs = f'{", ".join(elt_strings)}' - rhs = f"{{{rhs}}}" - else: - raise RuntimeError - elif isinstance(node.tp, SetType): - elts = list(get_elts(vq=node).values()) - elt_names = tuple([str(self.names[elt]) for elt in elts]) - rhs = f'{", ".join(elt_names)}' - rhs = f"{{{rhs}}}" - else: - raise ValueError - return rhs - - def get_construct_query_rhs(self, node: ValNode) -> str: - if isinstance(node.tp, (ListType, DictType)): - idxs_and_elts = list(get_items(vq=node).values()) - idxs = [x[0] for x in idxs_and_elts] - elts = [x[1] for x in idxs_and_elts] - elt_names = tuple([str(self.names[elt]) for elt in elts]) - idx_names = tuple( - [str(self.names[idx]) if idx is not None else "?" for idx in idxs] - ) - comment = self.get_struct_comment( - vq=node, elt_names=elt_names, idx_names=idx_names - ) - if isinstance(node.tp, ListType): - rhs = f'ListQ(elts=[{", ".join(elt_names)}], idxs=[{", ".join(idx_names)}]) # {comment}' - elif isinstance(node.tp, DictType): - elt_strings = [ - f"{idx_name}: {elt_name}" - for idx_name, elt_name in zip(idx_names, elt_names) - ] - rhs = f'DictQ(dct={{..., {", ".join(elt_strings)}, ...}}) # {comment}' - else: - raise RuntimeError - elif isinstance(node.tp, SetType): - elts = list(get_elts(vq=node).values()) - elt_names = tuple([str(self.names[elt]) for elt in elts]) - comment = self.get_struct_comment(vq=node, elt_names=elt_names) - rhs = f'SetQ(elts=[{", ".join(elt_names)}]) # {comment}' - else: - raise ValueError - return rhs - - def get_op_computation_line( - self, - node: CallNode, - require_all_inputs: bool = False, - add_node_name: bool = True, - ) -> str: - """ - Get the line of code representing a function call of the form - output_0, output_1 = func_name(input_0=input_0, input_1=input_1) - """ - full_returns = node.returns_interp - full_returns_names = [ - self.names[vq] if vq is not None else "_" for vq in full_returns - ] - lhs = ", ".join(full_returns_names) - # rhs - args_dict = {arg_name: self.names[vq] for arg_name, vq in node.inputs.items()} - for inp_name in node.func_op.sig.input_names: - if inp_name not in args_dict and require_all_inputs: - raise RuntimeError - args_string = ", ".join( - [f"{k}={v}" for k, v in args_dict.items()] - + [f"{k}=_" for k in node.func_op.sig.input_names if k not in args_dict] - ) - rhs = f"{node.func_op.sig.ui_name}({args_string})" - if add_node_name: - rhs = rhs + " # OP NAME: " + self.fnames[node] - if len(lhs) == 0: - return rhs - else: - return f"{lhs} = {rhs}" - - def get_op_query_line(self, node: CallNode) -> str: - # lhs - full_returns = node.returns_interp - full_returns_names = [ - self.names[vq] if vq is not None else "_" for vq in full_returns - ] - lhs = ", ".join(full_returns_names) - # rhs - args_dict = {arg_name: self.names[vq] for arg_name, vq in node.inputs.items()} - for inp_name in node.func_op.sig.input_names: - if inp_name not in args_dict: - args_dict[inp_name] = "Q()" - args_string = ", ".join([f"{k}={v}" for k, v in args_dict.items()]) - rhs = f"{node.func_op.sig.ui_name}({args_string})" - if len(lhs) == 0: - return rhs - else: - return f"{lhs} = {rhs}" - - def get_destruct_computation_line(self, node: CallNode) -> str: - elt, struct = get_elt_and_struct(fq=node) - idx = get_idx(fq=node) - lhs = f"{self.names[elt]}" - if idx is None: - raise NotImplementedError - # inline the index value if the index is a source in the graph - if idx in self.sources: - if idx.constraint is None: - rhs = f"{self.names[struct]}[{self.names[idx]}]" - else: - assert len(idx.constraint) == 1 - idx_ref = self.value_loader.load_value(full_uid=idx.constraint[0]) - idx_value = unwrap(idx_ref) - if isinstance(idx_value, int): - rhs = f"{self.names[struct]}[{idx_value}]" - elif isinstance(idx_value, str): - rhs = f"{self.names[struct]}['{idx_value}']" - else: - raise RuntimeError - else: - rhs = f"{self.names[struct]}[{self.names[idx]}]" - return f"{lhs} = {rhs}" - - def get_destruct_query_line(self, node: CallNode) -> str: - elt, struct = get_elt_and_struct(fq=node) - idx = get_idx(fq=node) - lhs = f"{self.names[elt]}" - if idx is None: - idx_label = "?" - else: - idx_label = self.names[idx] - rhs = f"{self.names[struct]}[{idx_label}] # {self.names[elt]} will match any element of a match for {self.names[struct]} at index matching {idx_label}" - return f"{lhs} = {rhs}" - - def print_computational_graph( - self, show_sources_as: Literal["values", "uids", "omit", "name_only"] = "values" - ) -> str: - res = [] - for node in self.full_topsort: - if isinstance(node, ValNode): - if node in self.sources: - if is_idx(node) or is_key(node): - # exclude indices/keys if they are sources (we will inline them) - continue - if show_sources_as == "omit": - continue - elif show_sources_as == "name_only": - res.append(self.names[node]) - continue - assert node.constraint is not None - assert len(node.constraint) == 1 - if show_sources_as == "values": - ref = self.value_loader.load_value(full_uid=node.constraint[0]) - value = unwrap(ref) - rep = textwrap.shorten(repr(value), 25) - res.append(f"{self.names[node]} = {rep}") - elif show_sources_as == "uids": - uid, causal_uid = Ref.parse_full_uid( - full_uid=node.constraint[0] - ) - res.append( - f"{self.names[node]} = Ref(uid={uid}, causal_uid={causal_uid})" - ) - else: - raise ValueError - elif isinstance(node.tp, StructType): - if get_vq_orientation(node) != StructOrientations.construct: - continue - rhs = self.get_construct_computation_rhs(node=node) - lhs = f"{self.names[node]}" - line = f"{lhs} = {rhs}" - res.append(line) - elif isinstance(node, CallNode): - if not node.func_op.is_builtin: - res.append(self.get_op_computation_line(node=node)) - else: - if node.orientation == StructOrientations.destruct: - res.append(self.get_destruct_computation_line(node=node)) - return "\n".join(res) - - def print_query_graph( - self, - selection: Optional[List[ValNode]], - pprint: bool = False, - ): - res = [] - for node in self.full_topsort: - if isinstance(node, ValNode): - if node in self.sources: - comment = self.get_source_comment(vq=node) - res.append(f"{self.names[node]} = Q() # {comment}") - elif isinstance(node.tp, StructType): - if get_vq_orientation(node) != StructOrientations.construct: - continue - rhs = self.get_construct_query_rhs(node=node) - lhs = f"{self.names[node]}" - line = f"{lhs} = {rhs}" - res.append(line) - elif isinstance(node, CallNode): - if not node.func_op.is_builtin: - res.append(self.get_op_query_line(node=node)) - else: - if node.orientation == StructOrientations.destruct: - res.append(self.get_destruct_query_line(node=node)) - if selection is not None: - res.append( - f"result = storage.df({', '.join([self.names[vq] for vq in selection])})" - ) - res = [textwrap.indent(line, " ") for line in res] - if Config.has_rich and pprint: - from rich.syntax import Syntax - - highlighted = Syntax( - "\n".join(res), - "python", - theme="solarized-light", - line_numbers=False, - ) - # return Panel.fit(highlighted, title="Computational Graph") - return highlighted - else: - return "\n".join(res) - - -def print_graph( - vqs: Set[ValNode], - fqs: Set[CallNode], - names: Dict[ValNode, str], - selection: Optional[List[ValNode]], - pprint: bool = False, -): - printer = GraphPrinter(vqs=vqs, fqs=fqs, names=names) - s = printer.print_query_graph(selection=selection) - if Config.has_rich and pprint: - rich.print(s) - else: - print(s) - - -def graph_to_dot( - vqs: List[ValNode], - fqs: List[CallNode], - names: Dict[ValNode, str], - layout: Literal["computational", "bipartite"] = "computational", - memoization_tables: Optional[Dict[CallNode, pd.DataFrame]] = None, -) -> str: - # should work for subgraphs - assert set(vqs) <= set(names.keys()) - nodes = {} # val/op query -> Node obj - edges = [] - col_names = [] - counter = 0 - for vq in vqs: - if names[vq] is not None: - col_names.append(names[vq]) - else: - col_names.append(f"unnamed_{counter}") - counter += 1 - for vq, col_name in zip(vqs, col_names): - html_label = HTMLBuilder() - if hasattr(vq, "_hidden_message"): - msg = vq._hidden_message - text = f"{col_name} ({msg})" - else: - text = str(col_name) - html_label.add_row( - cells=[ - Cell( - text=text, - port=None, - bold=True, - bgcolor=SOLARIZED_LIGHT["orange"], - font_color=SOLARIZED_LIGHT["base3"], - ) - ] - ) - node = Node( - internal_name=str(id(vq)), - # label=str(val_query.column_name), - label=html_label.to_html_like_label(), - color=SOLARIZED_LIGHT["blue"], - shape="plain", - ) - nodes[vq] = node - for func_query in fqs: - html_label = HTMLBuilder() - func_preview = ( - # f'{func_query.displayname}({", ".join(func_query.inputs.keys())})' - func_query.displayname - ) - if hasattr(func_query, "_hidden_message"): - msg = func_query._hidden_message - func_preview = f"{func_preview} ({msg})" - title_cell = Cell( - text=func_preview, - port=None, - bgcolor=SOLARIZED_LIGHT["blue"], - bold=True, - font_color=SOLARIZED_LIGHT["base3"], - ) - input_cells = [] - output_cells = [] - for input_name in func_query.inputs.keys(): - input_cells.append(Cell(text=input_name, port=input_name, bold=True)) - # html_label.add_row(elts=[Cell(text=input_name, port=input_name)]) - for output_idx in range(len(func_query.outputs)): - output_cells.append( - Cell( - text=f"output_{output_idx}", port=f"output_{output_idx}", bold=True - ) - ) - if layout == "bipartite": - html_label.add_row(cells=[title_cell]) - if len(input_cells + output_cells) > 0: - html_label.add_row(cells=input_cells + output_cells) - if memoization_tables is not None: - column_names = [cell.text for cell in input_cells + output_cells] - port_names = [cell.port for cell in input_cells + output_cells] - # remove ports from the table column cells - for cell in input_cells + output_cells: - cell.port = None - df = memoization_tables[func_query][column_names] - rows = list(df.head().itertuples(index=False)) - for tup in rows: - html_label.add_row( - cells=[ - Cell(text=textwrap.shorten(str(x), 25), bold=True) - for x in tup - ] - ) - # add port names to the cells in the *last* row - for cell, port_name in zip(html_label.rows[-1], port_names): - cell.port = port_name - elif layout == "computational": - if len(input_cells) > 0: - html_label.add_row(input_cells) - html_label.add_row([title_cell]) - if len(output_cells) > 0: - html_label.add_row(output_cells) - else: - raise ValueError(f"Unknown layout: {layout}") - node = Node( - internal_name=str(id(func_query)), - # label=str(func_query.func_op.sig.ui_name), - label=html_label.to_html_like_label(), - color=SOLARIZED_LIGHT["red"], - shape="plain", - ) - nodes[func_query] = node - for input_name, vq in func_query.inputs.items(): - if not vq in nodes: - continue - if layout == "bipartite": - edges.append( - Edge( - target_node=nodes[vq], - source_node=nodes[func_query], - source_port=input_name, - arrowtail="none", - arrowhead="none", - ) - ) - elif layout == "computational": - edges.append( - Edge( - source_node=nodes[vq], - target_node=nodes[func_query], - target_port=input_name, - ) - ) - else: - raise ValueError(f"Unknown layout: {layout}") - for output_name, vq in func_query.outputs.items(): - if not vq in nodes: - continue - edges.append( - Edge( - source_node=nodes[func_query], - target_node=nodes[vq], - source_port=output_name, - arrowtail="none" if layout == "bipartite" else None, - arrowhead="none" if layout == "bipartite" else None, - ) - ) - return to_dot_string(nodes=list(nodes.values()), edges=edges, groups=[]) - - -def visualize_graph( - vqs: Set[ValNode], - fqs: Set[CallNode], - names: Optional[Dict[ValNode, str]], - layout: Literal["computational", "bipartite"] = "computational", - memoization_tables: Optional[Dict[CallNode, pd.DataFrame]] = None, - output_path: Optional[Path] = None, - show_how: Literal["none", "browser", "inline", "open"] = "none", -): - if names is None: - names = get_names( - hints={}, canonical_order=get_canonical_order(vqs=vqs, fqs=fqs) - ) - dot_string = graph_to_dot( - vqs=list(vqs), - fqs=list(fqs), - names=names, - layout=layout, - memoization_tables=memoization_tables, - ) - if output_path is None: - tempfile_obj, output_name = tempfile.mkstemp(suffix=".svg") - output_path = Path(output_name) - write_output( - output_path=output_path, - dot_string=dot_string, - output_ext="svg", - show_how=show_how, - ) - return output_path - - -def show(*vqs: ValNode): - vqs, fqs = traverse_all(vqs=list(vqs), direction="both") - visualize_graph(vqs=vqs, fqs=fqs, show_how="browser") - - -def extract_names_from_scope(scope: Dict[str, Any]) -> Dict[ValNode, str]: - """ - Heuristic to get deterministic name for all ValQueries we can find in the - scope. - """ - names_per_vq = defaultdict(list) - for k, v in scope.items(): - if k.startswith("_"): - continue - if isinstance(v, Ref) and v._query is not None: - names_per_vq[v.query].append(k) - elif isinstance(v, ValNode): - names_per_vq[v].append(k) - res = {} - for vq, names_per_vq in names_per_vq.items(): - if len(names_per_vq) > 1: - logger.warning( - f"Found multiple names for {vq}: {names_per_vq}, choosing {sorted(names_per_vq)[0]}" - ) - res[vq] = names_per_vq[0] - return res - - -def get_names( - hints: Dict[ValNode, str], canonical_order: List[ValNode] -) -> Dict[ValNode, str]: - """ - Get names for the given oredered list of ValQueries, using the following - priority: - - from the object itself; - - the hints; - - a_{counter} for the rest. - """ - counter = Count() - idx_counter = Count() - key_counter = Count() - existing_names = set(hints.values()) | { - vq.name for vq in canonical_order if vq.name is not None - } - res = {} - for vq in canonical_order: - if vq.name is not None: - res[vq] = vq.name - elif vq in hints: - res[vq] = hints[vq] - else: - if is_key(vq): - c, prefix = key_counter, "key" - elif is_idx(vq): - c, prefix = idx_counter, "idx" - else: - c, prefix = counter, "a" - while f"{prefix}{c.i}" in existing_names: - c.i += 1 - res[vq] = f"{prefix}{c.i}" - c.i += 1 - return res - - -class Count: - def __init__(self): - self.i = 0 diff --git a/mandala/queries/weaver.py b/mandala/queries/weaver.py deleted file mode 100644 index 5e6f486..0000000 --- a/mandala/queries/weaver.py +++ /dev/null @@ -1,710 +0,0 @@ -from abc import ABC, abstractmethod -from ..common_imports import * -from ..core.model import FuncOp, Ref, wrap_atom, Call -from ..core.wrapping import causify_atom -from ..core.config import dump_output_name, parse_output_idx -from ..core.tps import Type, ListType, DictType, SetType, AnyType, StructType -from ..core.builtins_ import Builtins, ListRef, DictRef, SetRef -from ..core.utils import Hashing, concat_lists, invert_dict -from typing import Literal -from typing import Sequence, Iterator - -T = TypeVar("T") -T1 = TypeVar("T1") - - -class StructOrientations: - # at runtime only - construct = "construct" - destruct = "destruct" - - -class Node(ABC): - @abstractmethod - def neighbors( - self, direction: Literal["backward", "forward", "both"] = "both" - ) -> List["Node"]: - raise NotImplementedError - - -class ValNode(Node): - def __init__( - self, - constraint: Optional[List[str]], - tp: Optional[Type], - creators: Optional[List["CallNode"]] = None, - created_as: Optional[List[str]] = None, - _label: Optional[str] = None, - name: Optional[str] = None, - refs: Optional[List[Ref]] = None, - ): - self.creators: List["CallNode"] = [] if creators is None else creators - self.created_as = [] if created_as is None else created_as - self.consumers: List["CallNode"] = [] - self.consumed_as: List[str] = [] - self.name = name - self.constraint = constraint - self.tp = tp - self._label: Optional[str] = _label - self._refs: Optional[List[Ref]] = None - self._refs_hash: Optional[str] = None - - if refs is not None: - self.refs = refs - - @property - def refs(self) -> List[Ref]: - return self._refs - - @refs.setter - def refs(self, value: List[Ref]): - # a single assignment expression guarantees no broken state - self._refs, self._refs_hash = value, ValNode.get_refs_hash(value) - - @staticmethod - def get_refs_hash(refs: List[Ref]) -> str: - return Hashing.get_content_hash( - obj=[r.causal_uid for r in refs], - ) - - @property - def refs_hash(self) -> str: - if self._refs_hash is None: - raise ValueError("No refs") - return self._refs_hash - - def add_consumer(self, consumer: "CallNode", consumed_as: str): - self.consumers.append(consumer) - self.consumed_as.append(consumed_as) - - def add_creator(self, creator: "CallNode", created_as: str): - assert isinstance(created_as, str) - self.creators.append(creator) - self.created_as.append(created_as) - - def neighbors( - self, direction: Literal["backward", "forward", "both"] = "both" - ) -> Set["CallNode"]: - backward = self.creators - forward = self.consumers - if direction == "backward": - return set(backward) - elif direction == "forward": - return set(forward) - elif direction == "both": - return set(backward + forward) - - def named(self, name: str) -> Any: - self.name = name - return self - - def __getitem__(self, idx: Union[int, str, "ValNode"]) -> "ValNode": - tp = self.tp or _infer_type(self) - if isinstance(tp, ListType): - return BuiltinQueries.GetListItemQuery( - lst=self, idx=qwrap(idx, tp=tp.elt_type) - ) - elif isinstance(tp, DictType): - assert isinstance(idx, (ValNode, str)) - return BuiltinQueries.GetDictItemQuery( - dct=self, key=qwrap(idx, tp=tp.elt_type) - ) - else: - raise NotImplementedError(f"Cannot index into query of type {tp}") - - def __repr__(self): - if self.name is not None: - return f"ValNode({self.name})" - else: - return f"ValNode({self.tp})" - - def inplace_mask(self, mask: np.ndarray): - """ - Inplace mask the refs of this `ValQuery` object. - """ - if self.refs is None: - raise ValueError("No refs to mask") - if isinstance(mask, np.ndarray): - assert mask.dtype == np.dtype("bool") - assert mask.shape[0] == len(self.refs) - self.refs = [r for r, m in zip(self.refs, mask) if m] - else: - raise NotImplementedError("Indexing only supported for boolean arrays") - - def get_constraint(self, *values) -> List[str]: - wrapped = [wrap_atom(v) for v in values] - for w in wrapped: - causify_atom(w) - return [w.full_uid for w in wrapped] - - def pin(self, *values): - if len(values) == 0: - raise ValueError("Must pin to at least one value") - else: - self.constraint = self.get_constraint(*values) - - def unpin(self): - self.constraint = None - - -def copy(vq: ValNode, label: Optional[str] = None) -> ValNode: - return ValNode( - constraint=vq.constraint, - tp=vq.tp, - creators=vq.creators, - created_as=vq.created_as, - _label=label, - ) - - -class CallNode(Node): - def __init__( - self, - inputs: Dict[str, ValNode], - func_op: FuncOp, - outputs: Dict[str, ValNode], - constraint: Optional[List[str]], - orientation: Optional[str] = None, - calls: Optional[List[Call]] = None, - ): - self.func_op = func_op - self.inputs = inputs - self.outputs = outputs - self.orientation = orientation - self.constraint = constraint - self._calls: Optional[List[Call]] = None - self._calls_hash: Optional[str] = None - - if calls is not None: - self.calls = calls - - @property - def calls(self) -> List[Call]: - return self._calls - - @property - def calls_hash(self) -> str: - if self._calls_hash is None: - raise ValueError("No df") - return self._calls_hash - - @calls.setter - def calls(self, value: List[Call]): - # a single assignment expression guarantees no broken state - self._calls, self._calls_hash = value, CallNode.get_calls_hash(value) - - @staticmethod - def get_calls_hash(calls: List[Call]) -> str: - return Hashing.get_content_hash( - [c.causal_uid for c in calls], - ) - - def inplace_mask(self, mask: np.ndarray): - """ - Inplace mask the call UIDs of this `FuncQuery` object. - """ - if self.calls is None: - raise ValueError("No call UIDs to mask") - if isinstance(mask, np.ndarray): - assert mask.dtype == np.dtype("bool") - assert mask.shape[0] == len(self.calls) - self.calls = [c for c, m in zip(self.calls, mask) if m] - else: - raise NotImplementedError("Indexing only supported for boolean arrays") - - def set_outputs(self, outputs: Dict[str, ValNode]): - self.outputs = outputs - - @property - def returns(self) -> List[ValNode]: - if self.func_op.is_builtin: - raise NotImplementedError() - else: - ord_outputs = {parse_output_idx(k): v for k, v in self.outputs.items()} - ord_outputs = [ord_outputs[i] for i in range(len(ord_outputs))] - return ord_outputs - - @property - def returns_interp(self) -> List[Optional[ValNode]]: - if self.func_op.is_builtin: - raise NotImplementedError() - else: - ord_outputs = {parse_output_idx(k): v for k, v in self.outputs.items()} - res = [] - for i in range(self.func_op.sig.n_outputs): - if i in ord_outputs: - res.append(ord_outputs[i]) - else: - res.append(None) - return res - - @property - def displayname(self) -> str: - data = { - "__list__": {"construct": "__list__", "destruct": "__getitem__"}, - "__dict__": {"construct": "__dict__", "destruct": "__getitem__"}, - "__set__": {"construct": "__set__", "destruct": "__getitem__"}, - } - if self.func_op._is_builtin: - return data[self.func_op.sig.ui_name][self.orientation] - else: - return self.func_op.sig.ui_name - - def neighbors( - self, direction: Literal["backward", "forward", "both"] = "both" - ) -> Set[ValNode]: - backward = list(self.inputs.values()) - forward = list(self.outputs.values()) - if direction == "backward": - return set(backward) - elif direction == "forward": - return set(forward) - elif direction == "both": - return set(backward + forward) - - @staticmethod - def link( - inputs: Dict[str, ValNode], - func_op: FuncOp, - outputs: Dict[str, ValNode], - constraint: Optional[List[str]], - orientation: Optional[str], - include_indexing: bool = True, - calls: Optional[List[Call]] = None, - ) -> "CallNode": - """ - Link a func query into the graph - """ - if func_op._is_builtin: - struct_id = func_op.sig.ui_name - assert orientation is not None - joined = {**inputs, **outputs} - if not include_indexing: - if struct_id == "__list__": - joined = {k: v for k, v in joined.items() if k != "idx"} - if struct_id == "__dict__": - joined = {k: v for k, v in joined.items() if k != "key"} - input_keys = Builtins.IO[orientation][struct_id]["in"] & joined.keys() - output_keys = Builtins.IO[orientation][struct_id]["out"] & joined.keys() - effective_inputs = {k: joined[k] for k in input_keys} - effective_outputs = {k: joined[k] for k in output_keys} - else: - assert orientation is None - effective_inputs = inputs - effective_outputs = outputs - result = CallNode( - inputs=effective_inputs, - func_op=func_op, - outputs=effective_outputs, - orientation=orientation, - constraint=constraint, - calls=calls, - ) - for name, inp in effective_inputs.items(): - inp.add_consumer(consumer=result, consumed_as=name) - for name, out in effective_outputs.items(): - out.add_creator(creator=result, created_as=name) - return result - - def unlink(self): - """ - Remove this `FuncQuery` from the graph. - """ - for inp in self.inputs.values(): - idxs = [i for i, x in enumerate(inp.consumers) if x is self] - inp.consumers = [x for i, x in enumerate(inp.consumers) if i not in idxs] - inp.consumed_as = [ - x for i, x in enumerate(inp.consumed_as) if i not in idxs - ] - for out in self.outputs.values(): - idxs = [i for i, x in enumerate(out.creators) if x is self] - out.creators = [x for i, x in enumerate(out.creators) if i not in idxs] - out.created_as = [x for i, x in enumerate(out.created_as) if i not in idxs] - - def __repr__(self): - args_string = ", ".join(f"{k}={v}" for k, v in self.inputs.items()) - if self.orientation is not None: - args_string += f", orientation={self.orientation}" - return f"CallNode({self.func_op.sig.ui_name}, {args_string})" - - -def traverse_all( - vqs: Set[ValNode], - direction: Literal["backward", "forward", "both"] = "both", -) -> Tuple[Set[ValNode], Set[CallNode]]: - """ - Extend the given `ValQuery` objects to all objects connected to them through - function inputs and/or outputs. - """ - vqs_ = {_ for _ in vqs} - fqs_: Set[CallNode] = set() - found_new = True - while found_new: - found_new = False - val_neighbors = concat_lists([v.neighbors(direction=direction) for v in vqs_]) - op_neighbors = concat_lists([o.neighbors(direction=direction) for o in fqs_]) - if any(k not in fqs_ for k in val_neighbors): - found_new = True - for neigh in val_neighbors: - if neigh not in fqs_: - fqs_.add(neigh) - if any(k not in vqs_ for k in op_neighbors): - found_new = True - for neigh in op_neighbors: - if neigh not in vqs_: - vqs_.add(neigh) - return vqs_, fqs_ - - -class BuiltinQueries: - @staticmethod - def ListQ( - elts: List[ValNode], idxs: Optional[List[Optional[ValNode]]] = None - ) -> ValNode: - result = ValNode( - creators=[], created_as=[], tp=ListType(elt_type=AnyType()), constraint=None - ) - if idxs is None: - idxs = [ - ValNode(constraint=None, tp=AnyType(), creators=[], created_as=[]) - for _ in elts - ] - for elt, idx in zip(elts, idxs): - CallNode.link( - inputs={"lst": result, "elt": elt, "idx": idx}, - func_op=Builtins.list_op, - outputs={}, - constraint=None, - orientation=StructOrientations.construct, - ) - return result - - @staticmethod - def DictQ(dct: Dict[ValNode, ValNode]) -> ValNode: - result = ValNode( - creators=[], created_as=[], tp=DictType(elt_type=AnyType()), constraint=None - ) - for key, val in dct.items(): - CallNode.link( - inputs={"dct": result, "key": key, "val": val}, - func_op=Builtins.dict_op, - outputs={}, - constraint=None, - orientation=StructOrientations.construct, - ) - return result - - @staticmethod - def SetQ(elts: Set[ValNode]) -> ValNode: - result = ValNode( - creators=[], created_as=[], tp=SetType(elt_type=AnyType()), constraint=None - ) - for elt in elts: - CallNode.link( - inputs={"st": result, "elt": elt}, - func_op=Builtins.set_op, - outputs={}, - constraint=None, - orientation=StructOrientations.construct, - ) - return result - - @staticmethod - def GetListItemQuery(lst: ValNode, idx: Optional[ValNode] = None) -> ValNode: - elt_tp = lst.tp.elt_type if isinstance(lst.tp, ListType) else None - result = ValNode(creators=[], created_as=[], tp=elt_tp, constraint=None) - CallNode.link( - inputs={"lst": lst, "elt": result, "idx": idx}, - func_op=Builtins.list_op, - outputs={}, - orientation=StructOrientations.destruct, - constraint=None, - ) - return result - - @staticmethod - def GetDictItemQuery(dct: ValNode, key: Optional[ValNode] = None) -> ValNode: - val_tp = dct.tp.elt_type if isinstance(dct.tp, DictType) else None - result = ValNode(creators=[], created_as=[], tp=val_tp, constraint=None) - CallNode.link( - inputs={"dct": dct, "key": key, "val": result}, - func_op=Builtins.dict_op, - outputs={}, - orientation=StructOrientations.destruct, - constraint=None, - ) - return result - - ############################################################################ - ### syntactic sugar - ############################################################################ - @staticmethod - def is_pattern(obj: Any) -> bool: - if type(obj) is list and Ellipsis in obj: - return all( - BuiltinQueries.is_pattern(elt) or isinstance(elt, ValNode) - for elt in obj - if elt is not Ellipsis - ) - elif type(obj) is dict and Ellipsis in obj: - return all( - BuiltinQueries.is_pattern(elt) or isinstance(elt, ValNode) - for elt in obj.values() - if elt is not Ellipsis - ) - elif type(obj) is set: - return all( - BuiltinQueries.is_pattern(elt) or isinstance(elt, ValNode) - for elt in obj - if elt is not Ellipsis - ) - else: - return False - - @staticmethod - def link_pattern(obj: Union[list, dict, set, ValNode]) -> ValNode: - if isinstance(obj, ValNode): - return obj - elif type(obj) is list: - elts = [ - BuiltinQueries.link_pattern(elt) for elt in obj if elt is not Ellipsis - ] - result = ValNode( - creators=[], - created_as=[], - tp=ListType(elt_type=AnyType()), - constraint=None, - ) - for elt in elts: - CallNode.link( - inputs={ - "lst": result, - "elt": elt, - "idx": ValNode( - constraint=None, tp=AnyType(), creators=[], created_as=[] - ), - }, - func_op=Builtins.list_op, - outputs={}, - constraint=None, - orientation=StructOrientations.construct, - ) - elif type(obj) is dict: - elts = { - k: BuiltinQueries.link_pattern(v) - for k, v in obj.items() - if k is not Ellipsis - } - result = ValNode( - creators=[], - created_as=[], - tp=DictType(elt_type=AnyType()), - constraint=None, - ) - for k, v in elts.items(): - CallNode.link( - inputs={"dct": result, "key": k, "val": v}, - func_op=Builtins.dict_op, - outputs={}, - constraint=None, - orientation=StructOrientations.construct, - ) - elif type(obj) is set: - elts = { - BuiltinQueries.link_pattern(elt) for elt in obj if elt is not Ellipsis - } - result = ValNode( - creators=[], - created_as=[], - tp=SetType(elt_type=AnyType()), - constraint=None, - ) - for elt in elts: - CallNode.link( - inputs={"st": result, "elt": elt}, - func_op=Builtins.set_op, - outputs={}, - constraint=None, - orientation=StructOrientations.construct, - ) - else: - raise ValueError - return result - - -def qwrap(obj: Any, tp: Optional[Type] = None, strict: bool = False) -> ValNode: - """ - Produce a ValQuery from an object. - """ - if isinstance(obj, ValNode): - return obj - elif isinstance(obj, Ref): - assert obj._query is not None, "Ref must be linked to a query" - return obj.query - elif BuiltinQueries.is_pattern(obj=obj): - if not strict: - return BuiltinQueries.link_pattern(obj=obj) - else: - raise ValueError - else: - if strict: - raise ValueError("value must be a `ValQuery` or `Ref`") - if tp is None: - tp = AnyType() - # wrap a raw value as a pointwise constraint - uid = obj.uid if isinstance(obj, Ref) else Hashing.get_content_hash(obj) - return ValNode( - tp=tp, - creators=[], - created_as=[], - constraint=[uid], - ) - - -def call_query( - func_op: FuncOp, inputs: Dict[str, Union[list, dict, set, ValNode, Ref, Any]] -) -> List[ValNode]: - for k in inputs.keys(): - inputs[k] = qwrap(obj=inputs[k]) - assert all(isinstance(inp, ValNode) for inp in inputs.values()) - ord_outputs = [ - ValNode(creators=[], created_as=[], tp=tp, constraint=None) - for tp in func_op.output_types - ] - outputs = {dump_output_name(index=i): o for i, o in enumerate(ord_outputs)} - CallNode.link( - inputs=inputs, - func_op=func_op, - outputs=outputs, - orientation=None, - constraint=None, - ) - return ord_outputs - - -################################################################################ -### introspection -################################################################################ -def _infer_type(val_query: ValNode) -> Type: - consumer_op_names = [c.func_op.sig.ui_name for c in val_query.consumers] - mapping = {"__list__": ListType(), "__dict__": DictType(), "__set__": SetType()} - tps = [mapping.get(x, None) for x in consumer_op_names] - struct_tps = [x for x in tps if x is not None] - if len(struct_tps) == 0: - return AnyType() - elif len(struct_tps) == 1: - return struct_tps[0] - else: - raise RuntimeError(f"Multiple types for {val_query}: {struct_tps}") - - -def get_vq_orientation(vq: ValNode) -> str: - if not isinstance(vq.tp, StructType): - raise ValueError - if ( - len(vq.creators) == 1 - and vq.creators[0].orientation == StructOrientations.destruct - ): - return StructOrientations.destruct - elif len(vq.creators) == 1 and vq.creators[0].orientation is None: - return StructOrientations.destruct - else: - return StructOrientations.construct - - -def is_idx(vq: ValNode) -> bool: - for consumer, consumed_as in zip(vq.consumers, vq.consumed_as): - if consumed_as == "idx" and consumer.func_op.sig.ui_name == "__list__": - return True - return False - - -def is_key(vq: ValNode) -> bool: - for consumer, consumed_as in zip(vq.consumers, vq.consumed_as): - if consumed_as == "key" and consumer.func_op.sig.ui_name == "__dict__": - return True - return False - - -def get_elt_fqs(vq: ValNode) -> List[CallNode]: - assert isinstance(vq.tp, StructType) - struct_id = vq.tp.struct_id - orientation = get_vq_orientation(vq) - fqs_to_search = ( - vq.consumers if orientation == StructOrientations.destruct else vq.creators - ) - fqs = [ - fq - for fq in fqs_to_search - if fq.func_op.sig.ui_name == struct_id and fq.orientation == orientation - ] - return fqs - - -def get_elts(vq: ValNode) -> Dict[CallNode, ValNode]: - """ - Get the constituent element queries of a set, as a dictionary of {fq: vq} - pairs. - """ - orientation = get_vq_orientation(vq) - elt_fqs = get_elt_fqs(vq) - return { - fq: fq.inputs["elt"] - if orientation == StructOrientations.construct - else fq.outputs["elt"] - for fq in elt_fqs - } - - -def get_items(vq: ValNode) -> Dict[CallNode, Tuple[Optional[ValNode], ValNode]]: - """ - Get the constituent elements and indices of a list or dict in the form of - {fq: (idx_vq, elt_vq)} pairs. - """ - assert isinstance(vq.tp, (ListType, DictType)) - orientation = get_vq_orientation(vq) - elt_fqs = get_elt_fqs(vq) - elt_key = "elt" if isinstance(vq.tp, ListType) else "val" - idx_key = "idx" if isinstance(vq.tp, ListType) else "key" - return { - fq: (fq.inputs.get(idx_key), fq.inputs[elt_key]) - if orientation == StructOrientations.destruct - else (fq.inputs.get(idx_key), fq.inputs[elt_key]) - for fq in elt_fqs - } - - -def get_elt_and_struct(fq: CallNode) -> Tuple[ValNode, ValNode]: - assert fq.func_op.is_builtin - struct_id = fq.func_op.sig.ui_name - elt_target = ( - fq.outputs if fq.orientation == StructOrientations.destruct else fq.inputs - ) - struct_target = ( - fq.inputs if fq.orientation == StructOrientations.destruct else fq.outputs - ) - if struct_id == "__list__": - return elt_target["elt"], struct_target["lst"] - elif struct_id == "__dict__": - return elt_target["val"], struct_target["dct"] - elif struct_id == "__set__": - return elt_target["elt"], struct_target["st"] - else: - raise NotImplementedError() - - -def get_idx(fq: CallNode) -> Optional[ValNode]: - assert fq.func_op.is_builtin - struct_id = fq.func_op.sig.ui_name - idx_target = fq.inputs - if struct_id == "__list__": - return idx_target.get("idx", None) - elif struct_id == "__dict__": - return idx_target.get("key", None) - else: - raise ValueError - - -def prepare_query(ref: Ref, tp: Type): - if ref._query is None: - ref._query = ValNode(tp=tp, constraint=None, creators=[], created_as=[]) diff --git a/mandala/queries/workflow.py b/mandala/queries/workflow.py deleted file mode 100644 index 2643586..0000000 --- a/mandala/queries/workflow.py +++ /dev/null @@ -1,288 +0,0 @@ -from ..common_imports import * -from .weaver import * -from ..core.model import Ref, FuncOp -from ..core.utils import Hashing -from ..core.config import dump_output_name, parse_output_idx - - -class CallStruct: - def __init__(self, func_op: FuncOp, inputs: Dict[str, Ref], outputs: List[Ref]): - self.func_op = func_op - self.inputs = inputs - self.outputs = outputs - - -class Workflow: - """ - An intermediate representation of a collection of calls, possibly not all of - which have been executed, following a particular computational graph. - - Used to: - - represent work to be done in a `batch` context as a data structure - - encode an entire workflow for e.g. testing scenarios that simulate - real-world workloads - """ - - def __init__(self): - ### encoding the shape - # in topological order - self.var_nodes: List[ValNode] = [] - # note that there may be many var nodes with the same causal hash - self.var_node_to_causal_hash: Dict[ValNode, str] = {} - # in topological order - self.op_nodes: List[CallNode] = [] - # self.causal_hash_to_op_node: Dict[str, FuncQuery] = {} - self.op_node_to_causal_hash: Dict[CallNode, str] = {} - ### encoding instance data - # multiple refs may map to the same query node - self.value_to_var: Dict[Ref, ValNode] = {} - # for a given op node, there may be multiple call structs - self.op_node_to_call_structs: Dict[CallNode, List[CallStruct]] = {} - - def check_invariants(self): - assert set(self.var_node_to_causal_hash.keys()) == set(self.var_nodes) - assert set(self.op_node_to_causal_hash.keys()) == set(self.op_nodes) - assert set(self.op_node_to_call_structs.keys()) == set(self.op_nodes) - assert set(self.value_to_var.values()) <= set(self.var_nodes) - for op_node in self.op_nodes: - for call_struct in self.op_node_to_call_structs[op_node]: - input_locations = { - k: self.value_to_var[v] for k, v in call_struct.inputs.items() - } - assert input_locations == op_node.inputs - output_locations = { - dump_output_name(i): self.value_to_var[v] - for i, v in enumerate(call_struct.outputs) - } - assert output_locations == op_node.outputs - - def get_default_hash(self) -> str: - return Hashing.get_content_hash(obj="null") - - @property - def callable_op_nodes(self) -> List[CallNode]: - # return op nodes that have non-empty inputs - res = [] - var_to_values = self.var_to_values() - for op_node in self.op_nodes: - if all([len(var_to_values[var]) > 0 for var in op_node.inputs.values()]): - res.append(op_node) - return res - - @property - def inputs(self) -> List[ValNode]: - # return [var for var in self.var_nodes if var.creator is None] - return [var for var in self.var_nodes if len(var.creators) == 0] - - def var_to_values(self) -> Dict[ValNode, List[Ref]]: - res = defaultdict(list) - for value, var in self.value_to_var.items(): - res[var].append(value) - return res - - def add_var(self, val_query: Optional[ValNode] = None) -> ValNode: - res = ( - val_query - if val_query is not None - # else ValQuery(creator=None, created_as=None) - else ValNode(creators=[], created_as=[], constraint=None, tp=AnyType()) - ) - # if res.creator is None: - if len(res.creators) == 0: - causal_hash = self.get_default_hash() - else: - # creator_hash = self.op_node_to_causal_hash[res.creator] - creator_hash = self.op_node_to_causal_hash[res.creators[0]] - # causal_hash = Hashing.get_content_hash(obj=[creator_hash, - # res.created_as]) - causal_hash = Hashing.get_content_hash( - obj=[creator_hash, res.created_as[0]] - ) - self.var_nodes.append(res) - self.var_node_to_causal_hash[res] = causal_hash - return res - - def get_op_hash( - self, - func_op: FuncOp, - node_inputs: Optional[Dict[str, ValNode]] = None, - val_inputs: Optional[Dict[str, Ref]] = None, - ) -> str: - assert (node_inputs is None) != (val_inputs is None) - if val_inputs is not None: - node_inputs = { - name: self.value_to_var[val] for name, val in val_inputs.items() - } - assert node_inputs is not None - input_causal_hashes = { - name: self.var_node_to_causal_hash[val] for name, val in node_inputs.items() - } - input_causal_hashes = { - k: v for k, v in input_causal_hashes.items() if v != self.get_default_hash() - } - input_causal_hashes = sorted(input_causal_hashes.items()) - op_representation = [ - input_causal_hashes, - func_op.sig.versioned_internal_name, - ] - causal_hash = Hashing.get_content_hash(obj=op_representation) - return causal_hash - - def add_op( - self, - inputs: Dict[str, ValNode], - func_op: FuncOp, - ) -> Tuple[CallNode, Dict[str, ValNode]]: - # TODO: refactor the `FuncQuery` creation here - res = CallNode(inputs=inputs, func_op=func_op, outputs={}, constraint=None) - causal_hash = self.get_op_hash(node_inputs=inputs, func_op=func_op) - self.op_nodes.append(res) - self.op_node_to_causal_hash[res] = causal_hash - # create outputs - outputs = {} - for i in range(func_op.sig.n_outputs): - # output = self.add_var(val_query=ValQuery(creator=res, - # created_as=i)) - output_name = dump_output_name(index=i) - output = self.add_var( - val_query=ValNode( - creators=[res], - created_as=[output_name], - constraint=None, - tp=AnyType(), - ) - ) - outputs[output_name] = output - # assign outputs to op - res.set_outputs(outputs=outputs) - self.op_node_to_call_structs[res] = [] - return res, outputs - - def add_value(self, value: Ref, var: ValNode): - assert var in self.var_nodes - self.value_to_var[value] = var - - def add_call_struct(self, call_struct: CallStruct): - # process inputs - func_op, inputs, outputs = ( - call_struct.func_op, - call_struct.inputs, - call_struct.outputs, - ) - if any([inp not in self.value_to_var.keys() for inp in inputs.values()]): - raise NotImplementedError() - op_hash = self.get_op_hash(func_op=func_op, val_inputs=inputs) - if op_hash not in self.op_node_to_causal_hash.values(): - # create op - op_node, output_nodes = self.add_op( - inputs={name: self.value_to_var[inp] for name, inp in inputs.items()}, - func_op=func_op, - ) - else: - candidates = [ - op_node - for op_node in self.op_nodes - if self.op_node_to_causal_hash[op_node] == op_hash - and op_node.inputs - == {name: self.value_to_var[inp] for name, inp in inputs.items()} - ] - op_node = candidates[0] - output_nodes = op_node.outputs - # process outputs - outputs_dict = {dump_output_name(i): output for i, output in enumerate(outputs)} - for k in outputs_dict.keys(): - self.value_to_var[outputs_dict[k]] = output_nodes[k] - self.op_node_to_call_structs[op_node].append(call_struct) - - ############################################################################ - ### - ############################################################################ - @staticmethod - def from_call_structs(call_structs: List[CallStruct]) -> "Workflow": - """ - Assumes calls are given in topological order - """ - res = Workflow() - input_var = res.add_var() - for call_struct in call_structs: - inputs = call_struct.inputs - for inp in inputs.values(): - if inp not in res.value_to_var.keys(): - res.add_value(value=inp, var=input_var) - res.add_call_struct(call_struct) - return res - - @staticmethod - def from_traversal( - vqs: List[ValNode], - ) -> Tuple["Workflow", Dict[ValNode, ValNode]]: - vqs, fqs = traverse_all(vqs, direction="backward") - vqs_topsort = reversed(vqs) - fqs_topsort = reversed(fqs) - # input_vqs = [vq for vq in vqs_topsort if vq.creator is None] - input_vqs = [vq for vq in vqs_topsort if len(vq.creators) == 0] - res = Workflow() - vq_to_new_vq = {} - for vq in input_vqs: - new_vq = res.add_var(val_query=vq) - vq_to_new_vq[vq] = new_vq - for fq in fqs_topsort: - new_inputs = {name: vq_to_new_vq[vq] for name, vq in fq.inputs.items()} - new_fq, new_outputs = res.add_op(inputs=new_inputs, func_op=fq.func_op) - for k in new_outputs.keys(): - vq, new_vq = fq.outputs[k], new_outputs[k] - vq_to_new_vq[vq] = new_vq - return res, vq_to_new_vq - - ############################################################################ - ### - ############################################################################ - @property - def empty(self) -> bool: - return len(self.value_to_var) == 0 - - @property - def shape_size(self) -> int: - return len(self.op_nodes) + len(self.var_nodes) - - @property - def num_calls(self) -> int: - return sum( - [ - len(call_structs) - for call_structs in self.op_node_to_call_structs.values() - ] - ) - - @property - def is_saturated(self) -> bool: - var_to_values = self.var_to_values() - return all([len(var_to_values[var]) > 0 for var in self.var_nodes]) - - @property - def has_delayed(self) -> bool: - return any([value.is_delayed() for value in self.value_to_var.keys()]) - - def print_shape(self): - var_names = {var: f"var_{i}" for i, var in enumerate(self.var_nodes)} - for var in self.inputs: - print(f"{var_names[var]} = Q()") - for op_node in self.op_nodes: - numbered_outputs = { - parse_output_idx(k): v for k, v in op_node.outputs.items() - } - outputs_list = [numbered_outputs[i] for i in range(len(numbered_outputs))] - lhs = ", ".join([var_names[var] for var in outputs_list]) - print( - f"{lhs} = {op_node.func_op.sig.ui_name}(" - + ", ".join( - [f"{name}={var_names[var]}" for name, var in op_node.inputs.items()] - ) - + ")" - ) - - -class History: - def __init__(self, workflow: Workflow, node: ValNode): - self.workflow = workflow - self.node = node diff --git a/mandala/_next/storage.py b/mandala/storage.py similarity index 100% rename from mandala/_next/storage.py rename to mandala/storage.py diff --git a/mandala/_next/storage_utils.py b/mandala/storage_utils.py similarity index 100% rename from mandala/_next/storage_utils.py rename to mandala/storage_utils.py diff --git a/mandala/storages/__init__.py b/mandala/storages/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mandala/storages/kv.py b/mandala/storages/kv.py deleted file mode 100644 index 64d88b0..0000000 --- a/mandala/storages/kv.py +++ /dev/null @@ -1,135 +0,0 @@ -from ..common_imports import * -from typing import Generic, TypeVar -from multiprocessing import Manager -from multiprocessing.managers import DictProxy - - -_KT = TypeVar("_KT") -_VT = TypeVar("_VT") - - -class KVCache(Generic[_KT, _VT]): - """ - Interface for key-value stores for Python objects (keyed by strings). - """ - - def __init__(self): - self.dirty_entries = set() - - def exists(self, k: _KT) -> bool: - raise NotImplementedError - - def set(self, k: _KT, v: _VT) -> None: - raise NotImplementedError - - def get(self, k: _KT) -> Any: - raise NotImplementedError - - def __getitem__(self, k: _KT) -> _VT: - raise NotImplementedError - - def __setitem__(self, k: _KT, v: _VT) -> None: - raise NotImplementedError - - def delete(self, k: _KT) -> None: - raise NotImplementedError - - def keys(self) -> List[_KT]: - raise NotImplementedError - - def items(self) -> Iterable[Tuple[_KT, _VT]]: - raise NotImplementedError - - def evict_all(self) -> None: - """ - Remove all entries from the cache. - """ - raise NotImplementedError - - def clear_all(self) -> None: - """ - Mark all entries as clean. - """ - raise NotImplementedError - - -class InMemoryStorage(KVCache): - """ - Simple in-memory implementation for local testing and/or buffering - """ - - def __init__(self): - self.data: dict[str, Any] = {} - self.dirty_entries: set[str] = set() - - def __repr__(self): - return f"InMemoryStorage(data={self.data})" - - def exists(self, k: str) -> bool: - return k in self.data - - def set(self, k: str, v: Any): - self.data[k] = v - self.dirty_entries.add(k) - - def get(self, k: str) -> Any: - return self.data[k] - - def __getitem__(self, k: str) -> Any: - return self.data[k] - - def __setitem__(self, k: str, v: Any) -> None: - self.data[k] = v - self.dirty_entries.add(k) - - def delete(self, k: str): - del self.data[k] - self.dirty_entries.remove(k) - - def keys(self) -> List[str]: - return list(self.data.keys()) - - def items(self) -> Iterable[Tuple[str, Any]]: - return self.data.items() - - def evict_all(self) -> None: - self.data = {} - self.dirty_entries = set() - - def clear_all(self) -> None: - self.dirty_entries = set() - - @property - def is_clean(self) -> bool: - return len(self.dirty_entries) == 0 - - -class MultiProcInMemoryStorage(KVCache): - def __init__(self): - manager = Manager() - self.data: DictProxy[str, Any] = manager.dict() - self.dirty_entries: DictProxy[str, None] = manager.dict() - - def __repr__(self): - return f"MultiProcInMemoryStorage(data={self.data})" - - def exists(self, k: str) -> bool: - return k in self.data.keys() - - def set(self, k: str, v: Any): - self.data[k] = v - self.dirty_entries[k] = None - - def get(self, k: str) -> Any: - return self.data[k] - - def delete(self, k: str): - del self.data[k] - self.dirty_entries.pop(k) - - def keys(self) -> List[str]: - return list(self.data.keys()) - - @property - def is_clean(self) -> bool: - return len(self.dirty_entries) == 0 diff --git a/mandala/storages/rel_impls/__init__.py b/mandala/storages/rel_impls/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mandala/storages/rel_impls/bases.py b/mandala/storages/rel_impls/bases.py deleted file mode 100644 index f06ce49..0000000 --- a/mandala/storages/rel_impls/bases.py +++ /dev/null @@ -1,85 +0,0 @@ -from abc import ABC, abstractmethod -import sqlite3 -from pypika import Query, Column -from ...common_imports import * -from .utils import Connection -import pyarrow as pa - - -class RelStorage(ABC): - """ - Responsible for the low-level (i.e., unaware of mandala-specific concepts) - interactions with the relational part of the storage, such as creating and - extending tables, running queries, etc. This is intended to be a pretty - generic, minimal database interface, supporting just the things we need. - - It's deliberately referred to as "relational storage" as opposed to a - "relational database" because simpler implementations exist. - """ - - @abstractmethod - def create_relation( - self, - name: str, - columns: List[Tuple[str, Optional[str]]], # [(col name, type), ...] - defaults: Dict[str, Any], # {col name: default value, ...} - primary_key: Optional[Union[str, List[str]]] = None, - if_not_exists: bool = True, - conn: Optional[Any] = None, - ): - """ - Create a relation with the given name and columns. - """ - raise NotImplementedError() - - @abstractmethod - def create_column(self, relation: str, name: str, default_value: str): - raise NotImplementedError() - - @abstractmethod - def insert(self, name: str, df: pd.DataFrame): - """ - Append rows to a table - """ - raise NotImplementedError() - - @abstractmethod - def upsert(self, relation: str, ta: pa.Table, conn: Optional[Connection] = None): - """ - Upsert rows in a table based on index - """ - raise NotImplementedError() - - @abstractmethod - def delete( - self, - relation: str, - where_col: str, - where_values: List[str], - conn: Optional[Connection] = None, - ): - """ - Delete rows from a table where `where_col` is in `where_values` - """ - raise NotImplementedError() - - @abstractmethod - def get_data( - self, table: str, conn: Optional[sqlite3.Connection] = None - ) -> pd.DataFrame: - """ - Fetch data from a table. - """ - raise NotImplementedError() - - @abstractmethod - def execute_df( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[sqlite3.Connection] = None, - ) -> pd.DataFrame: - """ - Execute a query and return the result as a DataFrame. - """ - raise NotImplementedError() diff --git a/mandala/storages/rel_impls/duckdb_impl.py b/mandala/storages/rel_impls/duckdb_impl.py deleted file mode 100644 index d37a397..0000000 --- a/mandala/storages/rel_impls/duckdb_impl.py +++ /dev/null @@ -1,307 +0,0 @@ -import duckdb -import pyarrow as pa -from duckdb import DuckDBPyConnection as Connection -from pypika import Query, Column - -from .bases import RelStorage -from .utils import Transactable, transaction -from ...core.config import Config -from ...core.utils import get_uid -from ...common_imports import * - - -class DuckDBRelStorage(RelStorage, Transactable): - UID_DTYPE = "VARCHAR" # TODO - change this - TEMP_ARROW_TABLE = Config.temp_arrow_table - - def __init__(self, address: Optional[str] = None, _read_only: bool = False): - self.in_memory = address is None - self._read_only = _read_only - if self.in_memory: - self._conn = duckdb.connect(database=":memory:", read_only=self._read_only) - else: - self.path = address - - ############################################################################ - ### transaction interface - ############################################################################ - def _get_connection(self) -> Connection: - return ( - self._conn - if self.in_memory - else duckdb.connect(database=self.path, read_only=self._read_only) - ) - - def _end_transaction(self, conn: Connection): - if not self.in_memory: - conn.close() - - ############################################################################ - ### - ############################################################################ - @transaction() - def get_tables(self, conn: Optional[Connection] = None) -> List[str]: - return self.execute_df(query="SHOW TABLES;", conn=conn)["name"].tolist() - - @transaction() - def table_exists(self, relation: str, conn: Optional[Connection] = None) -> bool: - return relation in self.get_tables(conn=conn) - - @transaction() - def get_data(self, table: str, conn: Optional[Connection] = None) -> pd.DataFrame: - return self.execute_df(query=f"SELECT * FROM {table};", conn=conn) - - @transaction() - def get_count(self, table: str, conn: Optional[Connection] = None) -> int: - df = self.execute_df(query=f"SELECT COUNT(*) FROM {table};", conn=conn) - return df["count_star()"].item() - - @transaction() - def get_all_data( - self, conn: Optional[Connection] = None - ) -> Dict[str, pd.DataFrame]: - tables = self.get_tables(conn=conn) - data = {} - for table in tables: - data[table] = self.get_data(table=table, conn=conn) - return data - - ############################################################################ - ### schema management - ############################################################################ - @transaction() - def create_relation( - self, - name: str, - columns: List[Tuple[str, Optional[str]]], # [(col name, type), ...] - defaults: Dict[str, Any], # {col name: default value, ...} - primary_key: Optional[str] = None, - if_not_exists: bool = True, - conn: Optional[Connection] = None, - ): - """ - Create a table with given columns, with an optional primary key. - Columns are given as tuples of (name, type). - Columns without a dtype are assumed to be of type `self.UID_DTYPE`. - """ - query = Query.create_table(table=name).columns( - *[ - Column( - column_name=column_name, - column_type=column_type - if column_type is not None - else self.UID_DTYPE, - default=defaults.get(column_name, None), - # nullable=False, - ) - for column_name, column_type in columns - ], - ) - if if_not_exists: - query = query.if_not_exists() - if primary_key is not None: - query = query.primary_key(primary_key) - conn.execute(str(query)) - logger.debug( - f'Created table "{name}" with columns {[elt[0] for elt in columns]}' - ) - - @transaction() - def create_column( - self, - relation: str, - name: str, - default_value: str, - conn: Optional[Connection] = None, - ): - """ - Add a new column to a table. - """ - query = f"ALTER TABLE {relation} ADD COLUMN {name} {self.UID_DTYPE} DEFAULT '{default_value}'" - conn.execute(query=query) - logger.debug(f'Added column "{name}" to table "{relation}"') - - @transaction() - def drop_column(self, relation: str, name: str, conn: Optional[Connection] = None): - """ - Drop a column from a table. - """ - query = f"ALTER TABLE {relation} DROP COLUMN {name}" - conn.execute(query=query) - logger.debug(f'Dropped column "{name}" from table "{relation}"') - - @transaction() - def rename_relation( - self, name: str, new_name: str, conn: Optional[Connection] = None - ): - """ - Rename a table - """ - query = f"ALTER TABLE {name} RENAME TO {new_name};" - conn.execute(query) - logger.debug(f'Renamed table "{name}" to "{new_name}"') - - @transaction() - def rename_column( - self, relation: str, name: str, new_name: str, conn: Optional[Connection] = None - ): - """ - Rename a column - """ - query = f'ALTER TABLE {relation} RENAME "{name}" TO "{new_name}";' - conn.execute(query) - logger.debug(f'Renamed column "{name}" of table "{relation}" to "{new_name}"') - - @transaction() - def rename_columns( - self, relation: str, mapping: Dict[str, str], conn: Optional[Connection] = None - ): - # factorize the renaming into two maps that can be applied atomically - part_1 = {k: get_uid() for k in mapping.keys()} - part_2 = {part_1[k]: v for k, v in mapping.items()} - for k, v in part_1.items(): - self.rename_column(relation=relation, name=k, new_name=v, conn=conn) - for k, v in part_2.items(): - self.rename_column(relation=relation, name=k, new_name=v, conn=conn) - if len(mapping) > 0: - logger.debug(f'Renamed columns of table "{relation}" via mapping {mapping}') - - ############################################################################ - ### instance management - ############################################################################ - @transaction() - def _get_cols(self, relation: str, conn: Optional[Connection] = None) -> List[str]: - """ - Duckdb-specific method to get the *ordered* columns of a table. - """ - return ( - self.execute_arrow(query=f'DESCRIBE "{relation}";', conn=conn) - .column("column_name") - .to_pylist() - ) - - @transaction() - def _get_primary_keys( - self, relation: str, conn: Optional[Connection] = None - ) -> List[str]: - """ - Duckdb-specific method to get the primary key of a table. - """ - constraint_type = "PRIMARY KEY" - df = self.execute_df(query=f"SELECT * FROM duckdb_constraints();", conn=conn) - df = df[["table_name", "constraint_type", "constraint_column_names"]] - df = df[ - (df["table_name"] == relation) & (df["constraint_type"] == constraint_type) - ] - if len(df) == 0: - raise NotImplementedError() - elif len(df) == 1: - return df["constraint_column_names"].item() - else: - raise NotImplementedError(f"Multiple primary keys for {relation}") - - @transaction() - def insert(self, relation: str, ta: pa.Table, conn: Optional[Connection] = None): - """ - Append rows to a table - """ - if len(ta) == 0: - return - table_cols = self._get_cols(relation=relation, conn=conn) - assert set(ta.column_names) == set(table_cols) - cols_string = ", ".join([f'"{column_name}"' for column_name in ta.column_names]) - conn.register(view_name=self.TEMP_ARROW_TABLE, python_object=ta) - conn.execute( - f'INSERT INTO "{relation}" ({cols_string}) SELECT * FROM {self.TEMP_ARROW_TABLE}' - ) - conn.unregister(view_name=self.TEMP_ARROW_TABLE) - - @transaction() - def upsert( - self, - relation: str, - ta: pa.Table, - key_cols: Optional[List[str]] = None, - conn: Optional[Connection] = None, - ): - """ - Upsert rows in a table based on primary key. - - TODO: currently does NOT update matching rows - """ - if len(ta) == 0: - return - if not self.table_exists(relation, conn=conn): - raise RuntimeError() - col_names = ta.column_names if isinstance(ta, pa.Table) else ta.columns - cols_string = ", ".join([f'"{column_name}"' for column_name in col_names]) - conn.register(view_name=self.TEMP_ARROW_TABLE, python_object=ta) - if key_cols is None: - primary_keys = self._get_primary_keys(relation=relation, conn=conn) - if len(primary_keys) != 1: - raise NotImplementedError() - primary_key = primary_keys[0] - query = f'INSERT INTO "{relation}" ({cols_string}) SELECT * FROM {self.TEMP_ARROW_TABLE} WHERE "{primary_key}" NOT IN (SELECT "{primary_key}" FROM "{relation}")' - else: - key_cols_str = ", ".join([f'"{column_name}"' for column_name in key_cols]) - key_cols_str = f"({key_cols_str})" - query = f'INSERT INTO "{relation}" ({cols_string}) SELECT * FROM {self.TEMP_ARROW_TABLE} WHERE {key_cols_str} NOT IN (SELECT {key_cols_str} FROM "{relation}")' - conn.execute(query) - conn.unregister(view_name=self.TEMP_ARROW_TABLE) - - @transaction() - def delete( - self, relation: str, index: List[str], conn: Optional[Connection] = None - ): - """ - Delete rows from a table based on index - """ - primary_keys = self._get_primary_keys(relation=relation, conn=conn) - if len(primary_keys) != 1: - raise NotImplementedError() - primary_key = primary_keys[0] - in_str = ", ".join([f"'{i}'" for i in index]) - query = f'DELETE FROM "{relation}" WHERE {primary_key} IN ({in_str})' - conn.execute(query) - - ############################################################################ - ### queries - ############################################################################ - @transaction() - def execute_arrow( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[Connection] = None, - ) -> pa.Table: - if parameters is None: - parameters = [] - if not isinstance(query, str): - query = str(query) - return conn.execute(query, parameters=parameters).fetch_arrow_table() - - @transaction() - def execute_no_results( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[Connection] = None, - ) -> None: - if parameters is None: - parameters = [] - if not isinstance(query, str): - query = str(query) - return conn.execute(query, parameters=parameters) - - @transaction() - def execute_df( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[Connection] = None, - ) -> pd.DataFrame: - if parameters is None: - parameters = [] - if not isinstance(query, str): - query = str(query) - return conn.execute(query, parameters=parameters).fetchdf() diff --git a/mandala/storages/rel_impls/sqlite_impl.py b/mandala/storages/rel_impls/sqlite_impl.py deleted file mode 100644 index 0e37edc..0000000 --- a/mandala/storages/rel_impls/sqlite_impl.py +++ /dev/null @@ -1,393 +0,0 @@ -import sqlite3 -from pandas.api.types import is_string_dtype -import pyarrow as pa -from pypika import Query, Column - -from .bases import RelStorage -from .utils import Transactable, transaction -from ...core.utils import get_uid -from ...common_imports import * - - -class SQLiteRelStorage(RelStorage, Transactable): - UID_DTYPE = "VARCHAR" - - def __init__( - self, - address: Optional[str] = None, - _read_only: bool = False, - autocommit: bool = False, - journal_mode: str = "WAL", - page_size: int = 32768, - mmap_size_MB: int = 256, - cache_size_pages: int = 1000, - synchronous: str = "normal", - ): - self.journal_mode = journal_mode - self.page_size = page_size - self.mmap_size_MB = mmap_size_MB - self.cache_size_pages = cache_size_pages - self.synchronous = synchronous - self.autocommit = autocommit - self._read_only = _read_only - - self.in_memory = address is None - if self.in_memory: - self._id = get_uid() - self._connection_address = f"file:{self._id}?mode=memory&cache=shared" - self._conn = sqlite3.connect( - str(self._connection_address), isolation_level=None, uri=True - ) - with self._conn: - self.apply_optimizations(self._conn) - else: - self._connection_address = address - - def get_optimizations(self) -> List[str]: - """ - This needs some explaining: - - you cannot change `page_size` after setting `journal_mode = WAL` - - `journal_mode = WAL` is persistent across database connections - - `cache_size` is in pages when positive, in kB when negative - """ - if self.mmap_size_MB is None: - mmap_size = 0 - else: - mmap_size = self.mmap_size_MB * 1024**2 - pragma_dict = OrderedDict( - [ - # 'temp_store': 'memory', - ("synchronous", self.synchronous), - ("page_size", self.page_size), - ("cache_size", self.cache_size_pages), - ("journal_mode", self.journal_mode), - ("mmap_size", mmap_size), - ("foreign_keys", "ON"), - ] - ) - lines = [f"PRAGMA {k} = {v};" for k, v in pragma_dict.items()] - return lines - - def apply_optimizations(self, conn: sqlite3.Connection): - opts = self.get_optimizations() - for line in opts: - conn.execute(line) - - def read_cursor(self, c: sqlite3.Cursor) -> pd.DataFrame: - if c.description is None: - assert len(c.fetchall()) == 0 - return pd.DataFrame() - colnames = [col[0] for col in c.description] - df = pd.DataFrame(c.fetchall(), columns=colnames) - return self.postprocess_df(df) - - def postprocess_df(self, df: pd.DataFrame) -> pd.DataFrame: - for col, dtype in df.dtypes.items(): - if not is_string_dtype(dtype): - df[col] = df[col].astype(str) - return df - - ############################################################################ - ### transaction interface - ############################################################################ - def _get_connection(self) -> sqlite3.Connection: - if self.in_memory: - return self._conn - else: - return sqlite3.connect( - str(self._connection_address), isolation_level=None # "IMMEDIATE" - ) - - def _end_transaction(self, conn: sqlite3.Connection): - conn.commit() - if not self.in_memory: - conn.close() - - ############################################################################ - ### - ############################################################################ - @transaction() - def get_tables(self, conn: Optional[sqlite3.Connection] = None) -> List[str]: - query = "SELECT name FROM sqlite_master WHERE type='table';" - cur = conn.cursor() - cur.execute(query) - return [row[0] for row in cur.fetchall()] - - @transaction() - def table_exists( - self, relation: str, conn: Optional[sqlite3.Connection] = None - ) -> bool: - return relation in self.get_tables(conn=conn) - - @transaction() - def get_data( - self, table: str, conn: Optional[sqlite3.Connection] = None - ) -> pd.DataFrame: - return self.execute_df(query=f"SELECT * FROM {table};", conn=conn) - - @transaction() - def get_count(self, table: str, conn: Optional[sqlite3.Connection] = None) -> int: - query = f"SELECT COUNT(*) FROM {table};" - return int(self.execute_df(query=query, conn=conn).iloc[0, 0]) - - @transaction() - def get_all_data( - self, conn: Optional[sqlite3.Connection] = None - ) -> Dict[str, pd.DataFrame]: - return { - table: self.get_data(table, conn=conn) - for table in self.get_tables(conn=conn) - } - - ############################################################################ - ### schema management - ############################################################################ - @transaction() - def create_relation( - self, - name: str, - columns: List[Tuple[str, Optional[str]]], # [(col name, type), ...] - defaults: Dict[str, Any], # {col name: default value, ...} - primary_key: Optional[Union[str, List[str]]] = None, - if_not_exists: bool = True, - conn: Optional[sqlite3.Connection] = None, - ): - """ - Create a table with given columns, with an optional primary key. - Columns are given as tuples of (name, type). - Columns without a dtype are assumed to be of type `self.UID_DTYPE`. - """ - query = Query.create_table(table=name).columns( - *[ - Column( - column_name=column_name, - column_type=column_type - if column_type is not None - else self.UID_DTYPE, - default=defaults.get(column_name, None), - # nullable=False, - ) - for column_name, column_type in columns - ], - ) - if if_not_exists: - query = query.if_not_exists() - if primary_key is not None: - if isinstance(primary_key, str): - query = query.primary_key(primary_key) - else: - query = query.primary_key(*primary_key) - conn.execute(str(query)) - logger.debug( - f'Created table "{name}" with columns {[elt[0] for elt in columns]}' - ) - - @transaction() - def create_column( - self, - relation: str, - name: str, - default_value: str, - conn: Optional[sqlite3.Connection] = None, - ): - """ - Add a new column to a table. - """ - query = ( - f"ALTER TABLE {relation} ADD COLUMN {name} TEXT DEFAULT '{default_value}'" - ) - conn.execute(query) - logger.debug(f'Created column "{name}" in table "{relation}"') - - @transaction() - def drop_column( - self, relation: str, name: str, conn: Optional[sqlite3.Connection] = None - ): - """ - Drop a column from a table. - """ - query = f'ALTER TABLE {relation} DROP COLUMN "{name}"' - conn.execute(query) - logger.debug(f'Dropped column "{name}" from table "{relation}"') - - @transaction() - def rename_relation( - self, name: str, new_name: str, conn: Optional[sqlite3.Connection] = None - ): - """ - Rename a table - """ - query = f"ALTER TABLE {name} RENAME TO {new_name};" - conn.execute(query) - logger.debug(f'Renamed table "{name}" to "{new_name}"') - - @transaction() - def rename_column( - self, - relation: str, - name: str, - new_name: str, - conn: Optional[sqlite3.Connection] = None, - ): - """ - Rename a column - """ - query = f'ALTER TABLE {relation} RENAME COLUMN "{name}" TO "{new_name}";' - conn.execute(query) - logger.debug(f'Renamed column "{name}" in table "{relation}" to "{new_name}"') - - @transaction() - def rename_columns( - self, - relation: str, - mapping: Dict[str, str], - conn: Optional[sqlite3.Connection] = None, - ): - # factorize the renaming into two maps that can be applied atomically - part_1 = {k: get_uid() for k in mapping.keys()} - part_2 = {part_1[k]: v for k, v in mapping.items()} - for k, v in part_1.items(): - self.rename_column(relation=relation, name=k, new_name=v, conn=conn) - for k, v in part_2.items(): - self.rename_column(relation=relation, name=k, new_name=v, conn=conn) - if len(mapping) > 0: - logger.debug(f'Renamed columns of table "{relation}" via mapping {mapping}') - - ############################################################################ - ### instance management - ############################################################################ - @transaction() - def _get_cols( - self, relation: str, conn: Optional[sqlite3.Connection] = None - ) -> List[str]: - """ - get the *ordered* columns of a table. - """ - query = f"PRAGMA table_info({relation})" - df = self.execute_df(query=query, conn=conn) - return list(df["name"]) - - @transaction() - def _get_primary_keys( - self, relation: str, conn: Optional[sqlite3.Connection] = None - ) -> List[str]: - """ - get the primary key of a table. - """ - query = f"PRAGMA table_info({relation})" - df = self.execute_df(query=query, conn=conn) - return list(df[df["pk"].apply(int) == 1]["name"]) - - @transaction() - def insert( - self, relation: str, ta: pa.Table, conn: Optional[sqlite3.Connection] = None - ): - """ - Append rows to a table - """ - df = ta.to_pandas() - if df.empty: - return - columns = df.columns.tolist() - col_str = ", ".join([f'"{col}"' for col in columns]) - placeholder_str = ",".join(["?" for _ in columns]) - query = f"INSERT INTO {relation}({col_str}) VALUES ({placeholder_str})" - parameters = list(df.itertuples(index=False)) - conn.executemany(query, parameters) - - @transaction() - def upsert( - self, - relation: str, - ta: pa.Table, - key_cols: Optional[List[str]] = None, - conn: Optional[sqlite3.Connection] = None, - ): - """ - Upsert rows in a table based on primary key. - """ - if isinstance(ta, pa.Table): - df = ta.to_pandas() - else: - df = ta - if df.empty: # engine complains - return - columns = df.columns.tolist() - col_str = ", ".join([f'"{col}"' for col in columns]) - placeholder_str = ",".join(["?" for _ in columns]) - query = ( - f"INSERT OR REPLACE INTO {relation}({col_str}) VALUES ({placeholder_str})" - ) - parameters = list(df.itertuples(index=False)) - conn.executemany(query, parameters) - - @transaction() - def delete( - self, - relation: str, - where_col: str, - where_values: List[str], - conn: Optional[sqlite3.Connection] = None, - ): - """ - Delete rows from a table where `where_col` is in `where_values` - """ - query = f"DELETE FROM {relation} WHERE {where_col} IN ({','.join(['?']*len(where_values))})" - conn.execute(query, where_values) - - @transaction() - def vacuum(self, warn: bool = True, conn: Optional[sqlite3.Connection] = None): - """ - ! this needs a lot of free space on disk to work (~2x db size) - """ - if warn: - total_db_size = ( - self.execute_df("PRAGMA page_count").astype(float).iloc[0, 0] - * self.page_size - ) - question = "Vacuuming a database of size {:.2f} MB, this may take a long time and requires ~2x as much free space on disk, are you sure?".format( - total_db_size / 1024**2 - ) - # ask the user if they are sure - user_input = input(question + " (y/n): ") - if user_input.lower() != "y": - logging.info("Aborting vacuuming.") - return - conn.execute("VACUUM") - - @transaction() - def execute_df( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[sqlite3.Connection] = None, - ) -> pd.DataFrame: - if parameters is None: - parameters = [] - cursor = conn.execute(str(query), parameters) - return self.postprocess_df(self.read_cursor(cursor)) - - @transaction() - def execute_arrow( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[sqlite3.Connection] = None, - ) -> pa.Table: - if isinstance(query, Query): - query = str(query) - df = self.execute_df(query=query, parameters=parameters, conn=conn) - return pa.Table.from_pandas(df) - - @transaction() - def execute_no_results( - self, - query: Union[str, Query], - parameters: List[Any] = None, - conn: Optional[sqlite3.Connection] = None, - ) -> None: - if parameters is None: - parameters = [] - if not isinstance(query, str): - query = str(query) - return conn.execute(query, parameters) diff --git a/mandala/storages/rel_impls/utils.py b/mandala/storages/rel_impls/utils.py deleted file mode 100644 index 54bf7a9..0000000 --- a/mandala/storages/rel_impls/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -from ...common_imports import * -from ...core.config import Config -from abc import ABC, abstractmethod -import functools - -if Config.has_duckdb: - from duckdb import DuckDBPyConnection as Connection -else: - - class Connection: - pass - - -class Transactable(ABC): - @abstractmethod - def _get_connection(self) -> Connection: - raise NotImplementedError() - - @abstractmethod - def _end_transaction(self, conn: Connection): - raise NotImplementedError() - - -class Transaction: - def __init__(self): - pass - - def __call__(self, method) -> "method": - @functools.wraps(method) - def inner(instance: Transactable, *args, conn: Connection = None, **kwargs): - transaction_started_here = False - if conn is None: - # new transaction - conn = instance._get_connection() - instance._current_conn = conn - transaction_started_here = True - # conn.execute("BEGIN IMMEDIATE") - try: - result = method(instance, *args, conn=conn, **kwargs) - if transaction_started_here: - instance._end_transaction(conn=conn) - return result - except Exception as e: - if transaction_started_here: - conn.rollback() - instance._end_transaction(conn=conn) - raise e - # instance._end_transaction(conn=conn) - # return result - # else: - # # nest in existing transaction - # result = method(instance, *args, conn=conn, **kwargs) - # return result - - return inner - - -transaction = Transaction diff --git a/mandala/storages/rels.py b/mandala/storages/rels.py deleted file mode 100644 index 122bd0c..0000000 --- a/mandala/storages/rels.py +++ /dev/null @@ -1,1351 +0,0 @@ -from collections import defaultdict - -import pyarrow as pa -import pyarrow.parquet as pq -from pypika import Query, Table, Parameter -from pypika.terms import LiteralValue - -from ..common_imports import * -from ..core.config import Config, dump_output_name, Provenance -from ..core.model import Call, FuncOp, Ref, ValueRef, collect_detached -from ..core.builtins_ import Builtins -from ..core.wrapping import unwrap -from ..core.utils import get_fibers_as_lists -from ..core.sig import Signature -from ..deps.versioner import Versioner -from ..utils import serialize, deserialize, _rename_cols - -if Config.has_duckdb: - from .rel_impls.duckdb_impl import DuckDBRelStorage -from .rel_impls.utils import Transactable, transaction, Connection -from .rel_impls.bases import RelStorage - - -# {internal table name -> serialized (internally named) table} -RemoteEventLogEntry = Dict[str, bytes] - - -class SpilledRef: - pass - - -class VersionAdapter(Transactable): - # todo: this is too similar to SignatureAdapter, refactor - # like SignatureAdapter, but for dependency state. - # encapsulates methods to load and write the dependency table - def __init__(self, rel_adapter: "RelAdapter"): - self.rel_adapter = rel_adapter - self.rel_storage = rel_adapter.rel_storage - - ############################################################################ - ### `Transactable` interface - ############################################################################ - def _get_connection(self) -> Connection: - return self.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_storage._end_transaction(conn=conn) - - ### - @transaction() - def dump_state( - self, - state: Versioner, - conn: Optional[Connection] = None, - ): - """ - Dump the given state of the signatures to the database. Should always - call this after the signatures have been updated. - """ - # delete existing, if any - index_col = "index" - query = f'DELETE FROM {self.rel_adapter.DEPS_TABLE} WHERE "{index_col}" = 0' - conn.execute(query) - # insert new - serialized = serialize(obj=state) - df = pd.DataFrame( - { - index_col: [0], - "deps": [serialized], - } - ) - ta = pa.Table.from_pandas(df) - self.rel_storage.insert(relation=self.rel_adapter.DEPS_TABLE, ta=ta, conn=conn) - - @transaction() - def has_state(self, conn: Optional[Connection] = None) -> bool: - query = f'SELECT * FROM {self.rel_adapter.DEPS_TABLE} WHERE "index" = 0' - df = self.rel_storage.execute_df(query=query, conn=conn) - return len(df) != 0 - - @transaction() - def load_state(self, conn: Optional[Connection] = None) -> Optional[Versioner]: - """ - Load the state of the signatures from the database. All interactions - with the state of the signatures are done transactionally through this - method. - """ - query = f'SELECT * FROM {self.rel_adapter.DEPS_TABLE} WHERE "index" = 0' - df = self.rel_storage.execute_df(query=query, conn=conn) - if len(df) == 0: - return None - else: - return deserialize(df["deps"][0]) - - -class SigAdapter(Transactable): - """ - Responsible for state transitions of the schema that update the - signature objects *and* the relational tables in a transactional way. - """ - - def __init__( - self, - rel_adapter: "RelAdapter", - ): - self.rel_adapter = rel_adapter - self.rel_storage = self.rel_adapter.rel_storage - - @transaction() - def dump_state( - self, state: Dict[Tuple[str, int], Signature], conn: Optional[Connection] = None - ): - """ - Dump the given state of the signatures to the database. Should always - call this after the signatures have been updated. - """ - # delete existing, if any - index_col = "index" - query = ( - f'DELETE FROM {self.rel_adapter.SIGNATURES_TABLE} WHERE "{index_col}" = 0' - ) - conn.execute(query) - # insert new - serialized = serialize(obj=state) - df = pd.DataFrame( - { - index_col: [0], - "signatures": [serialized], - } - ) - ta = pa.Table.from_pandas(df) - self.rel_storage.insert( - relation=self.rel_adapter.SIGNATURES_TABLE, ta=ta, conn=conn - ) - - @transaction() - def has_state(self, conn: Optional[Connection] = None) -> bool: - query = f'SELECT * FROM {self.rel_adapter.SIGNATURES_TABLE} WHERE "index" = 0' - df = self.rel_storage.execute_df(query=query, conn=conn) - return len(df) != 0 - - @transaction() - def load_state( - self, conn: Optional[Connection] = None - ) -> Dict[Tuple[str, int], Signature]: - """ - Load the state of the signatures from the database. All interactions - with the state of the signatures are done transactionally through this - method. - """ - query = f'SELECT * FROM {self.rel_adapter.SIGNATURES_TABLE} WHERE "index" = 0' - df = self.rel_storage.execute_df(query=query, conn=conn) - if len(df) == 0: - return {} - else: - return deserialize(df["signatures"][0]) - - def check_invariants( - self, - sigs: Optional[Dict[Tuple[str, int], Signature]] = None, - conn: Optional[Connection] = None, - ): - """ - This checks that the invariants of the *set* of signatures for the storage - hold. This means that: - - all versions of a signature are consecutive integers starting from - 0 - - signatures have the same UI name iff they have the same internal - name - - Invariants for individual signatures are not checked by this (they are - checked by the `Signature` class). - """ - # check version numbering - if sigs is None: - sigs = self.load_state(conn=conn) - internal_names = {internal_name for internal_name, _ in sigs.keys()} - for internal_name in internal_names: - versions = [version for _, version in sigs.keys() if _ == internal_name] - assert sorted(versions) == list(range(len(versions))) - # check exactly 1 UI name per internal name - internal_to_ui_names = defaultdict(set) - for (internal_name, _), sig in sigs.items(): - internal_to_ui_names[internal_name].add(sig.ui_name) - for internal_name, ui_names in internal_to_ui_names.items(): - assert ( - len(ui_names) == 1 - ), f"Internal name {internal_name} has multiple UI names: {ui_names}" - # check exactly 1 internal name per UI name - ui_to_internal_names = defaultdict(set) - for (internal_name, _), sig in sigs.items(): - ui_to_internal_names[sig.ui_name].add(internal_name) - for ui_name, internal_names in ui_to_internal_names.items(): - assert ( - len(internal_names) == 1 - ), f"UI name {ui_name} has multiple internal names: {internal_names}" - - ############################################################################ - ### `Transactable` interface - ############################################################################ - def _get_connection(self) -> Connection: - return self.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_storage._end_transaction(conn=conn) - - ############################################################################ - ### - ############################################################################ - @transaction() - def load_ui_sigs( - self, conn: Optional[Connection] = None - ) -> Dict[Tuple[str, int], Signature]: - """ - Get the signatures indexed by (ui_name, version) - """ - sigs = self.load_state(conn=conn) - res = {(sig.ui_name, sig.version): sig for sig in sigs.values()} - assert len(res) == len(sigs) - return res - - @transaction() - def exists_versioned_ui( - self, sig: Signature, conn: Optional[Connection] = None - ) -> bool: - """ - Check if the signature exists based on its UI name *and* version - """ - return (sig.ui_name, sig.version) in self.load_ui_sigs(conn=conn) - - @transaction() - def exists_any_version( - self, sig: Signature, conn: Optional[Connection] = None - ) -> bool: - """ - Check using internal name (or UI name, if it has no internal data) if - there exists any version for this signature. - """ - if sig.has_internal_data: - return any( - sig.internal_name == k[0] for k in self.load_state(conn=conn).keys() - ) - else: - return any(sig.ui_name == k[0] for k in self.load_ui_sigs(conn=conn).keys()) - - @transaction() - def get_latest_version( - self, sig: Signature, conn: Optional[Connection] = None - ) -> Signature: - """ - Get the latest version of the signature, based on internal name or UI - name if it has no internal data. - """ - sigs = self.load_state(conn=conn) - if sig.has_internal_data: - versions = [k[1] for k in sigs.keys() if k[0] == sig.internal_name] - if len(versions) == 0: - raise ValueError(f"No versions for signature {sig}") - version = max(versions) - return sigs[(sig.internal_name, version)] - else: - versions = [ - k[1] for k in self.load_ui_sigs(conn=conn).keys() if k[0] == sig.ui_name - ] - if len(versions) == 0: - raise ValueError(f"No versions for signature {sig}") - version = max(versions) - return self.load_ui_sigs(conn=conn)[(sig.ui_name, version)] - - @transaction() - def get_versions( - self, sig: Signature, conn: Optional[Connection] = None - ) -> List[int]: - """ - Get all versions of the signature, based on internal name or UI name if - it has no internal data. - """ - sigs = self.load_state(conn=conn) - if sig.has_internal_data: - return [k[1] for k in sigs.keys() if k[0] == sig.internal_name] - else: - ui_sigs = self.load_ui_sigs(conn=conn) - return [k[1] for k in ui_sigs.keys() if k[0] == sig.ui_name] - - @transaction() - def exists_internal( - self, sig: Signature, conn: Optional[Connection] = None - ) -> bool: - """ - Check if the signature exists based on its *internal* name - """ - return (sig.internal_name, sig.version) in self.load_state(conn=conn) - - @transaction() - def internal_to_ui(self, conn: Optional[Connection] = None) -> Dict[str, str]: - """ - Get a mapping from internal names to UI names - """ - return {k[0]: v.ui_name for k, v in self.load_state(conn=conn).items()} - - @transaction() - def ui_to_internal(self, conn: Optional[Connection] = None) -> Dict[str, str]: - """ - Get a mapping from UI names to internal names - """ - return {v.ui_name: k[0] for k, v in self.load_state(conn=conn).items()} - - @transaction() - def ui_names(self, conn: Optional[Connection] = None) -> Set[str]: - # return the set of ui names - return set(self.ui_to_internal(conn=conn).keys()) - - @transaction() - def internal_names(self, conn: Optional[Connection] = None) -> Set[str]: - # return the set of internal names - return set(self.internal_to_ui(conn=conn).keys()) - - @transaction() - def is_sig_table_name( - self, name: str, use_internal: bool, conn: Optional[Connection] = None - ) -> bool: - """ - Check if the name is a valid name for a table corresponding to a - signature. - """ - parts = name.split("_", 1) - return ( - parts[0] - in ( - self.internal_names(conn=conn) - if use_internal - else self.ui_names(conn=conn) - ) - and parts[1].isdigit() - ) - - ############################################################################ - ### elementary transitions for local state - ############################################################################ - @transaction() - def _init_deps(self, sig: Signature, conn: Optional[Connection] = None): - pass - # deps = self.deps_adapter.load_state(conn=conn) - # deps.op_graphs[(sig.internal_name, sig.version)] = DependencyGraph() - # self.deps_adapter.dump_state(state=deps, conn=conn) - - @transaction() - def _create_relation(self, sig: Signature, conn: Optional[Connection] = None): - io_colnames = list(sig.input_names) + [ - dump_output_name(index=i) for i in range(sig.n_outputs) - ] - all_cols = [(col, None) for col in Config.special_call_cols] + [ - (column, None) for column in io_colnames - ] - self.rel_storage.create_relation( - name=sig.versioned_ui_name, - columns=all_cols, - primary_key=Config.causal_uid_col, - defaults=sig.new_ui_input_default_uids, - conn=conn, - ) - self._init_deps(sig=sig, conn=conn) - - @transaction() - def create_sig(self, sig: Signature, conn: Optional[Connection] = None): - """ - Create a new signature `sig`. `sig` must have internal data, and not be - present in storage at any version. - """ - assert sig.has_internal_data - sigs = self.load_state(conn=conn) - assert sig.internal_name not in self.internal_names(conn=conn) - # assert (sig.internal_name, sig.version) not in sigs.keys() - sigs[(sig.internal_name, sig.version)] = sig - # write signatures - self.dump_state(state=sigs, conn=conn) - # create relation - self._create_relation(sig=sig, conn=conn) - logger.debug(f"Created signature:\n{sig}") - - @transaction() - def create_new_version(self, sig: Signature, - strict: bool = False, - conn: Optional[Connection] = None): - """ - Create a new version of an already existing function using the `sig` - object. `sig` must have internal data, and the internal name must - already be present in some version. - """ - assert sig.has_internal_data - latest_sig = self.get_latest_version(sig=sig, conn=conn) - if strict: - assert sig.version == latest_sig.version + 1 - # update signatures - sigs = self.load_state(conn=conn) - sigs[(sig.internal_name, sig.version)] = sig - self.dump_state(state=sigs, conn=conn) - # create relation - self._create_relation(sig=sig, conn=conn) - logger.debug(f"Created new version:\n{sig}") - - @transaction() - def update_sig(self, sig: Signature, conn: Optional[Connection] = None): - """ - Update an existing signature. `sig` must have internal data, and - must already exist in storage. - """ - assert sig.has_internal_data - sigs = self.load_state(conn=conn) - assert (sig.internal_name, sig.version) in sigs.keys() - current = sigs[(sig.internal_name, sig.version)] - # the `update` method also ensures that the signature is compatible - n_outputs_new = sig.n_outputs - n_outputs_old = current.n_outputs - new_sig, updates = current.update(new=sig) - # update the signature data - sigs[(sig.internal_name, sig.version)] = new_sig - # create new inputs in the database, if any - for new_input, default_value in updates.items(): - internal_input_name = new_sig.ui_to_internal_input_map[new_input] - full_default_uid = new_sig._new_input_defaults_uids[internal_input_name] - self.rel_storage.create_column( - relation=new_sig.versioned_ui_name, - name=new_input, - default_value=full_default_uid, - conn=conn, - ) - # insert the default in the objects *in the database*, if it's - # not there already - default_uid, _ = Ref.parse_full_uid(full_uid=full_default_uid) - default_vref = ValueRef(uid=default_uid, obj=default_value, in_memory=True) - self.rel_adapter.obj_set(uid=default_uid, value=default_vref, conn=conn) - self.rel_adapter.obj_set_causal(full_uid=full_default_uid, conn=conn) - # update the outputs in the database, if this is allowed - n_rows = self.rel_storage.get_count(table=new_sig.versioned_ui_name, conn=conn) - if n_rows > 0 and n_outputs_new != n_outputs_old: - raise ValueError( - f"Cannot change the number of outputs of a signature that has already been used. " - f"Current number of outputs: {n_outputs_old}, new number of outputs: {n_outputs_new}." - ) - if n_outputs_new > n_outputs_old: - for i in range(n_outputs_old, n_outputs_new): - self.rel_storage.create_column( - relation=new_sig.versioned_ui_name, - name=dump_output_name(index=i), - default_value=None, - conn=conn, - ) - if n_outputs_new < n_outputs_old: - for i in range(n_outputs_new, n_outputs_old): - self.rel_storage.drop_column( - relation=new_sig.versioned_ui_name, - name=dump_output_name(index=i), - conn=conn, - ) - if len(updates) > 0: - logger.debug( - f"Updated signature:\n new inputs:{updates} new signature:\n {sig}" - ) - self.dump_state(state=sigs, conn=conn) - - @transaction() - def update_ui_name( - self, - sig: Signature, - conn: Optional[Connection] = None, - validate_only: bool = False, - ) -> Dict[Tuple[str, int], Signature]: - """ - Update a signature's UI name using the given `Signature` object to get - the new name. `sig` must have internal data, and must carry the new UI - name. - - NOTE: the `sig` may have the same UI name as the current signature, in - which case this method does nothing but return the current state of the - signatures. - - This method has the option of only generating the new state of the - signatures without performing the update. - """ - assert sig.has_internal_data - assert self.exists_internal(sig=sig, conn=conn) - sigs = self.load_state(conn=conn) - all_versions = self.get_versions(sig=sig, conn=conn) - current_ui_name = sigs[(sig.internal_name, all_versions[0])].ui_name - new_ui_name = sig.ui_name - if current_ui_name == new_ui_name: - # nothing to do - return sigs - # make sure there are no conflicts - internal_to_ui = self.internal_to_ui(conn=conn) - if new_ui_name in internal_to_ui.values(): - raise ValueError( - f"UI name {new_ui_name} already exists for another signature." - ) - for version in all_versions: - current = sigs[(sig.internal_name, version)] - if current.ui_name != sig.ui_name: - new_sig = current.rename(new_name=sig.ui_name) - # update signature object in memory - sigs[(sig.internal_name, version)] = new_sig - if not validate_only: - # update table - self.rel_storage.rename_relation( - name=current.versioned_ui_name, - new_name=new_sig.versioned_ui_name, - conn=conn, - ) - if not validate_only: - # update signatures state - self.dump_state(state=sigs, conn=conn) - if current_ui_name != new_ui_name: - logger.debug( - f"Updated UI name of signature: from {current_ui_name} to {new_ui_name}" - ) - return sigs - - @transaction() - def update_input_ui_names( - self, sig: Signature, conn: Optional[Connection] = None - ) -> Signature: - """ - Update a signature's input UI names from the given `Signature` object. - `sig` must have internal data, and must carry the new UI input names. - """ - assert sig.has_internal_data - sigs = self.load_state(conn=conn) - current = sigs[(sig.internal_name, sig.version)] - current_internal_to_ui = current.internal_to_ui_input_map - new_internal_to_ui = sig.internal_to_ui_input_map - renaming_map = { - current_internal_to_ui[k]: new_internal_to_ui[k] - for k in current_internal_to_ui - if current_internal_to_ui[k] != new_internal_to_ui[k] - } - # update signature object - new_sig = current.rename_inputs(mapping=renaming_map) - sigs[(sig.internal_name, sig.version)] = new_sig - self.dump_state(state=sigs, conn=conn) - # update table columns - self.rel_storage.rename_columns( - relation=new_sig.versioned_ui_name, mapping=renaming_map, conn=conn - ) - if len(renaming_map) > 0: - logger.debug( - f"Updated input UI names of signature named {sig.ui_name}: via mapping {renaming_map}" - ) - return new_sig - - @transaction() - def rename_tables( - self, - tables: Dict[str, TableType], - to: str = "internal", - conn: Optional[Connection] = None, - ) -> Dict[str, TableType]: - """ - Rename a dictionary of {versioned name: table} pairs and the tables' - columns to either internal or UI names. - """ - result = {} - assert to in ["internal", "ui"] - for table_name, table in tables.items(): - if self.is_sig_table_name( - name=table_name, use_internal=(to != "internal"), conn=conn - ): - if to == "internal": - ui_name, version = Signature.parse_versioned_name(table_name) - sig = self.load_ui_sigs(conn=conn)[ui_name, version] - new_table_name = sig.versioned_internal_name - mapping = sig.ui_to_internal_input_map - else: - internal_name, version = Signature.parse_versioned_name(table_name) - sig = self.load_state(conn=conn)[internal_name, version] - new_table_name = sig.versioned_ui_name - mapping = sig.internal_to_ui_input_map - result[new_table_name] = _rename_cols(table=table, mapping=mapping) - else: - result[table_name] = table - return result - - -class RelAdapter(Transactable): - EVENT_LOG_TABLE = Config.event_log_table - VREF_TABLE = Config.vref_table - CAUSAL_VREF_TABLE = Config.causal_vref_table - SIGNATURES_TABLE = Config.schema_table - DEPS_TABLE = Config.deps_table - PROVENANCE_TABLE = Config.provenance_table - # tables to be excluded from certain operations - SPECIAL_TABLES = ( - EVENT_LOG_TABLE, - Config.temp_arrow_table, - SIGNATURES_TABLE, - DEPS_TABLE, - PROVENANCE_TABLE, - ) - - def __init__( - self, - rel_storage: RelStorage, - spillover_dir: Optional[Path] = None, - spillover_threshold_mb: Optional[float] = None, - ): - self.rel_storage = rel_storage - self.spillover_dir = spillover_dir - self.spillover_threshold_mb = ( - Config.spillover_threshold_mb - if spillover_threshold_mb is None - else spillover_threshold_mb - ) - if self.spillover_dir is not None: - self.spillover_dir.mkdir(parents=True, exist_ok=True) - self.sig_adapter = SigAdapter(rel_adapter=self) - self.init() - # check if we are connecting to an existing instance - conn = self._get_connection() - if not self.sig_adapter.has_state(conn=conn): - self.sig_adapter.dump_state(state={}, conn=conn) - self._end_transaction(conn=conn) - - @transaction() - def init(self, conn: Optional[Connection] = None): - if self.rel_storage._read_only: - return - self.rel_storage.create_relation( - name=self.VREF_TABLE, - columns=[(Config.uid_col, None), (Config.vref_value_col, "blob")], - primary_key=Config.uid_col, - defaults={}, - if_not_exists=True, - conn=conn, - ) - self.rel_storage.create_relation( - name=self.CAUSAL_VREF_TABLE, - columns=[(Config.full_uid_col, None)], - primary_key=Config.full_uid_col, - defaults={}, - if_not_exists=True, - conn=conn, - ) - self.rel_storage.create_relation( - name=self.PROVENANCE_TABLE, - columns=[ - (Provenance.causal_uid, None), - (Provenance.name, None), - (Provenance.call_causal_uid, None), - (Provenance.direction, None), - (Provenance.op_id, None), - ], - primary_key=[Provenance.call_causal_uid, Provenance.name], - defaults={}, - if_not_exists=True, - conn=conn, - ) - # Initialize the event log. - # The event log is just a list of UIDs that changed, for now. - # the UID column stores the vref/call uid, the `table` column stores the - # table in which this UID is to be found. - self.rel_storage.create_relation( - name=self.EVENT_LOG_TABLE, - columns=[(Config.uid_col, None), ("table", "varchar")], - primary_key=Config.uid_col, - defaults={}, - if_not_exists=True, - conn=conn, - ) - # The signatures table is a binary dump of the signatures - self.rel_storage.create_relation( - name=self.SIGNATURES_TABLE, - columns=[ - ("index", "int"), - ("signatures", "blob"), - ], - primary_key="index", - defaults={}, - if_not_exists=True, - conn=conn, - ) - self.rel_storage.create_relation( - name=self.DEPS_TABLE, - columns=[ - ("index", "int"), - ("deps", "blob"), - ], - primary_key="index", - defaults={}, - if_not_exists=True, - conn=conn, - ) - - @transaction() - def get_call_tables(self, conn: Optional[Connection] = None) -> List[str]: - tables = self.rel_storage.get_tables(conn=conn) - return [ - t - for t in tables - if t not in self.SPECIAL_TABLES - and t != self.VREF_TABLE - and t != self.CAUSAL_VREF_TABLE - ] - - @transaction() - def get_vrefs(self, conn: Optional[Connection] = None) -> pd.DataFrame: - """ - Returns a dataframe of the deserialized values of the value references - in the storage. - """ - data = self.rel_storage.get_data(table=self.VREF_TABLE, conn=conn) - data["value"] = data["value"].apply(lambda vref: unwrap(deserialize(vref))) - return data - - @transaction() - def get_causal_vrefs(self, conn: Optional[Connection] = None) -> pd.DataFrame: - data = self.rel_storage.get_data(table=self.CAUSAL_VREF_TABLE, conn=conn) - return data - - @transaction() - def get_all_call_data( - self, conn: Optional[Connection] = None - ) -> Dict[str, pd.DataFrame]: - """ - Return a dictionary of all the memoization tables (labeled by versioned - ui name) - """ - result = {} - for table in self.get_call_tables(conn=conn): - result[table] = self.rel_storage.get_data(table=table, conn=conn) - return result - - ############################################################################ - ### event log stuff - ############################################################################ - @transaction() - def get_event_log(self, conn: Optional[Connection] = None) -> pd.DataFrame: - return self.rel_storage.get_data(table=self.EVENT_LOG_TABLE, conn=conn) - - @transaction() - def clear_event_log(self, conn: Optional[Connection] = None): - event_log_table = Table(self.EVENT_LOG_TABLE) - query = Query.from_(event_log_table).delete() - self.rel_storage.execute_no_results(query=query, conn=conn) - - ############################################################################ - ### `Transactable` interface - ############################################################################ - def _get_connection(self) -> Connection: - return self.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_storage._end_transaction(conn=conn) - - ############################################################################ - ### call methods - ############################################################################ - @transaction() - def _get_current_names( - self, sig: Signature, conn: Optional[Connection] = None - ) -> Tuple[str, Dict[str, str]]: - """ - Given a possibly stale signature `sig`, return - - the current ui name for this signature - - a mapping of stale input names to their current values - """ - current_sigs = self.sig_adapter.load_state(conn=conn) - current_sig = current_sigs[sig.internal_name, sig.version] - true_versioned_ui_name = current_sig.versioned_ui_name - stale_to_true_input_mapping = { - k: current_sig.internal_to_ui_input_map[v] - for k, v in sig.ui_to_internal_input_map.items() - # for k, v in current_sig.ui_to_internal_input_map.items() - } - return true_versioned_ui_name, stale_to_true_input_mapping - - @transaction() - def tabulate_calls( - self, calls: List[Call], conn: Optional[Connection] = None - ) -> Dict[str, pa.Table]: - """ - Converts call objects to a dictionary of {op UI name: table - to upsert} pairs. - - Note that the calls can involve functions in different stages of - staleness. This method can handle calls to many different variants of - the same function (adding inputs, renaming the function or its inputs). - To handle calls to stale functions, this passes through internal names - to get the current UI names. - """ - if not len(calls) == len(set([call.full_uid for call in calls])): - # something fishy may be going on - raise InternalError("Calls must have unique UIDs") - # split by operation *internal* name to group calls to the same op in - # the same group, even if UI names are different. - calls_by_op = defaultdict(list) - for call in calls: - calls_by_op[call.func_op.sig.versioned_internal_name].append(call) - res = {} - for versioned_internal_name, calls in calls_by_op.items(): - rows = [] - true_sig = self.sig_adapter.load_state()[ - Signature.parse_versioned_name(versioned_internal_name) - ] - true_versioned_ui_name = None - for call in calls: - # it is necessary to process each call individually to properly - # handle multiple stale variants of this op - sig = call.func_op.sig - # get the current state of this signature - ( - true_versioned_ui_name, - # stale UI input -> current UI input. This could vary across calls - stale_to_true_input_mapping, - ) = self._get_current_names(sig, conn=conn) - # form the input UIDs - input_uids = { - stale_to_true_input_mapping[k]: v.full_uid - for k, v in call.inputs.items() - } - # patch the input uids using the true signature. This is - # necesary to do here because it seems duckdb has issues with - # interpreting NaNs from pyarrow as nulls - for k, v in true_sig.new_ui_input_default_uids.items(): - if k not in input_uids: - input_uids[k] = v - row = { - Config.uid_col: call.uid, - Config.causal_uid_col: call.causal_uid, - Config.content_version_col: call.content_version, - Config.semantic_version_col: call.semantic_version, - Config.transient_col: call.transient, - **input_uids, - **{ - dump_output_name(index=i): v.full_uid - for i, v in enumerate(call.outputs) - }, - } - rows.append(row) - assert true_versioned_ui_name is not None - res[true_versioned_ui_name] = pa.Table.from_pylist(rows) - return res - - @transaction() - def upsert_calls(self, calls: List[Call], conn: Optional[Connection] = None): - """ - Upserts *detached* calls in the relational storage so that they will - show up in declarative queries. - """ - if len(calls) > 0: # avoid dealing with empty dataframes - ### upsert full ref uids - full_uids = set() - for call in calls: - full_uids.update({v.full_uid for v in call.inputs.values()}) - full_uids.update({v.full_uid for v in call.outputs}) - self.rel_storage.upsert( - relation=self.CAUSAL_VREF_TABLE, - ta=pa.Table.from_pydict( - { - Config.full_uid_col: list(full_uids), - } - ), - conn=conn, - ) - ### upsert calls - for table_name, ta in self.tabulate_calls(calls).items(): - self.rel_storage.upsert(relation=table_name, ta=ta, conn=conn) - # Write changes to the event log table - self.rel_storage.upsert( - relation=self.EVENT_LOG_TABLE, - ta=pa.Table.from_pydict( - { - Config.uid_col: ta[Config.uid_col], - "table": [table_name] * len(ta), - } - ), - conn=conn, - ) - ### upsert in provenance - self.upsert_provenance(calls=calls, conn=conn) - - @transaction() - def _query_call( - self, uid: str, by_causal: bool, conn: Optional[Connection] = None - ) -> pa.Table: - # TODO: replace this by something more efficient - col = Config.causal_uid_col if by_causal else Config.uid_col - all_tables = [ - Query.from_(table_name) - .where(Table(table_name)[col] == Parameter("$1")) - .select( - Table(table_name)[col], - LiteralValue(f"'{table_name}'"), - ) - for table_name in self.get_call_tables() - ] - query = "\nUNION\n".join([str(q) for q in all_tables]) - # query = sum(all_tables[1:], start=all_tables[0]) - return self.rel_storage.execute_arrow(query, [uid], conn=conn) - - @transaction() - def call_exists( - self, uid: str, by_causal: bool, conn: Optional[Connection] = None - ) -> bool: - return len(self._query_call(uid, by_causal=by_causal, conn=conn)) > 0 - - @transaction() - def call_get_lazy( - self, uid: str, by_causal: bool, conn: Optional[Connection] = None - ) -> Call: - """ - Return the call with the inputs/outputs as lazy value references. - """ - row = self._query_call(uid, by_causal=by_causal, conn=conn).take([0]) - col = Config.causal_uid_col if by_causal else Config.uid_col - table_name = row.column(1)[0] - table = Table(table_name) - query = ( - Query.from_(table).where(table[col] == Parameter("$1")).select(table.star) - ) - results = self.rel_storage.execute_arrow(query, [uid], conn=conn) - # determine the signature for this call - ui_name, version = Signature.parse_versioned_name( - versioned_name=str(table_name) - ) - sig = self.sig_adapter.load_ui_sigs(conn=conn)[ui_name, version] - return Call.from_row(results, func_op=FuncOp._from_sig(sig=sig)) - - @transaction() - def mget_call_lazy( - self, - versioned_ui_name: str, - uids: List[str], - by_causal: bool = True, - conn: Optional[Connection] = None, - ) -> List[Call]: - """ - Get many calls to the same op. - """ - if not by_causal: - raise NotImplementedError() - table_name = versioned_ui_name - table = Table(table_name) - query = ( - Query.from_(table) - .where(table[Config.causal_uid_col].isin(uids)) - .select(table.star) - ) - results = self.rel_storage.execute_df(query, conn=conn).set_index( - Config.causal_uid_col - ) - ui_name, version = Signature.parse_versioned_name( - versioned_name=str(table_name) - ) - sig = self.sig_adapter.load_ui_sigs(conn=conn)[ui_name, version] - calls_by_causal_uid = {} - for causal_uid, row in results.iterrows(): - call_dict = dict(row) - call_dict.update({Config.causal_uid_col: causal_uid}) - calls_by_causal_uid[causal_uid] = Call.from_row( - call_dict, func_op=FuncOp._from_sig(sig=sig) - ) - return [calls_by_causal_uid[uid] for uid in uids] - - ############################################################################ - ### object methods - ############################################################################ - def _spillover_criterion(self, serialized: bytes) -> bool: - return len(serialized) / 1024**2 > self.spillover_threshold_mb - - def _mset_spillover(self, uids: List[str], refs: List[Ref]): - assert self.spillover_dir is not None - for uid, ref in zip(uids, refs): - dump_path = self.spillover_dir / f"{uid}.joblib" - with open(dump_path, "wb") as f: - joblib.dump(ref, f) - if Config.warnings: - size_in_mb = round(os.path.getsize(dump_path) / 10**6, 2) - logger.warning( - f"Spill over {ref.__repr__(shorten=True)} of size {size_in_mb}MB" - ) - - def _mget_spillover(self, uids: List[str]) -> List[Ref]: - assert self.spillover_dir is not None - values = [] - for uid in uids: - try: - with open(self.spillover_dir / f"{uid}.joblib", "rb") as f: - values.append(joblib.load(f)) - except FileNotFoundError: - if Config.warnings: - logger.warning(f"Could not find spillover file for uid {uid}") - values.append(Ref.from_uid(uid=uid)) - return values - - def _serialize_spillover(self, ref: Ref) -> bytes: - s = serialize(ref) - if self._spillover_criterion(s): - substitute = Ref.from_uid(uid=ref.uid) - substitute._obj = SpilledRef() - self._mset_spillover([ref.uid], [ref]) - return serialize(substitute) - else: - return s - - def _deserialize_spillover(self, serialized: bytes) -> Ref: - value = deserialize(serialized) - if isinstance(value._obj, SpilledRef): - return self._mget_spillover([value.uid])[0] - else: - return value - - @transaction() - def obj_exists( - self, uids: List[str], conn: Optional[Connection] = None - ) -> List[bool]: - if len(uids) == 0: - return [] - table = Table(Config.vref_table) - query = ( - Query.from_(table) - .where(table[Config.uid_col].isin(uids)) - .select(table[Config.uid_col]) - ) - existing_uids = set( - self.rel_storage.execute_df(query=query, conn=conn)[Config.uid_col] - ) - return [uid in existing_uids for uid in uids] - - @transaction() - def obj_set( - self, - uid: str, - value: Ref, - shallow: bool = True, - conn: Optional[Connection] = None, - ) -> None: - self.obj_sets(vrefs={uid: value}, shallow=shallow, conn=conn) - - @transaction() - def obj_set_causal( - self, - full_uid: str, - conn: Optional[Connection] = None, - ) -> None: - self.rel_storage.upsert( - relation=Config.causal_vref_table, - ta=pd.DataFrame({Config.full_uid_col: [full_uid]}), - conn=conn, - ) - - @transaction() - def obj_sets( - self, - vrefs: Dict[str, Ref], - shallow: bool = True, - conn: Optional[Connection] = None, - ) -> None: - if not shallow: - raise NotImplementedError() - uids = list(vrefs.keys()) - indicators = self.obj_exists(uids, conn=conn) - new_uids = [uid for uid, indicator in zip(uids, indicators) if not indicator] - if self.spillover_dir is not None: - serializer = self._serialize_spillover - else: - serializer = serialize - serialized_refs = [serializer(vrefs[new_uid].dump()) for new_uid in new_uids] - # to get around the 2GB limit of pyarrow - ta = pd.DataFrame( - { - Config.uid_col: new_uids, - "value": serialized_refs, - } - ) - self.rel_storage.upsert(relation=Config.vref_table, ta=ta, conn=conn) - log_ta = pa.Table.from_pylist( - [ - {Config.uid_col: new_uid, "table": Config.vref_table} - for new_uid in new_uids - ] - ) - self.rel_storage.upsert(relation=self.EVENT_LOG_TABLE, ta=log_ta, conn=conn) - - @transaction() - def obj_gets( - self, - uids: List[str], - depth: Optional[int] = None, - _attach_atoms: bool = True, - conn: Optional[Connection] = None, - ) -> List[Ref]: - """ - Returns a list of value references for the given uids. - - Note that causal UIDs are not set for these values. - """ - if len(uids) == 0: - return [] - if depth == 0: - return [Ref.from_uid(uid=uid) for uid in uids] - elif depth == 1: - table = Table(Config.vref_table) - if not _attach_atoms: - query_uids = [uid for uid in uids if Builtins.is_builtin_uid(uid=uid)] - else: - query_uids = uids - if len(query_uids) > 0: - query = ( - Query.from_(table) - .where(table[Config.uid_col].isin(query_uids)) - .select(table[Config.uid_col], table["value"]) - ) - output = self.rel_storage.execute_df(query, conn=conn) - if self.spillover_dir is not None: - deserializer = self._deserialize_spillover - else: - deserializer = deserialize - output["value"] = output["value"].map(lambda x: deserializer(bytes(x))) - else: - output = pd.DataFrame(columns=[Config.uid_col, "value"]) - if not _attach_atoms: - atoms_df = pd.DataFrame( - { - Config.uid_col: [ - uid for uid in uids if not Builtins.is_builtin_uid(uid=uid) - ], - "value": None, - } - ) - atoms_df["value"] = atoms_df[Config.uid_col].map( - lambda x: Ref.from_uid(uid=x) - ) - output = pd.concat([output, atoms_df]) - return output.set_index(Config.uid_col).loc[uids, "value"].tolist() - elif depth is None: - results = [Ref.from_uid(uid=uid) for uid in uids] - self.mattach( - vrefs=results, shallow=False, _attach_atoms=_attach_atoms, conn=conn - ) - return results - - @transaction() - def obj_get( - self, - uid: str, - depth: Optional[int] = None, - _attach_atoms: bool = True, - conn: Optional[Connection] = None, - ) -> Ref: - vref_option = self.obj_gets( - uids=[uid], depth=depth, _attach_atoms=_attach_atoms, conn=conn - )[0] - if vref_option is None: - raise ValueError(f"Ref with uid {uid} does not exist") - return vref_option - - @transaction() - def mattach( - self, - vrefs: List[Ref], - shallow: bool = False, - _attach_atoms: bool = True, - conn: Optional[Connection] = None, - ) -> None: - """ - In-place attach objects. If `shallow`, only attach the next level; - otherwise, attach until leaf nodes. - - Note that some objects may already be attached. - """ - ### pass to the vrefs that need to be attached - detached_vrefs = collect_detached(refs=vrefs, include_transient=False) - vrefs = detached_vrefs - ### group the vrefs by uid - vrefs_by_uid = get_fibers_as_lists(mapping={vref: vref.uid for vref in vrefs}) - unique_uids = list(vrefs_by_uid.keys()) - ### load one level of the unique vrefs - vals = self.obj_gets( - uids=unique_uids, depth=1, _attach_atoms=_attach_atoms, conn=conn - ) #! this can be optimized - for i, uid in enumerate(unique_uids): - for obj in vrefs_by_uid[uid]: - if vals[i].in_memory: - obj.attach(reference=vals[i]) - if not shallow: - residues = collect_detached(refs=vals, include_transient=False) - if len(residues) > 0: - self.mattach(vrefs=residues, shallow=False, conn=conn) - - ############################################################################ - ### provenance methods - ############################################################################ - @transaction() - def upsert_provenance(self, calls: List[Call], conn: Optional[Connection] = None): - rows = [] - for call in calls: - call_causal = call.causal_uid - for name, inp in call.inputs.items(): - rows.append( - { - Provenance.causal_uid: inp.causal_uid, - Provenance.name: name, - Provenance.call_causal_uid: call_causal, - Provenance.direction: "input", - Provenance.op_id: call.func_op.sig.versioned_internal_name, - } - ) - for i, output in enumerate(call.outputs): - name = dump_output_name(index=i) - rows.append( - { - Provenance.causal_uid: output.causal_uid, - Provenance.name: name, - Provenance.call_causal_uid: call_causal, - Provenance.direction: "output", - Provenance.op_id: call.func_op.sig.versioned_internal_name, - } - ) - df = pd.DataFrame(rows) - self.rel_storage.upsert(relation=Config.provenance_table, ta=df, conn=conn) - - ############################################################################ - ### deletion - ############################################################################ - @transaction() - def delete_causal_refs( - self, full_uids: List[str], conn: Optional[Connection] = None - ): - self.rel_storage.delete( - relation=Config.causal_vref_table, - where_col=Config.full_uid_col, - where_values=full_uids, - conn=conn, - ) - - @transaction() - def delete_refs( - self, uids: List[str], verbose: bool = False, conn: Optional[Connection] = None - ): - self.rel_storage.delete( - relation=Config.vref_table, - where_col=Config.uid_col, - where_values=uids, - conn=conn, - ) - full_uids = set( - self.rel_storage.execute_df( - query=f"SELECT {Config.full_uid_col} FROM {Config.causal_vref_table}", - conn=conn, - )[Config.full_uid_col].tolist() - ) - full_uids_to_delete = { - full_uid for full_uid in full_uids if full_uid.rsplit(".", 1)[0] in uids - } - if verbose: - logger.info(f"Deleting {len(full_uids_to_delete)} causal refs") - self.delete_causal_refs(full_uids=list(full_uids_to_delete), conn=conn) - spillover_uids = self._get_spillover_uids(uids=uids) - for uid in spillover_uids: - (self.spillover_dir / f"{uid}.joblib").unlink(missing_ok=True) - - @transaction() - def cleanup_vrefs(self, verbose: bool = False, conn: Optional[Connection] = None): - """ - Delete all value references that are not referenced by any call. - """ - # collect all the full uids referenced by calls - all_referenced_uids = set() - for table in self.get_call_tables(conn=conn): - df = self.rel_storage.get_data(table=table, conn=conn) - for col in [ - col for col in df.columns if col not in Config.special_call_cols - ]: - all_referenced_uids.update( - {x.rsplit(".", 1)[0] for x in df[col].unique()} - ) - # delete the unreferenced full uids - stored_uids = set( - self.rel_storage.execute_df( - query=f"SELECT {Config.uid_col} FROM {Config.vref_table}", conn=conn - )[Config.uid_col].tolist() - ) - unreferenced_uids = stored_uids - all_referenced_uids - if verbose: - spillover_uids = self._get_spillover_uids(uids=list(unreferenced_uids)) - total_memory_usage = 0 - for spillover_uid in spillover_uids: - size_in_mb = round( - os.path.getsize(self.spillover_dir / f"{spillover_uid}.joblib") - / 1024**2, - 2, - ) - total_memory_usage += size_in_mb - total_memory_usage += self.get_vref_memory_usage( - uids=list(unreferenced_uids - spillover_uids), conn=conn - ) - logger.info( - f"Deleting {len(unreferenced_uids)} unreferenced value refs that take up {total_memory_usage}MB" - ) - self.delete_refs(uids=list(unreferenced_uids), conn=conn) - - def _get_spillover_uids(self, uids: Iterable[str]) -> Set[str]: - if self.spillover_dir is None: - return set() - spillover_uids = set() - for uid in uids: - if (self.spillover_dir / f"{uid}.joblib").exists(): - spillover_uids.add(uid) - return spillover_uids - - @transaction() - def delete_calls( - self, - versioned_ui_name: str, - causal_uids: List[str], - conn: Optional[Connection] = None, - ): - # delete from memoization table - self.rel_storage.delete( - relation=versioned_ui_name, - where_col=Config.causal_uid_col, - where_values=causal_uids, - conn=conn, - ) - # delete from provenance table - self.rel_storage.delete( - relation=Config.provenance_table, - where_col=Provenance.call_causal_uid, - where_values=causal_uids, - conn=conn, - ) - - @transaction() - def get_vref_memory_usage( - self, uids: List[str], units: str = "MB", conn: Optional[Connection] = None - ) -> float: - if not uids: - return 0 - if units != "MB": - raise NotImplementedError - if len(uids) == 1: - in_clause = f"('{uids[0]}')" - else: - in_clause = str(tuple(uids)) - query = ( - f"SELECT length({Config.vref_value_col}) AS size_in_bytes FROM {Config.vref_table} WHERE __uid__ IN " - + in_clause - ) - df = self.rel_storage.execute_df(query=query, conn=conn) - total_size_in_bytes = df["size_in_bytes"].astype(int).sum() - size_in_mb = round(total_size_in_bytes / 1024**2, 2) - return size_in_mb diff --git a/mandala/storages/remote_impls/__init__.py b/mandala/storages/remote_impls/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mandala/storages/remote_impls/mongo_impl.py b/mandala/storages/remote_impls/mongo_impl.py deleted file mode 100644 index 4c85c99..0000000 --- a/mandala/storages/remote_impls/mongo_impl.py +++ /dev/null @@ -1,35 +0,0 @@ -# import datetime -# from typing import Tuple, List -# -# import pymongo -# -# from mandala.storages.rels import RemoteEventLogEntry -# from mandala.storages.remote_storage import RemoteStorage -# -# -# class MongoRemoteStorage(RemoteStorage): -# def __init__(self, db_name: str, client: pymongo.MongoClient): -# self.db_name = db_name -# self.client = client -# self.log = client.experiment_data[self.db_name].event_log -# -# def save_event_log_entry(self, entry: RemoteEventLogEntry): -# response = self.log.insert_one({"tables": entry}) -# assert response.acknowledged -# # Set the timestamp based on the server time. -# self.log.update_one( -# {"_id": response.inserted_id}, -# {"$currentDate": {"timestamp": {"$type": "date"}}}, -# ) -# -# def get_log_entries_since( -# self, timestamp: datetime.datetime -# ) -> Tuple[List[RemoteEventLogEntry], datetime.datetime]: -# entries = [] -# last_timestamp = datetime.datetime.fromtimestamp(0) -# for entry in self.log.find({"timestamp": {"$gt": timestamp}}): -# entries.append(entry["tables"]) -# if entry["timestamp"] > last_timestamp: -# last_timestamp = entry["timestamp"] -# return entries, last_timestamp -# diff --git a/mandala/storages/remote_impls/mongo_mock.py b/mandala/storages/remote_impls/mongo_mock.py deleted file mode 100644 index 667f1b4..0000000 --- a/mandala/storages/remote_impls/mongo_mock.py +++ /dev/null @@ -1,51 +0,0 @@ -# from ...core.sig import Signature -# from ...common_imports import * -# import mongomock -# import datetime -# -# from mandala.storages.rels import RemoteEventLogEntry -# from mandala.storages.remote_storage import RemoteStorage -# -# -# class MongoMockRemoteStorage(RemoteStorage): -# def __init__(self, db_name: str, client: mongomock.MongoClient): -# self.db_name = db_name -# self.client = client -# self.log = client.experiment_data[self.db_name].event_log -# self.sigs: Dict[Tuple[str, int], Signature] = {} -# -# def pull_signatures(self) -> Dict[Tuple[str, int], Signature]: -# return self.sigs -# -# def push_signatures(self, new_sigs: Dict[Tuple[str, int], Signature]) -> None: -# current_internal_sigs = self.sigs -# for (internal_name, version), new_sig in new_sigs.items(): -# if (internal_name, version) in current_internal_sigs: -# current_sig = current_internal_sigs[(internal_name, version)] -# if not current_sig.is_compatible(new_sig): -# raise ValueError( -# f"Signature {internal_name}:{version} is incompatible with {new_sig}" -# ) -# self.sigs = new_sigs -# -# def save_event_log_entry(self, entry: RemoteEventLogEntry): -# response = self.log.insert_one({"tables": entry}) -# assert response.acknowledged -# # Set the timestamp based on the server time. -# self.log.update_one( -# {"_id": response.inserted_id}, -# {"$currentDate": {"timestamp": {"$type": "date"}}}, -# ) -# -# def get_log_entries_since( -# self, timestamp: datetime.datetime -# ) -> Tuple[List[RemoteEventLogEntry], datetime.datetime]: -# logger.debug(f"Getting log entries since {timestamp}") -# entries = [] -# last_timestamp = datetime.datetime.fromtimestamp(0) -# for entry in self.log.find({"timestamp": {"$gt": timestamp}}): -# entries.append(entry["tables"]) -# if entry["timestamp"] > last_timestamp: -# last_timestamp = entry["timestamp"] -# return entries, last_timestamp -# diff --git a/mandala/storages/remote_storage.py b/mandala/storages/remote_storage.py deleted file mode 100644 index 20a57ef..0000000 --- a/mandala/storages/remote_storage.py +++ /dev/null @@ -1,55 +0,0 @@ -import abc -import datetime -from abc import abstractmethod -from typing import Optional - -from ..common_imports import * -from ..core.sig import Signature - -from mandala.storages.rels import RemoteEventLogEntry, RelAdapter - - -class RemoteStorage(abc.ABC): - @abstractmethod - def save_event_log_entry(self, entry: RemoteEventLogEntry): - raise NotImplementedError() - - @abstractmethod - def get_log_entries_since( - self, timestamp: datetime.datetime - ) -> Tuple[List[RemoteEventLogEntry], datetime.datetime]: - raise NotImplementedError() - - @abstractmethod - def pull_signatures(self) -> Dict[Tuple[str, int], Signature]: - raise NotImplementedError() - - @abstractmethod - def push_signatures(self, new_sigs: Dict[Tuple[str, int], Signature]) -> None: - raise NotImplementedError() - - -# class RemoteSyncManager: -# def __init__( -# self, -# local_storage: RelAdapter, -# remote_storage: RemoteStorage, -# timestamp: Optional[datetime.datetime] = None, -# ): -# self.local_storage = local_storage -# self.remote_storage = remote_storage -# self.last_timestamp = ( -# timestamp if timestamp is not None else datetime.datetime.fromtimestamp(0) -# ) -# -# def sync_from_remote(self): -# new_log_entries, timestamp = self.remote_storage.get_log_entries_since( -# self.last_timestamp -# ) -# self.local_storage.apply_from_remote(new_log_entries) -# self.last_timestamp = timestamp -# -# def sync_to_remote(self): -# changes = self.local_storage.bundle_to_remote() -# self.remote_storage.save_event_log_entry(changes) -# diff --git a/mandala/storages/sigs.py b/mandala/storages/sigs.py deleted file mode 100644 index dfed46d..0000000 --- a/mandala/storages/sigs.py +++ /dev/null @@ -1,296 +0,0 @@ -import pyarrow as pa - -from ..common_imports import * -from ..core.config import Config, dump_output_name -from ..core.sig import Signature -from .rel_impls.utils import Transactable, transaction, Connection -from .rels import RelAdapter, serialize, SigAdapter -from .remote_storage import RemoteStorage - - -class SigSyncer(Transactable): - """ - Responsible for syncing the local schema and the server. - - There are two kinds of updates: - - updates from the server: these come in bulk when you pull the current - true state of the signatures. The signature objects come with their - internal data. - - updates from the client: these could happen piece by piece, or in bulk - if using a context manager. The signature objects may not have - internal data if they come straight from the client's code, so need to - be matched to the existing signatures. - - This class ensures that all updates are valid against the current state of - the signatures on the server, and that only successful updates against this - copy go through to the local storage. - - """ - - def __init__( - self, - sig_adapter: SigAdapter, - root: Optional[Union[Path, RemoteStorage]] = None, - ): - self.sig_adapter = sig_adapter - self.root = root - self.rel_storage = self.sig_adapter.rel_storage - - ############################################################################ - ### `Transactable` interface - ############################################################################ - def _get_connection(self) -> Connection: - return self.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_storage._end_transaction(conn=conn) - - ############################################################################ - ### sync with server - ############################################################################ - @property - def has_remote(self) -> bool: - return self.root is not None - - def pull_signatures(self) -> Dict[Tuple[str, int], Signature]: - """ - Pull the current state of the signatures from the remote, make sure that - they are compatible with the current ones, and then update or create - according to the new signatures. - """ - if isinstance(self.root, RemoteStorage): - new_sigs = self.root.pull_signatures() - return new_sigs - else: - raise ValueError("No remote storage to pull from.") - - def push_signatures(self, sigs: Dict[Tuple[str, int], Signature]): - if isinstance(self.root, RemoteStorage): - assert isinstance(self.root, RemoteStorage) - self.root.push_signatures(new_sigs=sigs) - - @transaction() - def sync_from_remote(self, conn: Optional[Connection] = None): - """ - Update state from signatures *with internal data* (coming from the - server). This includes: - - creating new signatures - - updating existing signatures - - renaming functions and inputs - """ - logger.debug("Syncing signatures from remote...") - if self.has_remote: - sigs = self.pull_signatures() - # sort them by internal name and version to ensure earlier versions - # are updated first - sigs = sorted(sigs.values(), key=lambda s: (s.internal_name, s.version)) - for sig in sigs: - logging.debug(f"Processing signature {sig}") - if self.sig_adapter.exists_internal(sig=sig, conn=conn): - # first set the ui name to the current one (if necessary) - self.sig_adapter.update_ui_name(sig=sig, conn=conn) - # then, update the input names too (if necessary) - self.sig_adapter.update_input_ui_names(sig=sig, conn=conn) - # now, update the (already name-aligned) signature from the new - self.sig_adapter.update_sig(sig=sig, conn=conn) - elif self.sig_adapter.exists_any_version(sig=sig, conn=conn): - self.sig_adapter.create_new_version(sig=sig, conn=conn) - else: - self.sig_adapter.create_sig(sig=sig, conn=conn) - - ############################################################################ - ### atomic operations by the client - ############################################################################ - @transaction() - def validate_transaction( - self, - new_sig: Signature, - all_sigs: Dict[Tuple[str, int], Signature], - conn: Optional[Connection] = None, - ) -> bool: - """ - Check that a new signature is compatible with a current state of the - signatures WITHOUT actually updating the state. This is used to check - that a transaction is valid before committing it. - """ - assert new_sig.has_internal_data - if self.sig_adapter.exists_internal(sig=new_sig, conn=conn): - current = all_sigs[new_sig.internal_name, new_sig.version] - compatible, reason_not = current.is_compatible( - new=new_sig, - ) - if compatible: - return True - else: - raise ValueError(reason_not) - else: - return True - - @transaction() - def sync_create( - self, sig: Signature, conn: Optional[Connection] = None - ) -> Signature: - """ - Pull the current state of the signatures from the remote, make sure that - they are compatible with the creation of this signature, then push the - updated signatures to the remote and create the signature locally. - """ - self.sync_from_remote(conn=conn) - new_sig = sig._generate_internal() - self.validate_transaction( - new_sig=new_sig, - all_sigs=self.sig_adapter.load_state(conn=conn), - conn=conn, - ) - all_sigs = self.sig_adapter.load_state(conn=conn) - all_sigs[new_sig.internal_name, new_sig.version] = new_sig - self.push_signatures(sigs=all_sigs) - self.sig_adapter.create_sig(sig=new_sig, conn=conn) - return new_sig - - @transaction() - def sync_update( - self, sig: Signature, conn: Optional[Connection] = None - ) -> Signature: - """ - Pull the current state of the signatures from the remote, make sure that - they are compatible with the update of this signature, then push the - updated signatures to the remote and update the signature locally. - """ - self.sync_from_remote(conn=conn) - current = self.sig_adapter.load_ui_sigs(conn=conn)[sig.ui_name, sig.version] - if current != sig: - new_sig, _ = current.update(new=sig) - self.validate_transaction( - new_sig=new_sig, all_sigs=self.sig_adapter.load_state(conn=conn) - ) - all_sigs = self.sig_adapter.load_state(conn=conn) - all_sigs[(current.internal_name, current.version)] = new_sig - self.push_signatures(sigs=all_sigs) - self.sig_adapter.update_sig(sig=new_sig, conn=conn) - return new_sig - else: - return current - - @transaction() - def sync_new_version( - self, sig: Signature, conn: Optional[Connection] = None, - strict: bool = False, - ) -> Signature: - """ - Pull the current state of the signatures from the remote, make sure that - they are compatible with the creation of a new version given by this - signature, then push the updated signatures to the remote and create the - new version locally. - """ - self.sync_from_remote(conn=conn) - if not self.sig_adapter.exists_any_version(sig=sig, conn=conn): - raise ValueError() - latest_sig = self.sig_adapter.get_latest_version(sig=sig, conn=conn) - new_version = latest_sig.version + 1 - if strict: - if not sig.version == new_version: - raise ValueError(f"New version must be {new_version}, not {sig.version}") - new_sig = sig._generate_internal(internal_name=latest_sig.internal_name) - new_sig.check_invariants() - self.validate_transaction( - new_sig=new_sig, all_sigs=self.sig_adapter.load_state(conn=conn) - ) - all_sigs = self.sig_adapter.load_state(conn=conn) - all_sigs[(new_sig.internal_name, new_sig.version)] = new_sig - self.push_signatures(sigs=all_sigs) - self.sig_adapter.create_new_version(sig=new_sig, conn=conn) - return new_sig - - @transaction() - def sync_rename_sig( - self, sig: Signature, new_name: str, conn: Optional[Connection] = None - ) -> Signature: - """ - Pull the current state of the signatures from the remote, make sure that - they are compatible with the renaming of this signature, then push the - updated signatures to the remote and rename the signature locally. - """ - self.sync_from_remote(conn=conn) - #! note: we validate before the renaming. Ideally we should have logic - # to do this for the new signature directly - # self.validate_transaction( - # new_sig=sig, all_sigs=self.sig_adapter.load_state(conn=conn) - # ) - new_sig = sig.rename(new_name=new_name) - all_sigs = self.sig_adapter.update_ui_name( - sig=new_sig, conn=conn, validate_only=True - ) - # all_sigs = self.sig_adapter.load_state(conn=conn) - # all_sigs[(new_sig.internal_name, new_sig.version)] = new_sig - self.push_signatures(sigs=all_sigs) - self.sig_adapter.update_ui_name(sig=new_sig, conn=conn) - return new_sig - - @transaction() - def sync_rename_input( - self, - sig: Signature, - input_name: str, - new_input_name: str, - conn: Optional[Connection] = None, - ) -> Signature: - """ - Pull the current state of the signatures from the remote, make sure that - they are compatible with the renaming of this input, then push the - updated signatures to the remote and rename the input locally. - """ - self.sync_from_remote(conn=conn) - #! note: we validate before the renaming. Ideally we should have logic - # to do this for the new signature directly - self.validate_transaction( - new_sig=sig, all_sigs=self.sig_adapter.load_state(conn=conn) - ) - new_sig = sig.rename_inputs(mapping={input_name: new_input_name}) - all_sigs = self.sig_adapter.load_state(conn=conn) - all_sigs[(new_sig.internal_name, new_sig.version)] = new_sig - self.push_signatures(sigs=all_sigs) - self.sig_adapter.update_input_ui_names(sig=new_sig, conn=conn) - return new_sig - - @transaction() - def sync_from_local( - self, - sig: Signature, - conn: Optional[Connection] = None, - ) -> Signature: - """ - Create a new signature, create a new version, or update an existing one, - and immediately send changes to the server. - """ - if sig.version is None: - if self.sig_adapter.exists_any_version(sig=sig, conn=conn): - action = "use_latest" - else: - action = "create" - else: - if self.sig_adapter.exists_versioned_ui(sig=sig, conn=conn): - action = "update_current" - elif self.sig_adapter.exists_any_version(sig=sig, conn=conn): - action = "new_version" - else: - action = "create" - if action == "use_latest": - latest_sig = self.sig_adapter.get_latest_version(sig=sig, conn=conn) - new_sig = copy.deepcopy(sig) - new_sig.version = latest_sig.version - res = self.sync_update(sig=new_sig, conn=conn) - elif action == "update_current": - # current_version = self.sig_adapter.get_version(sig=sig, version=sig.version, conn=conn) - # new_sig = copy.deepcopy(current_sig) - # new_sig.version = latest_sig.version - res = self.sync_update(sig=sig, conn=conn) - elif action == "new_version": - res = self.sync_new_version(sig=sig, conn=conn) - elif action == "create": - sig = copy.deepcopy(sig) - sig.version = 0 - res = self.sync_create(sig=sig, conn=conn) - else: - raise ValueError(f"Unknown action {action}") - return res diff --git a/mandala/tests/__init__.py b/mandala/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mandala/tests/_test_remote_mock.py b/mandala/tests/_test_remote_mock.py deleted file mode 100644 index 1e7905f..0000000 --- a/mandala/tests/_test_remote_mock.py +++ /dev/null @@ -1,301 +0,0 @@ -import mongomock - -from mandala.all import * -from mandala.tests.utils import * - - -def test_disjoint_funcs(): - """ - Test a basic scenario where two users define two different functions and do - some work with each, then sync their results. - - Expected behavior: the two storages end up in the same state. - """ - client = mongomock.MongoClient() - root = MongoMockRemoteStorage(db_name="test", client=client) - - # create multiple storages connected to it - storage_1 = Storage(root=root) - storage_2 = Storage(root=root) - - ### do work with storage 1 - @op - def inc(x: int) -> int: - return x + 1 - - with storage_1.run(): - inc(23) - - ### do work with storage 2 - @op - def mult(x: int, y: int) -> int: - return x * y - - with storage_2.run(): - mult(23, 42) - - # synchronize storage_1 with the new work - storage_1.sync_with_remote() - - # verify that both storages have the same state - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - -def test_create_func(): - """ - Unit test for function creation in a multi-user setting. - - Specifically, test a scenario where: - - one user defines a function - - another user defines the same function under the same name - - Expected behavior: - - the first creator wins: both functions are assigned the UID issued to - the first creator. - """ - client = mongomock.MongoClient() - root = MongoMockRemoteStorage(db_name="test", client=client) - storage_1 = Storage(root=root) - storage_2 = Storage(root=root) - - @op(ui_name="f") - def f_1(x: int) -> int: - return x + 1 - - @op(ui_name="f") - def f_2(x: int) -> int: - return x + 1 - - storage_1.synchronize(f=f_1) - storage_2.synchronize(f=f_2) - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - - -def test_add_input(): - """ - Unit test for adding an input to a function in a multi-user setting. - - Specifically, test a scenario where: - - users 1 and 2 agree on the definition of a function `f` - - user 1 adds an input to the function - - user 2 tries to send calls to the initial function to the server - - Expected behavior: - - calls to the old variant of the function work as expected: - - user 2 is able to commit their calls and send them to the server - - user 1 is able to then load these new calls into their local storage - - the two storages end up in the same state - """ - client = mongomock.MongoClient() - root = MongoMockRemoteStorage(db_name="test", client=client) - storage_1 = Storage(root=root) - storage_2 = Storage(root=root) - - @op(ui_name="inc") - def inc_1(x: int) -> int: - return x + 1 - - @op(ui_name="inc") - def inc_2(x: int) -> int: - return x + 1 - - storage_1.synchronize(f=inc_1) - storage_2.synchronize(f=inc_2) - - @op(ui_name="inc") - def inc_1(x: int, how_many_times: int = 1) -> int: - return x + how_many_times - - storage_1.synchronize(f=inc_1) - assert not signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - with storage_2.run(): - inc_2(23) - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert not data_is_equal(storage_1=storage_1, storage_2=storage_2) - storage_1.sync_from_remote() - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - -def test_rename_func(): - """ - Unit test for renaming a function in a multi-user setting. - - Specifically, test a scenario where: - - users 1 and 2 agree on the definition of a function `f` - - the user 1 renames the function to `g` - - unbeknownst to that, the user 2 still sees `f` and uses it in - computations. - - Expected behavior: - - calls to the old variant of the function work as expected: - - user 2 is able to commit their calls and send them to the server - - user 1 is able to then load these new calls into their local storage - - the two storages end up in the same state - - ! Possible confusion: - - if user 2 then re-synchronizes the old variant of the function called - `f`, this will create a new function. - """ - client = mongomock.MongoClient() - root = MongoMockRemoteStorage(db_name="test", client=client) - storage_1 = Storage(root=root) - storage_2 = Storage(root=root) - - @op(ui_name="f") - def f_1(x: int) -> int: - return x + 1 - - @op(ui_name="f") - def f_2(x: int) -> int: - return x + 1 - - storage_1.synchronize(f=f_1) - storage_2.synchronize(f=f_2) - - storage_1.rename_func(func=f_1, new_name="g") - - assert not signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - with storage_2.run(): - f_2(23) - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert not data_is_equal(storage_1=storage_1, storage_2=storage_2) - storage_1.sync_from_remote() - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - -def test_rename_input(): - """ - Analogous to `test_rename_func`. - """ - client = mongomock.MongoClient() - root = MongoMockRemoteStorage(db_name="test", client=client) - storage_1 = Storage(root=root) - storage_2 = Storage(root=root) - - @op(ui_name="f") - def f_1(x: int) -> int: - return x + 1 - - @op(ui_name="f") - def f_2(x: int) -> int: - return x + 1 - - storage_1.synchronize(f=f_1) - storage_2.synchronize(f=f_2) - - storage_1.rename_arg(func=f_1, name="x", new_name="y") - - assert not signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - with storage_2.run(): - f_2(23) - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert not data_is_equal(storage_1=storage_1, storage_2=storage_2) - storage_1.sync_from_remote() - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - -def test_remote_lots_of_stuff(): - Config.autowrap_inputs = True - Config.autounwrap_inputs = True - # create a *single* (mock) remote database - client = mongomock.MongoClient() - root = MongoMockRemoteStorage(db_name="test", client=client) - - # create multiple storages connected to it - storage_1 = Storage(root=root) - storage_2 = Storage(root=root) - - def check_all_invariants(): - check_invariants(storage=storage_1) - check_invariants(storage=storage_2) - - ### do stuff with storage 1 - @op - def inc(x: int) -> int: - return x + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage_1.run(): - for i in range(20, 25): - j = inc(x=i) - final = add(i, j) - - ### do stuff with storage 2 - @op - def mult(x: int, y: int) -> int: - return x * y - - with storage_2.run(): - for i, j in zip(range(20, 25), range(20, 25)): - k = mult(i, j) - - storage_2.sync_with_remote() - storage_1.sync_with_remote() - storage_2.sync_with_remote() - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - ### now, rename a function in storage 1! - storage_1.rename_func(func=inc, new_name="inc_new") - - @op - def inc_new(x: int) -> int: - return x + 1 - - storage_1.synchronize(f=inc_new) - - ### rename argument too - storage_1.rename_arg(func=inc_new, name="x", new_name="x_new") - - @op - def inc_new(x_new: int) -> int: - return x_new + 1 - - storage_1.synchronize(f=inc_new) - - # do work with the renamed function in storage 1 - with storage_1.run(): - for i in range(20, 30): - j = inc_new(x_new=i) - final = add(i, j) - - # now sync stuff - storage_2.sync_with_remote() - - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - # check we can do work with the renamed function in storage_2 - @op - def inc_new(x_new: int) -> int: - return x_new + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage_2.run(): - for i in range(20, 40): - j = inc_new(x_new=i) - final = add(i, j) - storage_2.sync_with_remote() - storage_1.sync_with_remote() - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) - - # do some versioning stuff - @op(version=1) - def add(x: int, y: int, z: int) -> int: - return x + y + z - - with storage_2.run(): - add(x=1, y=2, z=3) - - storage_2.sync_with_remote() - storage_1.sync_with_remote() - assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2) - assert data_is_equal(storage_1=storage_1, storage_2=storage_2) diff --git a/mandala/tests/stateful_utils.py b/mandala/tests/stateful_utils.py deleted file mode 100644 index b4a4048..0000000 --- a/mandala/tests/stateful_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -from collections import OrderedDict -from mandala.common_imports import * -from mandala.all import * -from mandala.tests.utils import * -from mandala.core.utils import Hashing, invert_dict -from mandala.queries.compiler import * -from mandala.queries.weaver import ValNode, CallNode -import string - - -def combine_inputs(*args, **kwargs) -> str: - return Hashing.get_content_hash(obj=(args, kwargs)) - - -def generate_deterministic(seed: str, n_outputs: int) -> List[str]: - result = [] - current = seed - for i in range(n_outputs): - new = Hashing.get_content_hash(obj=current) - result.append(new) - current = new - return result - - -def random_string(size: int = 10) -> str: - return "".join(random.choice(string.ascii_letters) for _ in range(size)) - - -TEMPLATE = """ -def {name}({inputs}) -> {output_annotation}: - if {n_outputs} == 0: - return None - elif {n_outputs} == 1: - return generate_deterministic(seed=combine_inputs({inputs}), - n_outputs=1)[0] - else: - return tuple(generate_deterministic(seed=combine_inputs({inputs}), n_outputs={n_outputs})) -""" - - -def make_func( - ui_name: str, - input_names: List[str], - n_outputs: int, -) -> types.FunctionType: - """ - Generate a deterministic function with given interface - """ - inputs = ", ".join(input_names) - output_annotation = ( - "None" - if n_outputs == 0 - else "Any" - if n_outputs == 1 - else f"Tuple[{', '.join(['Any'] * n_outputs)}]" - ) - code = TEMPLATE.format( - name=ui_name, - inputs=inputs, - output_annotation=output_annotation, - n_outputs=n_outputs, - ) - f = compile(code, "", "exec") - exec(f) - f = locals()[ui_name] - return f - - -def make_func_from_sig(sig: Signature) -> types.FunctionType: - return make_func(sig.ui_name, list(sig.input_names), sig.n_outputs) - - -def make_op( - ui_name: str, - input_names: List[str], - n_outputs: int, - defaults: Dict[str, Any], - version: int = 0, - deterministic: bool = True, -) -> FuncOp: - """ - Generate a deterministic function with given interface - """ - sig = Signature( - ui_name=ui_name, - input_names=set(input_names), - n_outputs=n_outputs, - version=version, - defaults=defaults, - input_annotations={k: Any for k in input_names}, - output_annotations=[Any] * n_outputs, - ) - f = make_func(ui_name, input_names, n_outputs) - return FuncOp._from_data(func=f, sig=sig) diff --git a/mandala/tests/test_async.py b/mandala/tests/test_async.py deleted file mode 100644 index ea59ec9..0000000 --- a/mandala/tests/test_async.py +++ /dev/null @@ -1,29 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * - - -@pytest.mark.asyncio -async def test_unit(): - storage = Storage() - - @op - async def inc(x: int) -> int: - time.sleep(1) - print("Hi there!") - return x + 1 - - with storage.run(): - z = await inc(1) - assert storage.similar(z).shape[0] == 1 - - # now run 10 calls in parallel - with storage.run(): - tasks = [inc(i) for i in range(10)] - results = await asyncio.gather(*tasks) - assert storage.similar(results[0]).shape[0] == 10 - - # run again - with storage.run(): - tasks = [inc(i) for i in range(10)] - results = await asyncio.gather(*tasks) - assert storage.similar(results[0]).shape[0] == 10 diff --git a/mandala/tests/test_causal.py b/mandala/tests/test_causal.py deleted file mode 100644 index b9fac77..0000000 --- a/mandala/tests/test_causal.py +++ /dev/null @@ -1,88 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * - - -def test_unit(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - x = inc(23) - df = storage.get_table(inc, values="uids") - assert df.shape[0] == 1 - with storage.run(): - y = inc(22) - z = inc(y) - df = storage.get_table(inc, values="uids") - assert df.shape[0] == 3 - - -def test_queries(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def dec(x: int) -> int: - return x - 1 - - with storage.run(): - x = inc(23) - - with storage.run(): - y = dec(24) - - with storage.query(): - x = inc(Q()) - y = dec(x) - df = storage.df(x, y) - assert df.empty - num_content_calls_inc = storage.get_table( - inc, drop_duplicates=True, values="uids" - ).shape[0] - num_content_calls_dec = storage.get_table( - dec, drop_duplicates=True, values="uids" - ).shape[0] - assert num_content_calls_inc == 1 - assert num_content_calls_dec == 1 - - with storage.run(allow_calls=False): - x = inc(23) - y = dec(x) - num_content_calls_inc = storage.get_table( - inc, drop_duplicates=True, values="uids" - ).shape[0] - num_content_calls_dec = storage.get_table( - dec, drop_duplicates=True, values="uids" - ).shape[0] - assert num_content_calls_inc == 1 - assert num_content_calls_dec == 1 - with storage.query(): - x = inc(Q()) - y = dec(x) - df = storage.df(x, y) - assert df.shape[0] == 1 - - -def test_bug(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage.run(): - for x in range(5): - for y in range(5): - z = inc(x) - w = add(z, y) - assert storage.get_table(add).shape[0] == 25 diff --git a/mandala/tests/test_cfs.py b/mandala/tests/test_cfs.py index d76b0b1..b75ad16 100644 --- a/mandala/tests/test_cfs.py +++ b/mandala/tests/test_cfs.py @@ -1,117 +1,18 @@ -from mandala.all import * -from mandala.tests.utils import * -from mandala.queries.weaver import * +from mandala._next.imports import * -def test_construction(): - - storage = Storage() - - @op - def add(x: int, y: int) -> int: - return x + y - - @op - def f(x: int, a: int = 1) -> Tuple[int, int]: - return x + a, x - a - - with storage.run(): - things = [add(i, i + 1) for i in range(10)] - other_things = [f(i) for i in range(10)] - - rf1 = ComputationFrame.from_refs(things, storage=storage) - rf2 = ComputationFrame.from_refs([t[0] for t in other_things], storage=storage) - rf3 = ComputationFrame.from_op(func=add, storage=storage) - rf4 = ComputationFrame.from_op(func=f, storage=storage) - - -def test_back(): - storage = Storage() - - @op - def add(x: int, y: int) -> int: - return x + y - - @op - def f(x: int, a: int = 1) -> Tuple[int, int]: - return x + a, x - a - - @op - def mul(x: int, y: int) -> int: - return x * y - - with storage.run(): - cs = [] - for i in range(10): - a = add(i, i + 1) - b = f(a) - c = mul(b[0], b[1]) - cs.append(c) - - rf1 = ComputationFrame.from_refs(cs, storage=storage, name="c") - rf1.back("c") - rf1.back("c", inplace=True) - rf1.back() - - -def test_evals(): - storage = Storage() - - @op - def add(x: int, y: int) -> int: - return x + y - - @op - def f(x: int, a: int = 1) -> Tuple[int, int]: - return x + a, x - a - - @op - def mul(x: int, y: int) -> int: - return x * y - - with storage.run(): - cs = [] - for i in range(10): - a = add(i, i + 1) - b = f(a) - c = mul(b[0], b[1]) - cs.append(c) - - rf = ComputationFrame.from_refs(cs, storage=storage, name="c") - - rf[["c"]] - rf[list(rf.var_nodes.keys())] - df = rf.eval("c") - assert len(df) == 10 - df = rf.eval() - assert len(df) == 10 - - rf = ComputationFrame.from_op(func=add, storage=storage) - sub_rf = rf[rf.eval("x") < 5] - assert len(sub_rf) == 5 - - -def test_deletion(): - +def test_single_func(): storage = Storage() @op def inc(x: int) -> int: - print("hey!") return x + 1 - with storage.run(): - for i in range(10): - inc(i) - - # check the number of rows goes down by the expected amount - rf = ComputationFrame.from_op(func=inc, storage=storage) - rf[rf.eval("x") < 5].delete(delete_dependents=True, ask=False) - rf = ComputationFrame.from_op(func=inc, storage=storage) - assert len(rf) == 5 - - # re-compute the calls - storage.cache.evict_all() - with storage.run(): + with storage: for i in range(10): inc(i) + + cf = storage.cf(inc) + df = cf.df() + assert df.shape == (10, 3) + assert (df['output_0'] == df['x'] + 1).all() \ No newline at end of file diff --git a/mandala/tests/test_components.py b/mandala/tests/test_components.py deleted file mode 100644 index f604095..0000000 --- a/mandala/tests/test_components.py +++ /dev/null @@ -1,622 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * - - -def test_rel_storage(): - rel_storage = SQLiteRelStorage() - assert set(rel_storage.get_tables()) == set() - rel_storage.create_relation( - name="test", columns=[("a", None), ("b", None)], primary_key="a", defaults={} - ) - assert set(rel_storage.get_tables()) == {"test"} - assert rel_storage.get_data(table="test").empty - df = pd.DataFrame({"a": ["x", "y"], "b": ["z", "w"]}) - ta = pa.Table.from_pandas(df) - rel_storage.insert(relation="test", ta=ta) - assert (rel_storage.get_data(table="test") == df).all().all() - rel_storage.upsert(relation="test", ta=ta) - assert (rel_storage.get_data(table="test") == df).all().all() - rel_storage.create_column(relation="test", name="c", default_value="a") - df["c"] = ["a", "a"] - assert (rel_storage.get_data(table="test") == df).all().all() - # rel_storage.delete(relation="test", index=["x", "y"]) - # assert rel_storage.get_data(table="test").empty - - -def test_main_storage(): - storage = Storage() - # check that things work on an empty storage - storage.sig_adapter.load_state() - storage.rel_adapter.get_all_call_data() - - -def test_wrapping(): - assert unwrap(23) == 23 - assert unwrap(23.0) == 23.0 - assert unwrap("23") == "23" - assert unwrap([1, 2, 3]) == [1, 2, 3] - - vref = wrap_atom(23) - assert wrap_atom(vref) is vref - try: - wrap_atom(vref, uid="aaaaaa") - except: - assert True - - -def test_unwrapping(): - # tuples - assert unwrap((1, 2, 3)) == (1, 2, 3) - vrefs = (wrap_atom(23), wrap_atom(24), wrap_atom(25)) - assert unwrap(vrefs, through_collections=True) == (23, 24, 25) - assert unwrap(vrefs, through_collections=False) == vrefs - # sets - assert unwrap({1, 2, 3}) == {1, 2, 3} - vrefs = {wrap_atom(23), wrap_atom(24), wrap_atom(25)} - assert unwrap(vrefs, through_collections=True) == {23, 24, 25} - assert unwrap(vrefs, through_collections=False) == vrefs - # lists - assert unwrap([1, 2, 3]) == [1, 2, 3] - vrefs = [wrap_atom(23), wrap_atom(24), wrap_atom(25)] - assert unwrap(vrefs, through_collections=True) == [23, 24, 25] - assert unwrap(vrefs, through_collections=False) == vrefs - # dicts - assert unwrap({"a": 1, "b": 2, "c": 3}) == {"a": 1, "b": 2, "c": 3} - vrefs = {"a": wrap_atom(23), "b": wrap_atom(24), "c": wrap_atom(25)} - assert unwrap(vrefs, through_collections=True) == {"a": 23, "b": 24, "c": 25} - assert unwrap(vrefs, through_collections=False) == vrefs - - -def test_reprs(): - x = wrap_atom(23) - repr(x), str(x) - - -################################################################################ -### contexts -################################################################################ -def test_nesting_new_api(): - storage = Storage() - - with storage.run() as c: - assert c.mode == MODES.run - - with storage.query() as q: - assert q.mode == MODES.query - - with storage.run() as c_1: - with storage.run() as c_2: - with storage.run() as c_3: - assert c_1 is c_2 is c_3 - - with storage.run() as c: - assert c.mode == MODES.run - assert c.storage is storage - with storage.query() as q: - assert q is c - assert q.storage is storage - assert q.mode == MODES.query - assert c.mode == MODES.run - - -def test_noop(): - # check that ops are noops when not in a context - - @op - def inc(x: int) -> int: - return x + 1 - - assert inc(23) == 24 - - -def test_failures(): - storage = Storage() - - try: - with storage.run(bla=23): - pass - except: - assert True - - -################################################################################ -### test ops -################################################################################ -def test_signatures(): - sig = Signature( - ui_name="f", - input_names={"x", "y"}, - n_outputs=1, - defaults={"y": 42}, - version=0, - input_annotations={"x": int, "y": int}, - output_annotations=[Any], - ) - - # if internal data has not been set, it should not be accessible - try: - sig.internal_name - except ValueError: - assert True - - try: - sig.ui_to_internal_input_map - except ValueError: - assert True - - with_internal = sig._generate_internal() - assert not sig.has_internal_data - assert with_internal.has_internal_data - - ### invalid signature changes - # remove input - new = Signature( - ui_name="f", - input_names={"x", "z"}, - n_outputs=1, - defaults={"y": 42}, - version=0, - input_annotations={"x": int, "z": int}, - output_annotations=[Any], - ) - try: - sig.update(new=new) - except ValueError: - assert True - new = Signature( - ui_name="f", - input_names={"x", "y"}, - n_outputs=1, - defaults={}, - version=0, - input_annotations={"x": int, "y": int}, - output_annotations=[Any], - ) - try: - sig.update(new=new) - except ValueError: - assert True - # change version - new = Signature( - ui_name="f", - input_names={"x", "y"}, - n_outputs=1, - defaults={"y": 42}, - version=1, - input_annotations={"x": int, "y": int}, - output_annotations=[Any], - ) - try: - sig.update(new=new) - except ValueError: - assert True - - # add input - sig = sig._generate_internal() - try: - sig.create_input(name="y", default=23, annotation=Any) - except ValueError: - assert True - new = sig.create_input(name="z", default=23, annotation=Any) - assert new.input_names == {"x", "y", "z"} - - -def test_output_name_failure(): - - try: - - @op - def f(output_0: int) -> int: - return output_0 - - except: - assert True - - -def test_changing_num_outputs(): - - storage = Storage() - - @op - def f(x: int): - return x - - try: - with storage.run(): - f(1) - except Exception: - assert True - - @op - def f(x: int) -> int: - return x - - with storage.run(): - f(1) - - @op - def f(x: int) -> Tuple[int, int]: - return x - - try: - with storage.run(): - f(1) - except ValueError: - assert True - - @op - def f(x: int) -> int: - return x - - with storage.run(): - f(1) - - -def test_nout(): - storage = Storage() - - @op(nout=2) - def f(x: int): - return x, x - - with storage.run(): - a, b = f(1) - assert unwrap(a) == 1 and unwrap(b) == 1 - - @op(nout=0) - def g(x: int) -> Tuple[int, int]: - pass - - with storage.run(): - c = g(1) - assert c is None - - -################################################################################ -### test storage -################################################################################ -OUTPUT_ROOT = Path(__file__).parent / "output" - - -def test_get(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - y = inc(23) - - y_full = storage.rel_adapter.obj_get(uid=y.uid) - assert y_full.in_memory - assert unwrap(y_full) == 24 - - y_lazy = storage.rel_adapter.obj_get(uid=y.uid, _attach_atoms=False) - assert not y_lazy.in_memory - assert y_lazy.obj is None - - @op - def get_prime_factors(n: int) -> Set[int]: - factors = set() - d = 2 - while d * d <= n: - while (n % d) == 0: - factors.add(d) - n //= d - d += 1 - if n > 1: - factors.add(n) - return factors - - with storage.run(): - factors = get_prime_factors(42) - - factors_full = storage.rel_adapter.obj_get(uid=factors.uid) - assert factors_full.in_memory - assert all([x.in_memory for x in factors_full]) - - factors_shallow = storage.rel_adapter.obj_get(uid=factors.uid, depth=1) - assert factors_shallow.in_memory - assert all([not x.in_memory for x in factors_shallow]) - - storage.rel_adapter.mattach(vrefs=[factors_shallow]) - assert all([x.in_memory for x in factors_shallow]) - - @superop - def get_factorizations(n: int) -> List[List[int]]: - # get all factorizations of a number into factors - n = unwrap(n) - divisors = [i for i in range(2, n + 1) if n % i == 0] - result = [[n]] - for divisor in divisors: - sub_solutions = unwrap(get_factorizations(n // divisor)) - result.extend( - [ - [divisor] + sub_solution - for sub_solution in sub_solutions - if min(sub_solution) >= divisor - ] - ) - return result - - with storage.run(): - factorizations = get_factorizations(42) - - factorizations_full = storage.rel_adapter.obj_get(uid=factorizations.uid) - assert unwrap(factorizations_full) == [[42], [2, 21], [2, 3, 7], [3, 14], [6, 7]] - factorizations_shallow = storage.rel_adapter.obj_get( - uid=factorizations.uid, depth=1 - ) - assert factorizations_shallow.in_memory - storage.rel_adapter.mattach(vrefs=[factorizations_shallow.obj[0]]) - assert unwrap(factorizations_shallow.obj[0]) == [42] - - # result, call, wrapped_inputs = storage.call_run( - # func_op=get_factorizations.func_op, - # inputs={"n": 42}, - # _call_depth=0, - # ) - - -def test_persistent(): - db_path = OUTPUT_ROOT / "test_persistent.db" - if db_path.exists(): - db_path.unlink() - storage = Storage(db_path=db_path) - - try: - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def get_prime_factors(n: int) -> Set[int]: - factors = set() - d = 2 - while d * d <= n: - while (n % d) == 0: - factors.add(d) - n //= d - d += 1 - if n > 1: - factors.add(n) - return factors - - @superop - def get_factorizations(n: int) -> List[List[int]]: - # get all factorizations of a number into factors - n = unwrap(n) - divisors = [i for i in range(2, n + 1) if n % i == 0] - result = [[n]] - for divisor in divisors: - sub_solutions = unwrap(get_factorizations(n // divisor)) - result.extend( - [ - [divisor] + sub_solution - for sub_solution in sub_solutions - if min(sub_solution) >= divisor - ] - ) - return result - - with storage.run(): - y = inc(23) - factors = get_prime_factors(42) - factorizations = get_factorizations(42) - assert all([x.in_memory for x in (y, factors, factorizations)]) - - with storage.run(): - y = inc(23) - factors = get_prime_factors(42) - factorizations = get_factorizations(42) - assert all([not x.in_memory for x in (y, factors, factorizations)]) - - with storage.run(lazy=False): - y = inc(23) - factors = get_prime_factors(42) - factorizations = get_factorizations(42) - assert all([x.in_memory for x in (y, factors, factorizations)]) - - with storage.run(): - y = inc(23) - assert not y.in_memory - y._auto_attach() - assert y.in_memory - factors = get_prime_factors(42) - assert not factors.in_memory - 7 in factors - assert factors.in_memory - - factorizations = get_factorizations(42) - assert not factorizations.in_memory - n = len(factorizations) - assert factorizations.in_memory - #! this is now in memory b/c new caching - # assert not factorizations[0].in_memory - # factorizations[0][0] - # assert factorizations[0].in_memory - - #! this is now in memory b/c new caching - # for elt in factorizations[1]: - # assert not elt.in_memory - - except Exception as e: - raise e - finally: - db_path.unlink() - - -def test_magics(): - db_path = OUTPUT_ROOT / "test_magics.db" - if db_path.exists(): - db_path.unlink() - storage = Storage(db_path=db_path) - - try: - Config.enable_ref_magics = True - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - x = inc(23) - - with storage.run(): - x = inc(23) - assert not x.in_memory - if x > 0: - y = inc(x) - assert x.in_memory - - with storage.run(): - x = inc(23) - y = inc(x) - if x + y > 0: - z = inc(x) - - with storage.run(): - x = inc(23) - y = inc(x) - if x: - z = inc(x) - - except Exception as e: - raise e - finally: - db_path.unlink() - - -def test_spillover(): - db_path = OUTPUT_ROOT / "test_spillover.db" - if db_path.exists(): - db_path.unlink() - spillover_dir = OUTPUT_ROOT / "test_spillover/" - if spillover_dir.exists(): - shutil.rmtree(spillover_dir) - storage = Storage(db_path=db_path, spillover_dir=spillover_dir) - - try: - import numpy as np - - @op - def create_large_array() -> np.ndarray: - return np.random.rand(10_000_000) - - with storage.run(): - x = create_large_array() - - assert len(os.listdir(spillover_dir)) == 1 - path = spillover_dir / os.listdir(spillover_dir)[0] - with open(path, "rb") as f: - data = unwrap(joblib.load(f)) - assert np.allclose(data, unwrap(x)) - - with storage.run(): - x = create_large_array() - assert not x.in_memory - x = unwrap(x) - - except Exception as e: - raise e - finally: - db_path.unlink() - shutil.rmtree(spillover_dir) - - -def test_batching_unit(): - - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.batch(): - y = inc(23) - - assert unwrap(y) == 24 - assert y.uid is not None - all_data = storage.rel_storage.get_all_data() - assert all_data[Config.vref_table].shape[0] == 2 - assert all_data[inc.func_op.sig.versioned_ui_name].shape[0] == 1 - - -def test_exclude_arg(): - storage = Storage() - - @op - def inc(x: int, __excluded__=False) -> int: - return x + 1 - - with storage.run(): - y = inc(23) - - -def test_provenance(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - y = inc(23) - storage.prov(y) - storage.print_graph(y) - - ### struct inputs - @op - def avg_list(nums: List[int]) -> float: - return sum(nums) / len(nums) - - @op - def avg_dict(nums: Dict[str, int]) -> float: - return sum(nums.values()) / len(nums) - - @op - def avg_set(nums: Set[int]) -> float: - return sum(nums) / len(nums) - - with storage.run(): - x = avg_list([1, 2, 3]) - y = avg_dict({"a": 1, "b": 2, "c": 3}) - z = avg_set({1, 2, 3}) - for v in [x, y, z]: - storage.prov(v) - storage.print_graph(v) - - ### struct outputs - @op - def get_list() -> List[int]: - return [1, 2, 3] - - @op - def get_dict() -> Dict[str, int]: - return {"a": 1, "b": 2, "c": 3} - - @op - def get_set() -> Set[int]: - return {1, 2, 3} - - with storage.run(): - x = get_list() - y = get_dict() - z = get_set() - a = x[0] - b = y["a"] - for v in [x, y, z, a, b]: - storage.prov(v) - storage.print_graph(v) - - ### a mess of stuff - with storage.run(): - a = get_list() - x = avg_list(a[:2]) - y = avg_dict(get_dict()) - z = avg_set(get_set()) - for v in [a, x, y, z]: - storage.prov(v) - storage.print_graph(v) diff --git a/mandala/tests/test_deps.py b/mandala/tests/test_deps.py deleted file mode 100644 index fbabce9..0000000 --- a/mandala/tests/test_deps.py +++ /dev/null @@ -1,615 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * -from mandala.deps.shallow_versions import DAG -import numpy as np - - -def test_dag(): - d = DAG(content_type="code") - try: - d.commit("something") - except AssertionError: - pass - - content_hash_1 = d.init(initial_content="something") - assert len(d.commits) == 1 - assert d.head == content_hash_1 - content_hash_2 = d.commit(content="something else", is_semantic_change=True) - assert ( - d.commits[content_hash_2].semantic_hash - != d.commits[content_hash_1].semantic_hash - ) - assert d.head == content_hash_2 - content_hash_3 = d.commit(content="something else #2", is_semantic_change=False) - assert ( - d.commits[content_hash_3].semantic_hash - == d.commits[content_hash_2].semantic_hash - ) - - content_hash_4 = d.sync(content="something else") - assert content_hash_4 == content_hash_2 - assert d.head == content_hash_2 - - d.show() - - -MODULE_NAME = "mandala.tests.test_deps" -DEPS_PACKAGE = "mandala.tests" -DEPS_PATH = Path(__file__).absolute().resolve() -# MODULE_NAME = '__main__' - - -def _test_version_reprs(storage: Storage): - for dag in storage.get_versioner().component_dags.values(): - for compact in [True, False]: - dag.show(compact=compact) - for version in storage.get_versioner().get_flat_versions().values(): - storage.get_versioner().present_dependencies(commits=version.semantic_expansion) - storage.get_versioner().global_topology.show(path=generate_path(ext=".png")) - repr(storage.get_versioner().global_topology) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_unit(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - - # to be able to import this name - global f_1, A - - A = 42 - - @op - def f_1(x) -> int: - return 23 + A - - with storage.run(): - f_1(1) - - vs = storage.get_versioner() - f_1_versions = vs.versions[MODULE_NAME, "f_1"] - assert len(f_1_versions) == 1 - version = f_1_versions[list(f_1_versions.keys())[0]] - assert set(version.support) == {(MODULE_NAME, "f_1"), (MODULE_NAME, "A")} - _test_version_reprs(storage=storage) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_libraries(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - if tracer_impl is SysTracer: - track = lambda x: x - else: - from mandala.deps.tracers.dec_impl import track - - global f_2, f_1 - - @track - def f_1(x) -> int: - return 23 - - # use functions from libraries to make sure we don't trace them - @op - def f_2(x) -> int: - df = pd.DataFrame({"a": [1, 2, 3]}) - array = np.array([1, 2, 3]) - x = array.mean() + np.random.uniform() - return f_1(x) - - with storage.run(): - f_2(1) - - vs = storage.get_versioner() - f_2_versions = vs.versions[MODULE_NAME, "f_2"] - assert len(f_2_versions) == 1 - version = f_2_versions[list(f_2_versions.keys())[0]] - assert set(version.support) == {(MODULE_NAME, "f_2"), (MODULE_NAME, "f_1")} - _test_version_reprs(storage=storage) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_deps(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - global dep_1, f_2, A - if tracer_impl is SysTracer: - track = lambda x: x - else: - from mandala.deps.tracers.dec_impl import track - - A = 42 - - @track - def dep_1(x) -> int: - return 23 - - @track - @op - def f_2(x) -> int: - return dep_1(x) + A - - with storage.run(): - f_2(1) - - vs = storage.get_versioner() - f_2_versions = vs.versions[MODULE_NAME, "f_2"] - assert len(f_2_versions) == 1 - version = f_2_versions[list(f_2_versions.keys())[0]] - assert set(version.support) == { - (MODULE_NAME, "f_2"), - (MODULE_NAME, "dep_1"), - (MODULE_NAME, "A"), - } - _test_version_reprs(storage=storage) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_changes(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - - global f - - @op - def f(x) -> int: - return x + 1 - - with storage.run(): - f(1) - commit_1 = storage.sync_component( - component=f, - is_semantic_change=None, - ) - - @op - def f(x) -> int: - return x + 2 - - commit_2 = storage.sync_component(component=f, is_semantic_change=True) - assert commit_1 != commit_2 - with storage.run(): - f(1) - - @op - def f(x) -> int: - return x + 1 - - # confirm we reverted to the previous version - commit_3 = storage.sync_component( - component=f, - is_semantic_change=None, - ) - assert commit_3 == commit_1 - with storage.run(allow_calls=False): - f(1) - - # create a new branch - @op - def f(x) -> int: - return x + 3 - - commit_4 = storage.sync_component( - component=f, - is_semantic_change=True, - ) - assert commit_4 not in (commit_1, commit_2) - with storage.run(): - f(1) - - f_versions = storage.get_versioner().versions[MODULE_NAME, "f"] - assert len(f_versions) == 3 - semantic_versions = [v.semantic_version for v in f_versions.values()] - assert len(set(semantic_versions)) == 3 - _test_version_reprs(storage=storage) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_superops(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - Config.enable_ref_magics = True - - global f_1, f_2, f_3 - - @op - def f_1(x) -> int: - return x + 1 - - @superop - def f_2(x) -> int: - return f_1(x) + 1 - - with storage.run(): - f_1(1) - - with storage.run(attach_call_to_outputs=True): - a = f_2(1) - call: Call = a._call - version = storage.get_versioner().versions[MODULE_NAME, "f_2"][call.content_version] - assert set(version.support) == {(MODULE_NAME, "f_1"), (MODULE_NAME, "f_2")} - - @superop - def f_3(x) -> int: - return f_2(x) + 1 - - with storage.run(attach_call_to_outputs=True): - a = f_3(1) - call: Call = a._call - version = storage.get_versioner().versions[MODULE_NAME, "f_3"][call.content_version] - assert set(version.support) == { - (MODULE_NAME, "f_1"), - (MODULE_NAME, "f_2"), - (MODULE_NAME, "f_3"), - } - _test_version_reprs(storage=storage) - Config.enable_ref_magics = False - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_dependency_patterns(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - global A, B, f_1, f_2, f_3, f_4, f_5, f_6 - if tracer_impl is SysTracer: - track = lambda x: x - else: - from mandala.deps.tracers.dec_impl import track - - # global vars - A = 23 - B = [1, 2, 3] - - # using a global var - @track - def f_1(x) -> int: - return x + A - - # calling another function - @track - def f_2(x) -> int: - return f_1(x) + B[0] - - # different dependencies per call - @op - def f_3(x) -> int: - if x % 2 == 0: - return f_2(2 * x) - else: - return f_1(x + 1) - - with storage.run(attach_call_to_outputs=True): - x = f_3(0) - call: Call = x._call - version = storage.get_versioner().get_flat_versions()[call.content_version] - assert version.support == { - (MODULE_NAME, "f_3"), - (MODULE_NAME, "f_2"), - (MODULE_NAME, "A"), - (MODULE_NAME, "B"), - (MODULE_NAME, "f_1"), - } - with storage.run(attach_call_to_outputs=True): - x = f_3(1) - call: Call = x._call - version = storage.get_versioner().get_flat_versions()[call.content_version] - assert version.support == { - (MODULE_NAME, "f_3"), - (MODULE_NAME, "f_1"), - (MODULE_NAME, "A"), - } - - # using a lambda - @op - def f_4(x) -> int: - f = lambda y: f_1(y) + B[0] - return f(x) - - # make sure the call in the lambda is detected - with storage.run(attach_call_to_outputs=True): - x = f_4(10) - call: Call = x._call - version = storage.get_versioner().get_flat_versions()[call.content_version] - assert version.support == { - (MODULE_NAME, "f_4"), - (MODULE_NAME, "f_1"), - (MODULE_NAME, "A"), - (MODULE_NAME, "B"), - } - - # using comprehensions and generators - @superop - def f_5(x) -> int: - x = unwrap(x) - a = [f_1(y) for y in range(x)] - b = {f_2(y) for y in range(x)} - c = {y: f_3(y) for y in range(x)} - return sum(unwrap(f_4(y)) for y in range(x)) - - with storage.run(): - f_5(10) - - f_5_versions = storage.get_versioner().versions[MODULE_NAME, "f_5"] - assert len(f_5_versions) == 1 - version = f_5_versions[list(f_5_versions.keys())[0]] - assert set(version.support) == { - (MODULE_NAME, "f_5"), - (MODULE_NAME, "f_4"), - (MODULE_NAME, "f_3"), - (MODULE_NAME, "f_2"), - (MODULE_NAME, "f_1"), - (MODULE_NAME, "A"), - (MODULE_NAME, "B"), - } - - # nested comprehensions and generators - @superop - def f_6(x) -> int: - x = unwrap(x) - # nested list comprehension - a = sum([sum([f_1(y) for y in range(x)]) for z in range(x)]) - # nested comprehension with generator - b = sum(sum(f_2(y) for y in range(x)) for z in range(unwrap(f_3(x)))) - return a + b - - with storage.run(): - f_6(2) - - f_6_versions = storage.get_versioner().versions[MODULE_NAME, "f_6"] - assert len(f_6_versions) == 1 - version = f_6_versions[list(f_6_versions.keys())[0]] - assert set(version.support) == { - (MODULE_NAME, "f_6"), - (MODULE_NAME, "f_3"), - (MODULE_NAME, "f_2"), - (MODULE_NAME, "f_1"), - (MODULE_NAME, "A"), - (MODULE_NAME, "B"), - } - _test_version_reprs(storage=storage) - storage.versions(f_6) - storage.sources(f_6) - storage.get_code(version_id=version.content_version) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_recursion(tracer_impl): - ### mutual recursion - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - Config.enable_ref_magics = True - global s_1, s_2 - - @superop - def s_1(x) -> int: - if x == 0: - return 0 - else: - return s_2(x - 1) + 1 - - @superop - def s_2(x) -> int: - if x == 0: - return 0 - else: - return s_1(x - 1) + 1 - - with storage.run(attach_call_to_outputs=True): - a = s_1(0) - call_1 = a._call - b = s_1(1) - call_2 = b._call - version_1 = storage.get_versioner().get_flat_versions()[call_1.content_version] - assert version_1.support == {(MODULE_NAME, "s_1")} - version_2 = storage.get_versioner().get_flat_versions()[call_2.content_version] - assert version_2.support == {(MODULE_NAME, "s_1"), (MODULE_NAME, "s_2")} - _test_version_reprs(storage=storage) - Config.enable_ref_magics = False - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_memoized_tracking(tracer_impl): - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - if tracer_impl is SysTracer: - track = lambda x: x - else: - from mandala.deps.tracers.dec_impl import track - - global f_1, f_2, f, g - - @track - def f_1(x): - return x + 1 - - @track - def f_2(x): - return x + 2 - - @op - def f(x) -> int: - if x % 2 == 0: - return f_1(x) - else: - return f_2(x) - - @superop - def g(x) -> List[int]: - return [f(i) for i in range(unwrap(x))] - - with storage.run(): - for i in range(10): - f(i) - - with storage.run(): - z = g(10) - - -@pytest.mark.parametrize("tracer_impl", [SysTracer, DecTracer]) -def test_transient(tracer_impl): - global f - - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - - @op - def f(x) -> int: - return Transient(x + 1) - - with storage.run(): - a = f(42) - with storage.run(recompute_transient=True): - a = f(42) - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_queries_unit(tracer_impl): - Config.query_engine = "sql" # doesn't work yet with the naive engine - - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - - global g_1 - - ### create an op, run it and check the query result - @op - def g_1(x) -> int: - return x + 2 - - with storage.run(): - [g_1(i) for i in range(10)] - - with storage.query(): - i = Q() - j = g_1(i) - df = storage.df(i.named("i"), j.named("j")) - - assert df.shape == (10, 2) - assert (df["j"] == df["i"] + 2).all() - - ### change the op semantically and check that the query result is empty - @op - def g_1(x) -> int: - return x + 3 - - storage.sync_component( - component=g_1, - is_semantic_change=True, - ) - - with storage.query(): - i = Q() - j = g_1(i) - df = storage.df(i.named("i"), j.named("j")) - - assert df.empty - - ### run the op again and check that the query result is correct for the new version - with storage.run(): - [g_1(i) for i in range(10)] - - with storage.query(): - i = Q() - j = g_1(i) - df = storage.df(i.named("i"), j.named("j")) - - assert df.shape == (10, 2) - assert (df["j"] == df["i"] + 3).all() - - ### go back to the old version and check that the query result is correct - @op - def g_1(x) -> int: - return x + 2 - - storage.sync_component( - component=g_1, - is_semantic_change=None, - ) - with storage.query(): - i = Q() - j = g_1(i) - df = storage.df(i.named("i"), j.named("j")) - assert df.shape == (10, 2) - assert (df["j"] == df["i"] + 2).all() - - -@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer]) -def test_queries_multiple_versions(tracer_impl): - Config.query_engine = "sql" # doesn't work yet with the naive engine - storage = Storage( - deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl - ) - - global f_1, f_2, f_3 - - ### define an op with multiple semantically-compatible versions - @op - def f_1(x) -> int: - return x + 2 - - @op - def f_2(x) -> int: - return x + 3 - - @op - def f_3(x) -> int: - if x % 2 == 0: - return f_2(2 * x) - else: - return f_1(x + 1) - - with storage.run(): - for i in range(10): # make sure both versions are used - f_3(i) - - with storage.query(): - i = Q() - j = f_3(i) - df_1 = storage.df(i.named("i"), j.named("j")) - assert df_1.shape == (10, 2) - assert (df_1["j"] == df_1["i"].apply(f_3)).all() - - # change one of the dependencies semantically and check that the query - # result is what remains from the other branch - @op - def f_1(x) -> int: - return x + 4 - - storage.sync_component( - f_1, - is_semantic_change=True, - ) - - with storage.query(): - i = Q() - j = f_3(i) - df_2 = storage.df(i.named("i"), j.named("j")) - assert df_2.shape == (5, 2) - assert sorted(df_2["i"].values.tolist()) == [0, 2, 4, 6, 8] - assert (df_2["j"] == df_2["i"].apply(f_3)).all() - - ### go back to the old version and check that the query result is recovered - @op - def f_1(x) -> int: - return x + 2 - - storage.sync_component( - f_1, - is_semantic_change=None, - ) - - with storage.query(): - i = Q() - j = f_3(i) - df_3 = storage.df(i.named("i"), j.named("j")) - assert (df_1 == df_3).all().all() diff --git a/mandala/tests/test_memoization.py b/mandala/tests/test_memoization.py index 2c6eccc..2f47bda 100644 --- a/mandala/tests/test_memoization.py +++ b/mandala/tests/test_memoization.py @@ -1,142 +1,133 @@ -from mandala.all import * -from mandala.tests.utils import * +from mandala._next.imports import * -def test_func_creation(): +def test_storage(): storage = Storage() - @op - def add(x: int, y: int = 42) -> int: - return x + y - - assert add.func_op.sig.n_outputs == 1 - assert add.func_op.sig.input_names == {"x", "y"} - assert add.func_op.sig.defaults == {"y": 42} - check_invariants(storage) - - -@pytest.mark.parametrize("storage", generate_storages()) -def test_computation(storage): @op def inc(x: int) -> int: return x + 1 + + with storage: + x = 1 + y = inc(x) + z = inc(2) + w = inc(y) + + assert w.cid == z.cid + assert w.hid != y.hid + assert w.cid != y.cid + assert storage.unwrap(y) == 2 + assert storage.unwrap(z) == 3 + assert storage.unwrap(w) == 3 + for ref in (y, z, w): + assert storage.attach(ref).in_memory + assert storage.attach(ref).obj == storage.unwrap(ref) + + +def test_signatures(): + storage = Storage() - @op - def add(x: int, y: int) -> int: - return x + y + @op # a function with a wild input/output signature + def add(x, *args, y: int = 1, **kwargs): + # just sum everything + res = x + sum(args) + y + sum(kwargs.values()) + if kwargs: + return res, kwargs + elif args: + return None + else: + return res + + with storage: + # call the func in all the ways + sum_1 = add(1) + sum_2 = add(1, 2, 3, 4, ) + sum_3 = add(1, 2, 3, 4, y=5) + sum_4 = add(1, 2, 3, 4, y=5, z=6) + sum_5 = add(1, 2, 3, 4, z=5, w=7) + + assert storage.unwrap(sum_1) == 2 + assert storage.unwrap(sum_2) == None + assert storage.unwrap(sum_3) == None + assert storage.unwrap(sum_4) == (21, {'z': 6}) + assert storage.unwrap(sum_5) == (23, {'z': 5, 'w': 7}) + + +def test_retracing(): + storage = Storage() - # chain some functions - with storage.run(): - x = 23 - y = inc(x) - z = add(x, y=y) + @op + def inc(x): + return x + 1 - check_invariants(storage) - # run it again - with storage.run(): - x = 23 - y = inc(x) - z = add(x, y) - check_invariants(storage) - # do some more things - with storage.run(): - x = 42 - y = inc(x) - z = add(x, y) + ### iterating a function + with storage: + start = 1 for i in range(10): - z = add(z, i) - check_invariants(storage) + start = inc(start) + with storage: + start = 1 + for i in range(10): + start = inc(start) -@pytest.mark.parametrize("storage", generate_storages()) -def test_retracing(storage): - @op - def inc(x: int) -> int: - return x + 1 - + ### composing functions @op - def add(x: int, y: int) -> int: + def add(x, y): return x + y - - with storage.run(): - x = 23 - y = inc(x) - z = add(x, y) - - with storage.run(allow_calls=False): - x = 23 - y = inc(x) - z = add(x, y) - - try: - with storage.run(allow_calls=False): - x = 24 - y = inc(x) - z = add(x, y) - assert False - except Exception as e: - assert True - - -def test_debugging(): + + with storage: + inp = [1, 2, 3, 4, 5] + stage_1 = [inc(x) for x in inp] + stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)] + + with storage: + inp = [1, 2, 3, 4, 5] + stage_1 = [inc(x) for x in inp] + stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)] + + +def test_lists(): storage = Storage() @op - def inc(x: int) -> int: - return x + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage.run(debug_calls=True): - x = 23 - y = inc(x) - z = add(x, y) - - with storage.run(debug_calls=True): - x = 23 - y = inc(x) - z = add(x, y) - - -def _a(): + def get_sum(elts: MList[int]) -> int: + return sum(elts) + @op - def generate_dataset() -> Tuple[int, int]: - return 1, 2 - + def primes_below(n: int) -> MList[int]: + primes = [] + for i in range(2, n): + for p in primes: + if i % p == 0: + break + else: + primes.append(i) + return primes + @op - def train_model( - train_dataset: int, - test_dataset: int, - learning_rate: float, - batch_size: int, - num_epochs: int, - ) -> Tuple[int, float]: - return train_dataset + test_dataset + learning_rate, batch_size + num_epochs - - storage = Storage() - - with storage.run(): - X, y = generate_dataset() - for batch_size in (100, 200, 400): - for learning_rate in (1, 2, 3): - model, acc = train_model( - X, - y, - learning_rate=learning_rate, - batch_size=batch_size, - num_epochs=10, - ) - - with storage.run(): - X, y = generate_dataset() - for batch_size in (100, 200, 400): - for learning_rate in (1, 2, 3): - model, acc = train_model( - X, - y, - learning_rate=learning_rate, - batch_size=batch_size, - num_epochs=10, - ) + def chunked_square(elts: MList[int]) -> MList[int]: + # a model for an op that does something on chunks of a big thing + # to prevent OOM errors + return [x*x for x in elts] + + with storage: + n = 10 + primes = primes_below(n) + sum_primes = get_sum(primes) + assert len(primes) == 4 + # check indexing + assert storage.unwrap(primes[0]) == 2 + assert storage.unwrap(primes[:2]) == [2, 3] + + ### lists w/ overlapping elements + with storage: + n = 100 + primes = primes_below(n) + for i in range(0, len(primes), 2): + sum_primes = get_sum(primes[:i+1]) + + with storage: + elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + squares = chunked_square(elts) diff --git a/mandala/tests/test_queries.py b/mandala/tests/test_queries.py deleted file mode 100644 index 37a642e..0000000 --- a/mandala/tests/test_queries.py +++ /dev/null @@ -1,380 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * -from mandala.queries.viz import visualize_graph -from mandala.queries.weaver import * -from mandala.queries.graphs import * -from mandala.queries.viz import * - -OUTPUT_ROOT = Path(__file__).parent / "output/" - - -def get_graph(vqs: Set[ValNode]) -> InducedSubgraph: - vqs, fqs = traverse_all(vqs=vqs, direction="both") - return InducedSubgraph(vqs=vqs, fqs=fqs) - - -@pytest.mark.parametrize("storage", generate_storages()) -def test_queries_basics(storage): - Config.query_engine = "_test" - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage.run(): - for i in range(20, 25): - j = inc(i) - final = add(i, y=j) - - with storage.query(): - i = Q().named("i") - j = inc(i).named("j") - final = add(i, y=j).named("final") - df = storage.df(i, j, final) - assert set(df["i"]) == {i for i in range(20, 25)} - assert all(df["j"] == df["i"] + 1) - df_refs = storage.df(i, j, final, values="refs") - df_uids = storage.df(i, j, final, values="uids") - df_lazy = storage.df(i, j, final, values="lazy") - assert compare_dfs_as_relations(df_refs.applymap(lambda x: x.uid), df_uids) - assert compare_dfs_as_relations( - df_refs.applymap(lambda x: x.detached()), df_lazy - ) - assert compare_dfs_as_relations(df_refs.applymap(unwrap), df) - assert compare_dfs_as_relations(df_lazy.applymap(unwrap), df) - - check_invariants(storage) - vqs, fqs = traverse_all([i, j, final]) - visualize_graph( - vqs=vqs, fqs=fqs, output_path=OUTPUT_ROOT / "test_basics.svg", names=None - ) - - -def test_empty(): - storage = Storage() - - with storage.query(): - try: - df = storage.df() - except: - pass - - -def test_queries_static_builder(): - """ - A one-stop test for all of the following cases: - - queries to match data structures incl. nested structures - - queries to match elements of data structures - - wrapping raw objects into queries - - merging the query graph from a run block and a query block - - """ - storage = Storage() - - @op - def f(x) -> int: - return x + 1 - - @op - def g(x, y) -> Tuple[int, int]: - return x + y, x * y - - @op - def avg(numbers: list) -> float: - return sum(numbers) / len(numbers) - - @op - def avg_dict(numbers: dict) -> float: - return sum(numbers.values()) / len(numbers) - - @op - def avg_set(numbers: set) -> float: - return sum(numbers) / len(numbers) - - @op - def repeat(x, n) -> list: - return [x] * n - - @op - def dictify(x, y) -> dict: - return {"x": x, "y": y} - - @op - def nested_dictify(x, y) -> Dict[str, List[int]]: - return {"x": [x, x], "y": [y, y]} - - def get_graph_1(): - with storage.query(): - x = Q() - y = f(x) - z, w = g(y, x) - ### constructive stuff - lst_1 = qwrap([z, w, {Q(): x, ...: ...}, {x, y, ...}, ...]) - a_1 = avg(lst_1) - dct_1 = qwrap({Q(): z, ...: ...}) - a_2 = avg_dict(dct_1) - st_1 = qwrap({x, y, ...}) - a_3 = avg_set(st_1) - ### destructive stuff - lst_2 = repeat(x, Q()) - elt_0 = lst_2[Q()] - dct_2 = dictify(x, y) - val_0 = dct_2[Q()] - dct_3 = nested_dictify(x, y) - val_1 = dct_3[Q()] - elt_1 = val_1[Q()] - res = get_graph(vqs={elt_1, val_0, elt_0, a_3, a_2, a_1}) - return res - - def get_graph_2(): - with storage.query(): - x = Q() - y = f(x) - z, w = g(y, x) - helper_1 = BuiltinQueries.DictQ(dct={Q(): x}) - helper_2 = BuiltinQueries.SetQ(elts={x, y}) - lst_1 = BuiltinQueries.ListQ(elts=[z, w, helper_1, helper_2]) - a_1 = avg(lst_1) - dct_1 = BuiltinQueries.DictQ(dct={Q(): z}) - a_2 = avg_dict(dct_1) - st_1 = BuiltinQueries.SetQ(elts={x, y}) - a_3 = avg_set(st_1) - lst_2 = repeat(x, Q()) - elt_0 = lst_2[Q()] - dct_2 = dictify(x, y) - val_0 = dct_2[Q()] - dct_3 = nested_dictify(x, y) - val_1 = dct_3[Q()] - elt_1 = val_1[Q()] - res = get_graph(vqs={elt_1, val_0, elt_0, a_3, a_2, a_1}) - return res - - g_1, g_2 = get_graph_1(), get_graph_2() - assert InducedSubgraph.are_canonically_isomorphic(g_1, g_2) - - -def test_queries_visualization(): - storage = Storage() - - @op - def f(x: int, y: int) -> Tuple[int, int, int]: - return x, y, x + y - - @op - def g() -> int: - return 42 - - @op - def h(z: int, w: int): - pass - - for func in [f, g, h]: - storage.synchronize(func) - - with storage.query(): - a = g().named("a") - b = Q().named("b") - c, d, e = f(x=a, y=b) - c.named("c"), d.named("d"), e.named("e") - h(z=c, w=d) - h(z=a, w=b) - x, y, z = f(x=d, y=e) - x.named("x"), y.named("y"), z.named("z") - storage.draw_graph(a, b, c, d, e, traverse="both", show_how="none") - storage.print_graph(a, b, c, d, e, traverse="both") - - -def test_queries_exceptions(): - storage = Storage() - - @op - def inc(x: int) -> int: - return unwrap(x) + 1 - - try: - # passing raw values in queries should (currently) raise an error - with storage.query(): - y = inc(23) - assert False - except Exception: - assert True - - -def test_queries_filter_duplicates(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage.run(): - for x in range(5): - for y in range(5): - z = inc(x) - print(z.causal_uid) - w = add(z, y) - - with storage.query(): - x = Q().named("x") - y = Q().named("y") - z = inc(x).named("z") - w = add(z, y).named("w") - df_1 = storage.df(x, z, drop_duplicates=True, context=False) - df_2 = storage.df(x, z, drop_duplicates=False, context=False) - df_3 = storage.df( - x, - z, - drop_duplicates=True, - context=False, - engine="naive", - _visualize_steps_at=OUTPUT_ROOT, - ) - df_4 = storage.df( - x, - z, - drop_duplicates=False, - context=False, - engine="naive", - _visualize_steps_at=OUTPUT_ROOT, - ) - - assert len(df_1) == 5 - assert len(df_2) == 25 - assert compare_dfs_as_relations(df_1, df_3) - assert compare_dfs_as_relations(df_2, df_4) - - -def test_queries_weird(): - Config.autowrap_inputs = True - Config.autounwrap_inputs = True - Config.query_engine = "_test" - - storage = Storage() - - @op - def a(f: int, g: int): - return - - @op - def b(k: int, l: int): - return - - @op - def c(h: int, i: int) -> int: - return h + i - - @op - def d(j: int) -> Tuple[int, int, int]: - return j, j, j - - @op - def e(m: int) -> int: - return m + 1 - - for f in [a, b, c, d, e]: - storage.synchronize(f) - - with storage.run(): - var_0 = 23 - var_1 = 42 - a(var_0, var_0) - b(var_0, var_0) - a(f=var_0, g=var_0) - var_2 = c(h=var_0, i=var_0) - a(f=var_0, g=var_0) - var_3, var_4, var_5 = d(j=var_1) - b(k=var_4, l=var_1) - var_6 = e(m=var_2) - - with storage.query(): - var_0 = Q() - var_1 = Q() - a(f=var_0, g=var_0) - b(k=var_0, l=var_0) - a(f=var_0, g=var_0) - var_2 = c(h=var_0, i=var_0) - a(f=var_0, g=var_0) - var_3, var_4, var_5 = d(j=var_1) - b(k=var_4, l=var_1) - var_6 = e(m=var_2) - df = storage.df(var_0, var_1, var_2, var_3, var_4, var_5, var_6) - - -def test_queries_required_args(): - storage = Storage() - - @op - def add(x: int, y: int) -> int: - return x + y - - with storage.run(): - for x in [1, 2, 3]: - for y in [4, 5, 6]: - add(x, y) - - with storage.query(): - x = Q().named("x") - result = add(x=x).named("result") - df = storage.df(x, result) - assert df.shape == (9, 2) - - -def test_generalized(): - Config.query_engine = "sql" - storage = Storage() - - @op - def create_list(x) -> list: - return [x + i for i in range(10)] - - @op - def consume_list(nums: list) -> int: - return sum(nums) - - with storage.run(): - for x in wrap(list(range(10))): - nums = create_list(x) - for i in [2, 4, 6, 8, 10]: - res = consume_list(nums[:i]) - df = storage.similar(res, context=True) - assert df.shape[0] == 300 - - with storage.query(): - idx0 = Q() - idx0.pin(0) - x = Q() - nums = create_list(x=x) - a0 = nums[idx0] - a1 = ListQ(elts=[a0], idxs=[idx0]) - res = consume_list(nums=a1) - result = storage.df(idx0, x, nums, a0, a1, res) - assert result.shape[0] == 50 - - -def test_defaults(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - z = inc(23) - - @op - def inc(x: int, amount: int = 1) -> int: - return x + amount - - with storage.run(): - z = inc(23, 10) - - df = storage.similar(z) - assert df.shape[0] == 2 diff --git a/mandala/tests/test_refactoring.py b/mandala/tests/test_refactoring.py deleted file mode 100644 index 645033f..0000000 --- a/mandala/tests/test_refactoring.py +++ /dev/null @@ -1,408 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * - - -def test_add_input(): - Config.autowrap_inputs = True - Config.autounwrap_inputs = True - - storage = Storage() - - ############################################################################ - ### check that old calls are preserved - ############################################################################ - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - x = inc(23) - - @op - def inc(x: int, y=1) -> int: - return x + y - - with storage.run(): - x = inc(23) - - assert inc.func_op.sig.input_names == {"x", "y"} - df = storage.get_table(inc) - assert all(c in df.columns for c in ["x", "y"]) - assert df.shape[0] == 1 - - ############################################################################ - ### check that defaults can be overridden - ############################################################################ - @op - def add_many(x: int) -> int: - return x + 1 - - storage.synchronize(f=add_many) - - @op - def add_many(x: int, y: int = 23, z: int = 42) -> int: - return x + y + z - - storage.synchronize(f=add_many) - - with storage.run(): - add_many(0) - add_many(0, 1) - add_many(0, 1, 2) - - with storage.query(): - x, y, z = Q(), Q(), Q() - w = add_many(x, y, z) - df = storage.df(x, y, z, w) - - rows = set(tuple(row) for row in df.values.tolist()) - assert rows == {(0, 1, 2, 3), (0, 1, 42, 43), (0, 23, 42, 65)} - - ############################################################################ - ### check that queries work with defaults - ############################################################################ - with storage.query() as q: - x = Q() - w = add_many(x) - df = storage.df(x, w) - - ############################################################################ - ### check that invalid ways to add an input are not allowed - ############################################################################ - - ### no default - @op - def no_default(x: int) -> int: - return x + 1 - - storage.synchronize(f=no_default) - - try: - - @op - def no_default(x: int, y: int) -> int: - return x + y - - storage.synchronize(f=no_default) - except: - assert True - - -def test_add_input_bug(): - """ - The issue was that the sync logic was expecting to see the UIDs for the - defaults upon re-synchronizing the updated version. - """ - storage = Storage() - - @op - def f() -> int: - return 1 - - with storage.run(): - f() - - @op - def f(x: int = 23) -> int: - return x - - with storage.run(): - f() - - @op - def f(x: int = 23) -> int: - return x - - with storage.run(): - f() - - -def test_default_change(): - """ - Changing default values is not allowed for @ops - """ - storage = Storage() - - @op - def f(x: int = 23) -> int: - return x - - with storage.run(): - a = f() - - @op - def f(x: int = 42) -> int: - return x - - try: - with storage.run(): - b = f() - except: - assert True - - -def test_func_renaming(): - - storage = Storage() - - ############################################################################ - ### unit test - ############################################################################ - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - x = inc(23) - - storage.rename_func(func=inc, new_name="inc_new") - assert inc.is_invalidated - - # define correct function - @op - def inc_new(x: int) -> int: - return x + 1 - - with storage.run(): - inc_new(23) - - df = storage.get_table(inc_new) - # make sure the call was not new - assert df.shape[0] == 1 - # make sure we did not create a new function - sigs = [v for k, v in storage.sig_adapter.load_state().items() if not v.is_builtin] - assert len(sigs) == 1 - - ############################################################################ - ### check that name collisions are not allowed - ############################################################################ - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def new_inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=inc) - storage.synchronize(f=new_inc) - - try: - storage.rename_func(func=inc, new_name="new_inc") - except: - assert True - - ############################################################################ - ### permute names - ############################################################################ - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def new_inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=inc) - storage.synchronize(f=new_inc) - - storage.rename_func(func=inc, new_name="temp") - storage.rename_func(func=new_inc, new_name="inc") - - @op - def temp(x: int) -> int: - return x + 1 - - @op - def inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=temp) - storage.synchronize(f=inc) - - storage.rename_func(func=temp, new_name="new_inc") - - -def test_arg_renaming(): - storage = Storage() - - ############################################################################ - ### unit test - ############################################################################ - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - x = inc(23) - - storage.rename_arg(func=inc, name="x", new_name="x_new") - assert inc.is_invalidated - - # define correct function - @op - def inc(x_new: int) -> int: - return x_new + 1 - - with storage.run(): - x = inc(23) - - df = storage.get_table(inc) - # make sure the call was not new - assert df.shape[0] == 1 - # make sure we did not create a new function - sigs = [v for k, v in storage.sig_adapter.load_state().items() if not v.is_builtin] - assert len(sigs) == 1 - - ############################################################################ - ### check collisions are not allowed - ############################################################################ - @op - def add(x: int, y: int) -> int: - return x + y - - storage.synchronize(f=add) - try: - storage.rename_arg(func=add, name="x", new_name="y") - except: - assert True - - -def test_renaming_failures_1(): - """ - Try to do a rename on a function that was invalidated - """ - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=inc) - - storage.rename_func(func=inc, new_name="inc_new") - try: - storage.rename_func(func=inc, new_name="inc_other") - except: - assert True - - @op - def add(x: int, y: int) -> int: - return x + y - - storage.synchronize(f=add) - - storage.rename_arg(func=add, name="x", new_name="z") - try: - storage.rename_arg(func=add, name="y", new_name="w") - except: - assert True - - -def test_renaming_failures_2(): - """ - Try renaming a function to a name that already exists for another function - """ - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - @op - def add(x: int, y: int) -> int: - return x + y - - for f in (inc, add): - storage.synchronize(f=f) - - try: - storage.rename_func(func=inc, new_name="add") - except: - assert True - - -def test_renaming_inside_context_1(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=inc) - - try: - with storage.run(): - storage.rename_func(func=inc, new_name="inc_new") - inc(23) - except: - assert True - - @op - def add(x: int, y: int) -> int: - return x + y - - storage.synchronize(f=add) - - try: - with storage.run(): - storage.rename_arg(func=add, name="x", new_name="z") - add(23, 42) - except: - assert True - - -def test_renaming_inside_context_2(): - """ - Like the previous one, but with uncommitted work - """ - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=inc) - - try: - with storage.run(): - inc(23) - storage.rename_func(func=inc, new_name="inc_new") - except: - assert True - - @op - def add(x: int, y: int) -> int: - return x + y - - storage.synchronize(f=add) - - try: - with storage.run(): - add(23, 42) - storage.rename_arg(func=add, name="x", new_name="z") - except: - assert True - - -def test_other_refactoring_failures(): - storage = Storage() - - @op - def inc(x: int) -> int: - return x + 1 - - storage.synchronize(f=inc) - - @op - def inc(y: int) -> int: - return y + 1 - - try: - storage.synchronize(f=inc) - except: - assert True diff --git a/mandala/tests/test_stateful_bugs.py b/mandala/tests/test_stateful_bugs.py deleted file mode 100644 index f2c3883..0000000 --- a/mandala/tests/test_stateful_bugs.py +++ /dev/null @@ -1,273 +0,0 @@ -from .test_stateful_slow import SingleClientSimulator, MultiClientSimulator - - -def test_1(): - state = SingleClientSimulator(n_clients=1) - state.add_workflow() - state.add_input_var_to_workflow() - state.create_op() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.add_input() - state.execute_workflow() - state.execute_workflow() - - -def test_2(): - state = SingleClientSimulator(n_clients=1) - state.add_workflow() - state.create_op() - state.add_input_var_to_workflow() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.add_input() - state.execute_workflow() - state.verify_state() - - -def test_3(): - state = SingleClientSimulator(n_clients=1) - state.create_op() - state.add_workflow() - state.add_workflow() - state.rename_func() - state.create_op() - state.add_workflow() - state.create_op() - state.add_input() - state.create_op() - state.add_workflow() - state.rename_input() - state.add_workflow() - state.rename_input() - state.create_op() - state.rename_input() - state.add_input_var_to_workflow() - state.rename_func() - state.add_input() - state.add_op_to_workflow() - state.rename_input() - state.add_call_to_workflow() - state.add_input() - state.add_input() - state.add_input_var_to_workflow() - state.add_op_to_workflow() - state.rename_func() - state.add_op_to_workflow() - state.rename_func() - state.add_input_var_to_workflow() - state.rename_func() - state.rename_func() - state.rename_input() - state.rename_func() - state.rename_func() - state.add_input() - state.execute_workflow() - state.add_op_to_workflow() - state.add_op_to_workflow() - state.rename_input() - state.add_op_to_workflow() - state.execute_workflow() - state.rename_input() - state.add_op_to_workflow() - state.execute_workflow() - state.add_op_to_workflow() - state.rename_func() - state.add_call_to_workflow() - state.add_input_var_to_workflow() - state.rename_func() - state.rename_func() - state.execute_workflow() - state.rename_func() - state.add_call_to_workflow() - state.rename_func() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.add_call_to_workflow() - state.add_input_var_to_workflow() - state.add_call_to_workflow() - state.rename_input() - state.rename_func() - state.rename_input() - state.rename_func() - state.execute_workflow() - state.add_input_var_to_workflow() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.add_call_to_workflow() - state.execute_workflow() - - -def test_4(): - state = SingleClientSimulator(n_clients=1) - state.create_op() - state.create_new_version() - state.rename_func() - state.add_input() - - -def test_5(): - state = SingleClientSimulator(n_clients=1) - state.add_workflow() - state.add_input_var_to_workflow() - state.create_op() - state.add_input() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.create_new_version() - state.add_input_var_to_workflow() - state.add_input() - state.rename_input() - state.create_new_version() - state.add_input_var_to_workflow() - state.add_op_to_workflow() - state.rename_input() - state.add_op_to_workflow() - state.verify_state() - - -def test_6(): - state = SingleClientSimulator(n_clients=1) - state.add_workflow() - state.create_op() - state.add_input() - state.rename_func() - state.rename_func() - state.rename_input() - state.create_op() - state.add_input() - state.add_input_var_to_workflow() - state.add_op_to_workflow() - state.add_input() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.add_input() - state.add_op_to_workflow() - state.execute_workflow() - state.add_call_to_workflow() - state.execute_workflow() - state.verify_state() - - -def test_7(): - state = SingleClientSimulator(n_clients=1) - state.create_op() - state.create_new_version() - state.rename_func() - - -def _test_8(): - state = MultiClientSimulator(n_clients=3) - state.create_op() - state.create_op() - state.add_input() - state.add_input() - state.create_op() - state.add_input() - state.add_input() - state.add_input() - state.rename_input() - state.create_op() - state.add_input() - state.add_input() - state.add_input() - state.rename_input() - state.add_input() - state.add_input() - - -def _test_9(): - state = MultiClientSimulator(n_clients=3) - state.create_op() - state.rename_input() - state.sync_one() - state.sync_one() - state.add_workflow() - state.rename_func() - state.add_workflow() - state.add_input() - state.create_new_version() - state.add_input_var_to_workflow() - state.add_input() - state.add_op_to_workflow() - state.rename_func() - state.add_call_to_workflow() - state.execute_workflow() - state.add_call_to_workflow() - state.add_input() - state.add_input_var_to_workflow() - state.sync_one() - - -def _test_10(): - state = MultiClientSimulator() - state.add_workflow() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.create_op() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.execute_workflow() - state.add_input() - state.sync_one() - - -def _test_11(): - state = MultiClientSimulator(n_clients=3) - state.add_workflow() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.create_op() - state.add_input() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.execute_workflow() - state.add_input() - state.sync_one() - state.verify_state() - - -def _test_12(): - state = MultiClientSimulator(n_clients=3) - state.add_workflow() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.create_op() - state.add_input_var_to_workflow() - state.add_input_var_to_workflow() - state.add_op_to_workflow() - state.add_call_to_workflow() - state.execute_workflow() - state.add_input() - state.sync_one() - state.verify_state() - - -def _test_13(): - state = MultiClientSimulator(n_clients=3) - state.create_op() - state.rename_func() - state.rename_input() - state.add_input() - state.rename_func() - state.rename_func() - state.add_input() - state.create_op() - state.add_workflow() - state.check_mock_storage_single() - state.create_op() - state.sync_one() - state.create_op() - state.add_input() - state.rename_func() - state.add_input() - state.rename_input() - state.check_mock_storage_single() - state.rename_func() - state.create_op() - state.create_op() - state.add_input() - state.create_new_version() - state.verify_state() diff --git a/mandala/tests/test_stateful_slow.py b/mandala/tests/test_stateful_slow.py deleted file mode 100644 index 397a560..0000000 --- a/mandala/tests/test_stateful_slow.py +++ /dev/null @@ -1,737 +0,0 @@ -from hypothesis.stateful import ( - RuleBasedStateMachine, - Bundle, - rule, - initialize, - precondition, - invariant, - run_state_machine_as_test, -) -from hypothesis._settings import settings, Verbosity -from hypothesis import strategies as st - -from mandala.common_imports import * -from mandala.all import * -from mandala.tests.utils import * -from mandala.tests.stateful_utils import * -from mandala.queries.workflow import Workflow, CallStruct -from mandala.core.utils import Hashing, get_uid, parse_full_uid -from mandala.queries.compiler import * -from mandala.core.model import Type, ListType -from mandala.core.builtins_ import Builtins -from mandala.core.sig import _get_return_annotations -from mandala.storages.remote_storage import RemoteStorage -from mandala.ui.executors import SimpleWorkflowExecutor -from mandala.ui.funcs import FuncInterface -from mandala.ui.storage import make_delayed - - -class MockStorage: - """ - A simple storage simulator that - - stores all data in memory: calls as tables, vrefs as a dictionary - - only uses internal names for signatures; - - can be synced in a "naive" way with another storage: the state of the - other storage is upserted into this storage. - - note: currently, it replicates only the memoization tables and the table - of values (not the causal UIDs table) - - It is invariant-checked at the entry and exit of every method. The - state of this object should only be changed through these methods. This - makes it easy to track down when an inconsistent update happened. - """ - - def __init__(self): - self.calls: Dict[str, pd.DataFrame] = {} - self.values: Dict[str, Any] = {} - # versioned internal op name -> (internal input name -> default uid) - self.default_uids: Dict[str, Dict[str, str]] = {} - for builtin_op in Builtins.OPS.values(): - name = builtin_op.sig.versioned_internal_name - self.calls[name] = pd.DataFrame( - columns=list(builtin_op.sig.input_names) + Config.special_call_cols - ) - self.default_uids[name] = {} - self.check_invariants() - - def __eq__(self, other: Any): - if not isinstance(other, MockStorage): - return False - values_equal = self.values == other.values - default_uids_equal = self.default_uids == other.default_uids - calls_equal = self.calls.keys() == other.calls.keys() and all( - compare_dfs_as_relations(self.calls[k], other.calls[k]) - for k in self.calls.keys() - ) - return values_equal and default_uids_equal and calls_equal - - def check_invariants(self): - assert self.default_uids.keys() == self.calls.keys() - # get all vref uids that appear in calls - vref_uids_from_calls = [] - for k, df in self.calls.items(): - for col in df.columns: - if col not in Config.special_call_cols: - vref_uids_from_calls += df[col].values.tolist() - vref_uids_from_calls = [parse_full_uid(x)[0] for x in vref_uids_from_calls] - assert set(vref_uids_from_calls) <= set(self.values.keys()) - for versioned_internal_name, defaults in self.default_uids.items(): - df = self.calls[versioned_internal_name] - for internal_input_name, default_uid in defaults.items(): - assert internal_input_name in df.columns - - def create_op(self, func_op: FuncOp): - self.check_invariants() - sig = func_op.sig - if sig.versioned_internal_name in self.calls.keys(): - raise ValueError() - if sig.versioned_internal_name in self.default_uids: - raise ValueError() - self.calls[sig.versioned_internal_name] = pd.DataFrame( - columns=Config.special_call_cols - + list(sig.ui_to_internal_input_map.values()) - + [dump_output_name(index=i) for i in range(sig.n_outputs)] - ) - self.default_uids[sig.versioned_internal_name] = {} - self.check_invariants() - - def add_input( - self, - func_op: FuncOp, - internal_name: str, - default_value: Any, - default_full_uid: str, - ): - self.check_invariants() - default_uid, default_causal_uid = parse_full_uid(default_full_uid) - sig = func_op.sig - df = self.calls[sig.versioned_internal_name] - df[internal_name] = [default_full_uid for _ in range(len(df))] - self.values[default_uid] = default_value - self.default_uids[sig.versioned_internal_name][internal_name] = default_full_uid - self.check_invariants() - - def rename_func(self, func_op: FuncOp, new_name: str): - pass - - def rename_input(self, func_op: FuncOp, old_name: str, new_name: str): - pass - - def create_new_version(self, new_version: FuncOp): - self.check_invariants() - self.create_op(func_op=new_version) - self.check_invariants() - - def add_call(self, call: Call): - self.check_invariants() - func_op, inputs, outputs = call.func_op, call.inputs, call.outputs - sig = func_op.sig - row = { - Config.uid_col: call.uid, - Config.causal_uid_col: call.causal_uid, - Config.content_version_col: call.content_version, - Config.semantic_version_col: call.semantic_version, - Config.transient_col: call.transient, - **{sig.ui_to_internal_input_map[k]: v.full_uid for k, v in inputs.items()}, - **{dump_output_name(index=i): v.full_uid for i, v in enumerate(outputs)}, - } - df = self.calls[sig.versioned_internal_name] - # handle stale calls - for k in df.columns: - if k not in row.keys(): - row[k] = self.default_uids[sig.versioned_internal_name][k] - row_df = pd.DataFrame([row]) - if row[Config.uid_col] not in df[Config.uid_col].values: - self.calls[sig.versioned_internal_name] = pd.concat( - [df, row_df], - ignore_index=True, - ) - for vref in itertools.chain(inputs.values(), outputs): - self.values[vref.uid] = unwrap(vref) - self.check_invariants() - - def sync_from_other(self, other: "MockStorage"): - """ - Update this storage with the new data from another storage. - - NOTE: always copy the other storage's state into this storage; never use - shared objects. - """ - self.check_invariants() - # update values - for k, v in other.values.items(): - if k in self.values.keys(): - assert v == self.values[k] - # avoid shared objects - self.values[k] = copy.deepcopy(v) - # update defaults - for versioned_internal_name, defaults in other.default_uids.items(): - if versioned_internal_name in self.calls.keys(): - df = self.calls[versioned_internal_name] - for internal_input_name, default_uid in defaults.items(): - if internal_input_name not in df.columns: - df[internal_input_name] = [default_uid for _ in range(len(df))] - self.default_uids[versioned_internal_name][ - internal_input_name - ] = default_uid - else: - # use a copy to avoid a shared mutable object - self.default_uids[versioned_internal_name] = copy.deepcopy(defaults) - # update calls - for versioned_internal_name, df in other.calls.items(): - if versioned_internal_name in self.calls.keys(): - current_df = self.calls[versioned_internal_name] - new_df = df[~df[Config.uid_col].isin(current_df[Config.uid_col])].copy() - self.calls[versioned_internal_name] = pd.concat( - [current_df, new_df], ignore_index=True - ) - else: - self.calls[versioned_internal_name] = df.copy() - self.check_invariants() - - def compare_with_real(self, real_storage: Storage): - self.check_invariants() - # extract values - values_df = real_storage.rel_adapter.get_vrefs() - values = { - row[Config.uid_col]: row[Config.vref_value_col] - for _, row in values_df.iterrows() - } - sigs = real_storage.sig_adapter.load_ui_sigs() - # extract calls and defaults - all_call_data = real_storage.rel_adapter.get_all_call_data() - calls = {} - default_uids = {} - for versioned_ui_name, df in all_call_data.items(): - sig = sigs[Signature.parse_versioned_name(versioned_ui_name)] - versioned_internal_name = sig.versioned_internal_name - calls[versioned_internal_name] = df.rename( - columns=sig.ui_to_internal_input_map - ) - default_uids[versioned_internal_name] = sig._new_input_defaults_uids - sess.d() - return ( - values == self.values - and all( - compare_dfs_as_relations(calls[k], self.calls[k]) for k in calls.keys() - ) - and default_uids == self.default_uids - ) - - -class ClientState: - def __init__(self, root: Optional[RemoteStorage]): - self.storage = Storage(root=root) - self.mock_storage = MockStorage() - self.workflows: List[Workflow] = [] - self.func_ops: List[FuncOp] = [] - self.num_func_renames = 0 - self.num_input_renames = 0 - - -class Preconditions: - """ - A namespace for preconditions for the rules of the state machine. - - NOTE: Preconditions are defined as functions instead of lambdas to enable - type introspection, autorefactoring, etc. - """ - - # control some of the transitions to avoid long chains, especially ones that - # make the DB larger - #! (does this actually optimize things? we need to benchmark) - MAX_OPS_PER_CLIENT = 10 - MAX_INPUTS_PER_OP = 5 - MAX_WORKFLOWS_PER_CLIENT = 5 - MAX_WORKFLOW_SIZE_TO_ADD_VAR = 5 - MAX_WORKFLOW_SIZE_TO_ADD_OP = 10 - # prevent too many renames - MAX_FUNC_RENAMES_PER_CLIENT = 20 - MAX_INPUT_RENAMES_PER_CLIENT = 20 - - @staticmethod - def create_op(instance: "SingleClientSimulator") -> Tuple[bool, List[ClientState]]: - # return clients for which an op can be created - candidates = [ - c - for c in instance.clients - if len(c.func_ops) < Preconditions.MAX_OPS_PER_CLIENT - ] - return len(candidates) > 0, candidates - - ############################################################################ - ### refactoring - ############################################################################ - @staticmethod - def add_input( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]: - # return tuples of (client, op, op_idx) for which an input can be added - candidates = [] - for c in instance.clients: - for idx, func_op in enumerate(c.func_ops): - if len(func_op.sig.input_names) < Preconditions.MAX_INPUTS_PER_OP: - candidates.append((c, func_op, idx)) - return len(candidates) > 0, candidates - - @staticmethod - def rename_func( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]: - # return tuples of (client, op, idx) for which the op can be renamed - candidates = [] - for c in instance.clients: - if c.num_func_renames < Preconditions.MAX_FUNC_RENAMES_PER_CLIENT: - for idx, func_op in enumerate(c.func_ops): - candidates.append((c, func_op, idx)) - return len(candidates) > 0, candidates - - @staticmethod - def rename_input( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]: - # return tuples of (client, op, idx) for which an input can be renamed - candidates = [] - for c in instance.clients: - if c.num_input_renames < Preconditions.MAX_INPUT_RENAMES_PER_CLIENT: - candidates += [(c, op, idx) for idx, op in enumerate(c.func_ops)] - return len(candidates) > 0, candidates - - @staticmethod - def create_new_version( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]: - # return tuples of (client, op, idx of this op) for which a new version can be created - candidates = [] - for c in instance.clients: - num_ops = len(c.func_ops) - if num_ops > 0 and num_ops < Preconditions.MAX_OPS_PER_CLIENT: - candidates += [(c, op, idx) for idx, op in enumerate(c.func_ops)] - return len(candidates) > 0, candidates - - ############################################################################ - ### growing workflows - ############################################################################ - @staticmethod - def add_workflow( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[ClientState]]: - # return clients for which a workflow can be created - candidates = [ - c - for c in instance.clients - if len(c.workflows) < Preconditions.MAX_WORKFLOWS_PER_CLIENT - ] - return len(candidates) > 0, candidates - - @staticmethod - def add_input_var_to_workflow( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, Workflow]]]: - # return tuples of (client, workflow) for which an input var can be added - candidates = [] - for c in instance.clients: - for w in c.workflows: - if w.shape_size < Preconditions.MAX_WORKFLOW_SIZE_TO_ADD_VAR: - candidates.append((c, w)) - return len(candidates) > 0, candidates - - @staticmethod - def add_op_to_workflow( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, Workflow, FuncOp]]]: - # return tuples of (client, workflow, op) for which an op can be added - candidates = [] - for c in instance.clients: - for w in c.workflows: - if ( - w.shape_size < Preconditions.MAX_WORKFLOW_SIZE_TO_ADD_OP - and len(w.var_nodes) > 0 - ): - candidates += [(c, w, op) for op in c.func_ops] - return len(candidates) > 0, candidates - - @staticmethod - def add_call_to_workflow( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, Workflow, CallNode]]]: - # return tuples of (client, workflow, op_node) for which a call can be added - candidates = [] - for c in instance.clients: - for w in c.workflows: - candidates += [(c, w, op_node) for op_node in w.callable_op_nodes] - return len(candidates) > 0, candidates - - @staticmethod - def execute_workflow( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, Workflow]]]: - # return tuples of (client, workflow) for which a workflow can be - # executed - candidates = [] - for c in instance.clients: - for w in c.workflows: - if w.num_calls > 0: - candidates.append((c, w)) - return len(candidates) > 0, candidates - - @staticmethod - def check_mock_storage_single( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[ClientState]]: - candidates = [c for c in instance.clients if len(c.workflows) > 0] - return len(candidates) > 0, candidates - - @staticmethod - def sync_all( - instance: "SingleClientSimulator", - ) -> bool: - # require at least some calls to be executed - for client in instance.clients: - if any( - len(df) > 0 - for df in client.storage.rel_adapter.get_all_call_data().values() - ): - return True - return False - - @staticmethod - def query_workflow( - instance: "SingleClientSimulator", - ) -> Tuple[bool, List[Tuple[ClientState, Workflow]]]: - # return tuples of (client, workflow) for which a workflow can be - # queried - candidates = [] - for c in instance.clients: - for w in c.workflows: - if len(w.var_nodes) > 0 and w.is_saturated: - candidates.append((c, w)) - return len(candidates) > 0, candidates - - -class SingleClientSimulator(RuleBasedStateMachine): - def __init__(self, n_clients: int = 1): - super().__init__() - # client = mongomock.MongoClient() - # root = MongoMockRemoteStorage(db_name="test", client=client) - root = None - self.clients = [ClientState(root=root) for _ in range(n_clients)] - # a central storage on which all operations are performed - self.mock_storage = MockStorage() - #! keep everything deterministic. only use `random` for generating - random.seed(0) - - ############################################################################ - ### schema modifications - ############################################################################ - @precondition(lambda machine: Preconditions.create_op(machine)[0]) - @rule() - def create_op(self): - """ - Add a random op to the storage. - """ - client = random.choice(Preconditions.create_op(self)[1]) - new_func_op = make_op( - ui_name=random_string(), - input_names=[random_string() for _ in range(random.randint(1, 3))], - n_outputs=random.randint(0, 3), - defaults={}, - version=0, - deterministic=True, - ) - client.storage.synchronize_op(func_op=new_func_op) - client.func_ops.append(new_func_op) - for mock_storage in [client.mock_storage, self.mock_storage]: - mock_storage.create_op(func_op=new_func_op) - mock_storage.check_invariants() - - @precondition(lambda machine: Preconditions.add_input(machine)[0]) - @rule() - def add_input(self): - candidates = Preconditions.add_input(self)[1] - client, func_op, idx = random.choice(candidates) - # idx = random.randint(0, len(self._func_ops) - 1) - # func_op = self._func_ops[idx] - f = func_op.func - sig = func_op.sig - # simulate update using low-level API - new_name = random_string() - default_value = 23 - new_sig = sig.create_input(name=new_name, default=default_value, annotation=Any) - # TODO: provide a new function with extra input as a user would - new_func_op = FuncOp._from_data(func=make_func_from_sig(new_sig), sig=new_sig) - client.storage.synchronize_op(func_op=new_func_op) - client.func_ops[idx] = new_func_op - new_sig = new_func_op.sig - for mock_storage in [client.mock_storage, self.mock_storage]: - mock_storage.add_input( - func_op=new_func_op, - internal_name=new_sig.ui_to_internal_input_map[new_name], - default_value=default_value, - default_full_uid=new_sig.new_ui_input_default_uids[new_name], - ) - mock_storage.check_invariants() - - @precondition(lambda machine: Preconditions.rename_func(machine)[0]) - @rule() - def rename_func(self): - candidates = Preconditions.rename_func(self)[1] - client, func_op, idx = random.choice(candidates) - # idx = random.randint(0, len(self._func_ops) - 1) - # func_op = self._func_ops[idx] - new_name = random_string() - # find and rename all versions - all_versions = [ - (i, other_func_op) - for i, other_func_op in enumerate(client.func_ops) - if other_func_op.sig.internal_name == func_op.sig.internal_name - ] - rename_done = False - for version_idx, func_op_version in all_versions: - if not rename_done: - # use the API the user would use to rename. This will rename all - # versions. - func_interface = FuncInterface(func_op=func_op_version) - client.storage.synchronize(f=func_interface) - new_sig = client.storage.rename_func( - func=func_interface, new_name=new_name - ) - rename_done = True - else: - # after the rename, get the true signature from the storage - new_sig = client.storage.sig_adapter.load_state()[ - func_op_version.sig.internal_name, func_op_version.sig.version - ] - # now update the state of the simulator - new_func_op_version = FuncOp._from_data( - func=func_op_version.func, sig=new_sig - ) - client.storage.synchronize_op(func_op=new_func_op_version) - client.func_ops[version_idx] = new_func_op_version - client.num_func_renames += 1 - - @precondition(lambda machine: Preconditions.rename_input(machine)[0]) - @rule() - def rename_input(self): - candidates = Preconditions.rename_input(self)[1] - client, func_op, func_idx = random.choice(candidates) - # func_idx = random.randint(0, len(self._func_ops) - 1) - # func_op = self._func_ops[func_idx] - input_to_rename = random.choice(sorted(list(func_op.sig.input_names))) - new_name = random_string() - # use the API the user would use to rename - func_interface = FuncInterface(func_op=func_op) - client.storage.synchronize(f=func_interface) - new_sig = client.storage.rename_arg( - func=func_interface, - name=input_to_rename, - new_name=new_name, - ) - # now update the state of the simulator - new_func = make_func_from_sig(sig=new_sig) - new_func_op = FuncOp._from_data(func=new_func, sig=new_sig) - client.storage.synchronize_op(func_op=new_func_op) - client.func_ops[func_idx] = new_func_op - client.num_input_renames += 1 - - @precondition(lambda machine: Preconditions.create_new_version(machine)[0]) - @rule() - def create_new_version(self): - candidates = Preconditions.create_new_version(self)[1] - client, func_op, func_idx = random.choice(candidates) - # func_idx = random.randint(0, len(self._func_ops) - 1) - # func_op = self._func_ops[func_idx] - latest_version = client.storage.sig_adapter.get_latest_version(sig=func_op.sig) - new_version = latest_version.version + 1 - new_func_op = make_op( - ui_name=func_op.sig.ui_name, - input_names=[random_string() for _ in range(random.randint(1, 3))], - n_outputs=random.randint(0, 3), - defaults={}, - version=new_version, - deterministic=True, - ) - client.storage.synchronize_op(func_op=new_func_op) - client.func_ops.append(new_func_op) - for mock_storage in [client.mock_storage, self.mock_storage]: - mock_storage.create_new_version(new_version=new_func_op) - mock_storage.check_invariants() - - ############################################################################ - ### generating workflows - ############################################################################ - @precondition(lambda machine: Preconditions.add_workflow(machine)[0]) - @rule() - def add_workflow(self): - """ - Add a new (empty) workflow to the test. - """ - candidates = Preconditions.add_workflow(self)[1] - client = random.choice(candidates) - res = Workflow() - client.workflows.append(res) - - @precondition(lambda machine: Preconditions.add_input_var_to_workflow(machine)[0]) - @rule() - def add_input_var_to_workflow(self): - candidates = Preconditions.add_input_var_to_workflow(self)[1] - client, workflow = random.choice(candidates) - # workflow = random.choice([w for w in self._workflows if w.shape_size < 5]) - # always add a value to make sampling proceed faster - var = workflow.add_var() - workflow.add_value(value=wrap_atom(get_uid()), var=var) - - @precondition(lambda machine: Preconditions.add_op_to_workflow(machine)[0]) - @rule() - def add_op_to_workflow(self): - """ - Add an instance of some op to some workflow. - """ - candidates = Preconditions.add_op_to_workflow(self)[1] - client, workflow, func_op = random.choice(candidates) - # func_op = random.choice(self._func_ops) - # workflow = random.choice([w for w in self._workflows if len(w.var_nodes) > 0]) - # pick inputs randomly from workflow - inputs = { - name: random.choice(workflow.var_nodes) for name in func_op.sig.input_names - } - # add function over these inputs - _, _ = workflow.add_op(inputs=inputs, func_op=func_op) - - @precondition(lambda machine: Preconditions.add_call_to_workflow(machine)[0]) - @rule() - def add_call_to_workflow(self): - candidates = Preconditions.add_call_to_workflow(self)[1] - client, workflow, op_node = random.choice(candidates) - input_vars = op_node.inputs - # pick random values - var_to_values = workflow.var_to_values() - input_values = { - name: random.choice(var_to_values[var]) for name, var in input_vars.items() - } - func_op = op_node.func_op - output_types = [Type.from_annotation(a) for a in func_op.output_annotations] - outputs = [make_delayed(tp=tp) for tp in output_types] - call_struct = CallStruct( - func_op=op_node.func_op, inputs=input_values, outputs=outputs - ) - workflow.add_call_struct(call_struct=call_struct) - - def _execute_workflow(self, client: ClientState, workflow: Workflow) -> List[Call]: - client.storage.sync_from_remote() - client.mock_storage.sync_from_other(other=self.mock_storage) - calls = SimpleWorkflowExecutor().execute( - workflow=workflow, storage=client.storage - ) - client.storage.commit(calls=calls) - client.storage.sync_to_remote() - for mock_storage in [client.mock_storage, self.mock_storage]: - for call in calls: - mock_storage.add_call(call=call) - mock_storage.check_invariants() - return calls - - @precondition(lambda machine: Preconditions.execute_workflow(machine)[0]) - @rule() - def execute_workflow(self): - candidates = Preconditions.execute_workflow(self)[1] - client, workflow = random.choice(candidates) - calls = self._execute_workflow(client=client, workflow=workflow) - # client.storage.sync_from_remote() - # calls = SimpleWorkflowExecutor().execute( - # workflow=workflow, storage=client.storage - # ) - # client.storage.commit(calls=calls) - # client.storage.sync_to_remote() - - ############################################################################ - ### multi-client rules - ############################################################################ - @rule() - def sync_one(self): - client = random.choice(self.clients) - client.storage.sync_with_remote() - - @precondition(lambda machine: Preconditions.check_mock_storage_single(machine)[0]) - @rule() - def check_mock_storage_single(self): - # run all workflows for a given client and check that the state equals - # the state of the mock storage - candidates = Preconditions.check_mock_storage_single(self)[1] - client = random.choice(candidates) - for workflow in client.workflows: - self._execute_workflow(client=client, workflow=workflow) - assert client.mock_storage.compare_with_real(client.storage) - - @precondition(lambda machine: Preconditions.sync_all(machine)) - @rule() - def sync_all(self): - for client in self.clients: - client.storage.sync_with_remote() - client.mock_storage.sync_from_other(other=self.mock_storage) - for client in self.clients: - # have to do it again! - client.storage.sync_with_remote() - client.mock_storage.sync_from_other(other=self.mock_storage) - for client in self.clients: - assert client.mock_storage == self.mock_storage - assert self.mock_storage.compare_with_real(real_storage=client.storage) - - @invariant() - def verify_state(self): - for client in self.clients: - # make sure that functions called on the storage work - client.storage.rel_adapter.get_all_call_data() - for table in client.storage.rel_storage.get_tables(): - client.storage.rel_storage.get_count(table=table) - # check storage invariants - check_invariants(storage=client.storage) - # check invariants on the workflows - for w in client.workflows: - w.check_invariants() - # w.print_shape() - # check the individual signatures - for func_op in client.func_ops: - func_op.sig.check_invariants() - # check the set of signatures - client.storage.sig_adapter.check_invariants() - - def check_sig_synchronization(self): - for client in self.clients: - assert client.storage.root.sigs == client.storage.sig_adapter.load_state() - - # @precondition(lambda machine: Preconditions.query_workflow(machine)[0]) - # @rule() - # def query_workflow(self): - # candidates = Preconditions.query_workflow(self)[1] - # client, workflow = random.choice(candidates) - # # workflow = random.choice([w for w in client.workflows if w.is_saturated]) - # # path = Path(__file__).parent / f"bug.cloudpickle" - # # op_nodes = copy.deepcopy(workflow.op_nodes) - # # db_dump_path = Path(__file__).parent.absolute() / 'db_dump/' - # # self.storage.rel_storage.execute_no_results(query=f"EXPORT DATABASE '{db_dump_path}';") - # # data = (self._ops) - # # with open(path, "wb") as f: - # # cloudpickle.dump(data, f) - # val_queries, op_queries = workflow.var_nodes, workflow.op_nodes - # # workflow.print_shape() - # df = client.storage.execute_query(select_queries=val_queries, engine='naive') - - -class MultiClientSimulator(SingleClientSimulator): - def __init__(self, n_clients: int = 3): - super().__init__(n_clients=n_clients) - - -MAX_EXAMPLES = 100 -STEP_COUNT = 25 - -TestCaseSingle = SingleClientSimulator.TestCase -TestCaseSingle.settings = settings( - max_examples=MAX_EXAMPLES, deadline=None, stateful_step_count=STEP_COUNT -) - -TestCaseMany = MultiClientSimulator.TestCase -TestCaseMany.settings = settings( - max_examples=MAX_EXAMPLES, deadline=None, stateful_step_count=STEP_COUNT -) diff --git a/mandala/tests/test_structs.py b/mandala/tests/test_structs.py deleted file mode 100644 index bf71da5..0000000 --- a/mandala/tests/test_structs.py +++ /dev/null @@ -1,141 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * -from mandala.queries.weaver import * - - -@pytest.mark.parametrize("storage", generate_storages()) -def test_unit(storage): - Config.query_engine = "_test" - - ### lists - @op - def repeat(x: int, times: int = None) -> List[int]: - return [x] * times - - @op - def get_list_mean(nums: List[float]) -> float: - return sum(nums) / len(nums) - - with storage.run(): - lst = repeat(x=42, times=23) - a = lst[0] - x = get_list_mean(nums=lst) - y = get_list_mean(nums=lst[:10]) - assert unwrap(lst[1]) == 42 - assert len(lst) == 23 - storage.rel_adapter.obj_get(uid=lst.uid) - - with storage.query(): - x = Q().named("x") - lst = repeat(x).named("lst") - idx = Q().named("idx") - elt = BuiltinQueries.GetListItemQuery(lst=lst, idx=idx).named("elt") - df = storage.df(x, lst, elt, idx) - assert df.shape == (23, 4) - assert all(df["elt"] == 42) - assert sorted(df["idx"]) == list(range(23)) - - # test list constructor - with storage.query(): - # a query for all the lists whose mean we've taken - lst = BuiltinQueries.ListQ(elts=[Q()]).named("lst") - x = get_list_mean(nums=lst).named("x") - df = storage.df(lst, x) - - # test syntax sugar - with storage.query(): - x = Q().named("x") - lst = repeat(x).named("lst") - first_elt = lst[Q()].named("first_elt") - df = storage.df(x, lst, first_elt) - - ### dicts - @op - def get_dict_mean(nums: Dict[str, float]) -> float: - return sum(nums.values()) / len(nums) - - @op - def describe_sequence(seq: List[int]) -> Dict[str, float]: - return { - "min": min(seq), - "max": max(seq), - "mean": sum(seq) / len(seq), - } - - with storage.run(): - dct = describe_sequence(seq=[1, 2, 3]) - dct_mean = get_dict_mean(nums=dct) - dct_mean_2 = get_dict_mean(nums={"a": dct["min"]}) - assert unwrap(dct_mean) == 2.0 - assert unwrap(dct["min"]) == 1 - assert len(dct) == 3 - storage.rel_adapter.obj_get(uid=dct.uid) - - with storage.query(): - seq = Q().named("seq") - dct = describe_sequence(seq=seq).named("dct") - dct_mean = get_dict_mean(nums=dct).named("dct_mean") - df = storage.df(seq, dct, dct_mean) - - # test dict constructor - with storage.query(): - # a query for all the dicts whose mean we've taken - dct = BuiltinQueries.DictQ(dct={Q(): Q()}).named("dct") - dct_mean = get_dict_mean(nums=dct).named("dct_mean") - df = storage.df(dct, dct_mean) - - # test syntax sugar - # with storage.query(): - # seq = Q().named("seq") - # dct = describe_sequence(seq=seq).named("dct") - # dct_mean = dct["mean"].named("dct_mean") - # df = storage.df(seq, dct, dct_mean) - - ### sets - @op - def mean_set(nums: Set[float]) -> float: - return sum(nums) / len(nums) - - @op - def get_prime_factors(num: int) -> Set[int]: - factors = set() - for i in range(2, num): - while num % i == 0: - factors.add(i) - num /= i - return factors - - with storage.run(): - factors = get_prime_factors(num=42) - factors_mean = mean_set(nums=factors) - assert unwrap(factors_mean) == 4.0 - assert len(factors) == 3 - storage.rel_adapter.obj_get(uid=factors.uid) - - with storage.query(): - num = Q().named("num") - factors = BuiltinQueries.SetQ(elts={num}).named("factors") - factors_mean = mean_set(nums=factors).named("factors_mean") - df = storage.df(num, factors, factors_mean) - - -@pytest.mark.parametrize("storage", generate_storages()) -def test_nested(storage): - @op - def sum_rows(mat: List[List[float]]) -> List[float]: - return [sum(row) for row in mat] - - with storage.run(): - mat = [[1, 2, 3], [4, 5, 6]] - sums = sum_rows(mat=mat) - mat_2 = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - sums_2 = sum_rows(mat=mat_2) - assert sums_2[0].uid == sums[0].uid - assert sums_2[1].uid == sums[1].uid - - with storage.query(): - x = Q().named("x") - row = BuiltinQueries.ListQ(elts=[x]).named("row") - mat = BuiltinQueries.ListQ(elts=[row]).named("mat") - row_sums = sum_rows(mat=mat).named("row_sums") - df = storage.df(x, row, mat, row_sums) diff --git a/mandala/tests/test_superops.py b/mandala/tests/test_superops.py deleted file mode 100644 index be71443..0000000 --- a/mandala/tests/test_superops.py +++ /dev/null @@ -1,95 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * -from mandala.queries.weaver import * - - -@pytest.mark.parametrize("storage", generate_storages()) -def test_unit(storage): - @op - def f(x: int, y: int) -> int: - return x + y - - @op - def mean(nums: List[int]) -> float: - return sum(nums) / len(nums) - - @op - def dictmean(nums: Dict[str, Any]) -> float: - return sum(nums.values()) / len(nums) - - @op - def repeat(num: int, times: int) -> List[int]: - return [num for something in range(times)] - - @op - def make_dict(a: int, b: int) -> Dict[str, int]: - return {"a": a, "b": b} - - @op - def swap(x: int, y: int) -> Tuple[int, int]: - return y, x - - @op - def concat_lists(a: List[Any], b: List[Any]) -> List[Any]: - return a + b - - @superop - def workflow(a: int, b: int, c: int) -> List[Any]: - dct = {"a": a, "b": b} - dct_mean = dictmean(nums=dct) - things = repeat(num=a, times=b) - things_mean = mean(things) - new = {"x": dct_mean, "y": things_mean} - final = dictmean(nums=new) - x, y = swap(x=final, y=c) - return [x, y] - - @superop - def super_workflow(a: int, b: int) -> List[Any]: - things = workflow(a, a, a) - other_things = workflow(b, b, b) - return concat_lists(things, other_things) - - with storage.run(): - a = f(23, 42) - b = f(4, 8) - c = f(15, 16) - avg = mean([a, b, c]) - sames = repeat(num=23, times=5) - elt = sames[3] - dict_avg = dictmean({"a": 23, "b": 42}) - dct = make_dict(a=23, b=42) - z = workflow(a=23, b=42, c=10) - w = super_workflow(a=a, b=b) - - with storage.run(): - a = f(23, 42) - b = f(4, 8) - c = f(15, 16) - avg = mean([a, b, c]) - sames = repeat(num=23, times=5) - elt = sames[3] - dict_avg = dictmean({"a": 23, "b": 42}) - dct = make_dict(a=23, b=42) - z = workflow(a=23, b=42, c=10) - w = super_workflow(a=a, b=b) - - -@pytest.mark.parametrize("storage", generate_storages()) -def test_mutual_recursion(storage): - @superop - def f(x: int) -> int: - if unwrap(x) == 0: - return 0 - else: - return g(unwrap(x) - 1) - - @superop - def g(x: int) -> int: - if unwrap(x) == 0: - return 0 - else: - return f(unwrap(x) - 1) - - with storage.run(): - a = f(23) diff --git a/mandala/tests/test_transient.py b/mandala/tests/test_transient.py deleted file mode 100644 index f883fef..0000000 --- a/mandala/tests/test_transient.py +++ /dev/null @@ -1,66 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * - - -def test_unit(): - storage = Storage() - - @op - def f(x) -> int: - return Transient(x + 1) - - with storage.run(attach_call_to_outputs=True): - a = f(42) - call: Call = a._call - assert call.transient - assert a.transient - assert unwrap(a) == 43 - - with storage.run(attach_call_to_outputs=True): - a = f(42) - call: Call = a._call - # assert not a.in_memory - no longer true b/c caching - assert a.transient - assert call.transient - - with storage.run(recompute_transient=True, attach_call_to_outputs=True): - a = f(42) - call: Call = a._call - - assert a.in_memory - assert a.transient - assert call.transient - - -def test_composition(): - storage = Storage() - - @op - def f(x) -> int: - return Transient(x + 1) - - @op - def g(x) -> int: - return Transient(x**2) - - with storage.run(): - a = f(42) - b = g(a) - - with storage.run(): - a = f(23) - - storage.cache.evict_all() - - try: - with storage.run(): - a = f(23) - b = g(a) - assert False - except ValueError as e: - assert True - - with storage.run(recompute_transient=True): - a = f(23) - b = g(a) - assert unwrap(b) == 576 diff --git a/mandala/_next/tests/test_versioning.py b/mandala/tests/test_versioning.py similarity index 100% rename from mandala/_next/tests/test_versioning.py rename to mandala/tests/test_versioning.py diff --git a/mandala/tests/test_versions.py b/mandala/tests/test_versions.py deleted file mode 100644 index 1b2fac0..0000000 --- a/mandala/tests/test_versions.py +++ /dev/null @@ -1,30 +0,0 @@ -from mandala.all import * -from mandala.tests.utils import * - - -def test_unit(): - storage = Storage() - - ############################################################################ - ### unit test - ############################################################################ - - @op - def inc(x: int) -> int: - return x + 1 - - with storage.run(): - inc(23) - - @op(version=1) - def inc(x: int) -> int: - return x + 1 - - # check unsynchronized functions work for this - storage.sig_adapter.get_versions(sig=inc.func_op.sig) - - with storage.run(): - inc(23) - - sigs = [v for k, v in storage.sig_adapter.load_state().items() if not v.is_builtin] - assert len(sigs) == 2 diff --git a/mandala/tests/utils.py b/mandala/tests/utils.py deleted file mode 100644 index c2ab0bb..0000000 --- a/mandala/tests/utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import uuid -import pytest - -from ..common_imports import * -from mandala.all import * -from mandala.ui.storage import Storage -from mandala.ui.storage import MODES -from mandala.core.config import Config, is_output_name -from mandala.core.wrapping import compare_dfs_as_relations - -# from mandala.storages.remote_impls.mongo_impl import MongoRemoteStorage -# from mandala.storages.remote_impls.mongo_mock import MongoMockRemoteStorage -from mandala.storages.rels import RelAdapter - - -def generate_db_path() -> Path: - output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output") - fname = str(uuid.uuid4()) + ".db" - return output_dir / fname - - -def generate_spillover_dir() -> Path: - output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output") - fname = str(uuid.uuid4()) - return output_dir / fname - - -def generate_path(ext: str) -> Path: - output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output") - fname = str(uuid.uuid4()) + ext - return output_dir / fname - - -def generate_storages() -> List[Storage]: - results = [] - for db_backend in ("sqlite",): - for persistent in (True, False): - for spillover in (True, False): - results.append( - Storage( - db_backend=db_backend, - db_path=generate_db_path() if persistent else None, - spillover_dir=generate_spillover_dir() if spillover else None, - spillover_threshold_mb=0, - ) - ) - return results - - -def signatures_are_equal(storage_1: Storage, storage_2: Storage) -> bool: - sigs_1 = storage_1.sig_adapter.load_state() - sigs_2 = storage_2.sig_adapter.load_state() - if sigs_1.keys() != sigs_2.keys(): - return False - for (internal_name, version), sig_1 in sigs_1.items(): - sig_2 = sigs_2[internal_name, version] - if sig_1 != sig_2: - return False - return True - - -def data_is_equal( - storage_1: Storage, storage_2: Storage, return_reason: bool = False -) -> Union[bool, Tuple[bool, str]]: - data_1 = storage_1.rel_storage.get_all_data() - data_2 = storage_2.rel_storage.get_all_data() - #! remove some internal tables from the comparison - for _internal_table in [RelAdapter.DEPS_TABLE, RelAdapter.PROVENANCE_TABLE]: - if _internal_table in data_1: - data_1.pop(_internal_table) - data_2.pop(_internal_table) - # compare the keys - if data_1.keys() != data_2.keys(): - result, reason = False, f"Tables differ: {data_1.keys()} vs {data_2.keys()}" - # compare the signatures - sigs_1 = storage_1.sig_adapter.load_state() - sigs_2 = storage_2.sig_adapter.load_state() - if sigs_1.keys() != sigs_2.keys(): - result, reason = ( - False, - f"Signature keys differ: {sigs_1.keys()} vs {sigs_2.keys()}", - ) - if sigs_1 != sigs_2: - result, reason = False, f"Signatures differ: {sigs_1} vs {sigs_2}" - # compare the data - elementwise_comparisons = { - k: compare_dfs_as_relations(data_1[k], data_2[k], return_reason=True) - for k in data_1.keys() - if k != Config.schema_table - } - if all(result for result, _ in elementwise_comparisons.values()): - result, reason = True, "" - else: - result, reason = ( - False, - f"Found differences between tables: {elementwise_comparisons}", - ) - if return_reason: - return result, reason - else: - return result - - -def check_invariants(storage: Storage): - # check that signatures match tables - ui_sigs = storage.sig_adapter.load_ui_sigs() - call_tables = storage.rel_adapter.get_call_tables() - columns_by_table = {} - # collect table columns - for call_table in call_tables: - columns = storage.rel_storage._get_cols(relation=call_table) - columns_by_table[call_table] = columns - # check that all signatures are accounted for in the tables - for sig in ui_sigs.values(): - table_name = sig.versioned_ui_name - assert table_name in columns_by_table - columns = columns_by_table[table_name] - assert sig.input_names.issubset(set(columns)) - # check that all tables are accounted for in the signatures - for call_table, columns in columns_by_table.items(): - input_cols = [ - col - for col in columns - if not is_output_name(col) and col not in Config.special_call_cols - ] - ui_name, version = Signature.parse_versioned_name(versioned_name=call_table) - assert set(input_cols).issubset(ui_sigs[ui_name, version].input_names) diff --git a/mandala/_next/tps.py b/mandala/tps.py similarity index 100% rename from mandala/_next/tps.py rename to mandala/tps.py diff --git a/mandala/_next/tutorials/gotchas.ipynb b/mandala/tutorials/gotchas.ipynb similarity index 100% rename from mandala/_next/tutorials/gotchas.ipynb rename to mandala/tutorials/gotchas.ipynb diff --git a/mandala/_next/tutorials/hello.ipynb b/mandala/tutorials/hello.ipynb similarity index 100% rename from mandala/_next/tutorials/hello.ipynb rename to mandala/tutorials/hello.ipynb diff --git a/mandala/_next/tutorials/ml.ipynb b/mandala/tutorials/ml.ipynb similarity index 100% rename from mandala/_next/tutorials/ml.ipynb rename to mandala/tutorials/ml.ipynb diff --git a/mandala/_next/ui.py b/mandala/ui.py similarity index 100% rename from mandala/_next/ui.py rename to mandala/ui.py diff --git a/mandala/ui/cfs.py b/mandala/ui/cfs.py deleted file mode 100644 index bcf9d3e..0000000 --- a/mandala/ui/cfs.py +++ /dev/null @@ -1,895 +0,0 @@ -import textwrap -from ..common_imports import * -from ..core.config import * -from ..core.tps import Type -from ..core.model import Ref, FuncOp, Call -from ..core.builtins_ import StructOrientations, Builtins -from ..storages.rel_impls.utils import Transactable, transaction, Connection -from ..queries.graphs import copy_subgraph -from ..core.prov import propagate_struct_provenance -from ..queries.weaver import CallNode, ValNode -from ..queries.viz import GraphPrinter, visualize_graph -from .funcs import FuncInterface -from .storage import Storage, ValueLoader -from .cfs_utils import estimate_uid_storage, convert_bytes_to - - -class ComputationFrame(Transactable): - """ - In-memory, dynamic representation of a slice of storage representing some - computation, with methods for indexing, evaluating (i.e. loading from - storage), and navigating back/forward along computation paths. These methods - turn a `ComputationFrame` into a generalized dataframe over a computational - graph, with columns corresponding to variables in the computation, and rows - corresponding to values of those variables for a single instance of the - computation. - - This is a simple declarative interface for exploring the storage, in - contrast with using an imperative computational context (i.e., manipulating - some memoized piece of code to interface with storage). - - The main differences between a `ComputationFrame` and a `DataFrame` are as - follows: - - by default, the `ComputationFrame` is lazy, i.e. it does not load the - values of the variables it represents from the storage, but only their - metadata in the form of `Ref` objects; - - the `eval(vars)` method allows for loading the values of chosen - variables from the storage, and returns an ordinary `pandas` dataframe - with columns corresponding to the variables; - - the `forward(vars)` and `back(vars)` methods allow for navigation - along the computation graph, by expanding the graph to include the - operation(s) that created/used given variables represented in the - `ComputationFrame`; - - the `creators(var)` and `consumers(var)` methods allow for inspecting - the operations that created/used a variable (including ones not - currently represented in the `ComputationFrame`); - - It has several important limitations: - - the main one being that it can only represent a single computation in a - given instance (i.e., a single composition of functions). This comes at the - benefit of simplicity and ease of use. - - when it finds a data structure that was constructed "by hand" out of - multiple elements, it picks one of them arbitrarily to represent the - computational history of the entire structure. Sometimes this makes sense - (e.g., when the structure is a list of values with analogous provenance), - sometimes it doesn't. - - Similarly, when it finds a data structure that is the output of a - computation, it picks one of its elements arbitrarily to represent the - computational continuation of the entire structure. Dual caveats apply. - """ - - def __init__( - self, - call_nodes: Dict[str, CallNode], - val_nodes: Dict[str, ValNode], - storage: Storage, - prov_df: Optional[pd.DataFrame] = None, - ): - self.op_nodes = call_nodes - self.var_nodes = val_nodes - - self.storage = storage - if prov_df is None: - prov_df = storage.rel_storage.get_data(table=Config.provenance_table) - prov_df = propagate_struct_provenance(prov_df) - self.prov_df = prov_df - - def __len__(self) -> int: - if len(self.op_nodes) > 0: - return len(self.op_nodes[list(self.op_nodes.keys())[0]].calls) - if len(self.var_nodes) > 0: - return len(self.var_nodes[list(self.var_nodes.keys())[0]].refs) - return 0 - - def to_pandas( - self, - columns: Optional[Union[str, Iterable[str]]] = None, - values: Literal["refs", "lazy", "objs"] = "lazy", - ) -> Union[pd.Series, pd.DataFrame]: - """ - Extract a series/dataframe of refs from the RefFunctor, analogous to - indexing into a pandas dataframe. - """ - if columns is None: - columns = ( - list(self.var_nodes.keys()) - if len(self.var_nodes) > 1 - else list(self.var_nodes.keys()) - ) - if isinstance(columns, str): - refs = pd.Series(self.var_nodes[columns].refs, name=columns) - full_uids_df = refs.apply(lambda x: x.full_uid) - res_df = self.storage.eval_df( - full_uids_df=full_uids_df.to_frame(), values=values - ) - # turn into a series - return res_df[columns] - elif isinstance(columns, list) and all(isinstance(x, str) for x in columns): - refs = pd.DataFrame( - {col: self.var_nodes[col].refs for col in columns} - ) - return self.storage.eval_df( - full_uids_df=refs.applymap( - lambda x: x.full_uid if x is not None else None - ), - values=values, - ) - else: - raise ValueError(f"Invalid columns type: {type(columns)}") - - def __getitem__( - self, indexer: Union[str, Iterable[str], np.ndarray] - ) -> "ComputationFrame": - """ - Analogous to pandas __getitem__, but tries to return a `ComputationFrame` - """ - if isinstance(indexer, str): - if indexer in self.var_nodes: - return self.copy_subgraph( - val_nodes=[self.var_nodes[indexer]], call_nodes=[] - ) - elif indexer in self.op_nodes: - return self.copy_subgraph( - val_nodes=[], call_nodes=[self.op_nodes[indexer]] - ) - else: - raise ValueError( - f"Column {indexer} not found in variables or operations" - ) - elif isinstance(indexer, list) and all(isinstance(x, str) for x in indexer): - var_keys = [x for x in self.var_nodes.keys() if x in indexer] - call_keys = [x for x in self.op_nodes.keys() if x in indexer] - return self.copy_subgraph( - val_nodes={self.var_nodes[k] for k in var_keys}, - call_nodes={self.op_nodes[k] for k in call_keys}, - ) - elif isinstance(indexer, (np.ndarray, pd.Series)): - if isinstance(indexer, pd.Series): - indexer = indexer.values - # boolean mask - if indexer.dtype == bool: - res = self.copy_subgraph() - for k, v in res.var_nodes.items(): - v.inplace_mask(indexer) - for k, v in res.op_nodes.items(): - v.inplace_mask(indexer) - return res - else: - raise NotImplementedError( - "Indexing with a non-boolean mask is not supported" - ) - else: - raise ValueError( - f"Invalid indexer type into {self.__class__}: {type(indexer)}" - ) - - @transaction() - def eval( - self, - indexer: Optional[Union[str, List[str]]] = None, - conn: Optional[Connection] = None, - ) -> Union[pd.Series, pd.DataFrame]: - if indexer is None: - indexer = list(self.var_nodes.keys()) - if isinstance(indexer, str): - full_uids_df = ( - self[indexer] - .to_pandas() - .applymap(lambda x: x.full_uid if x is not None else None) - ) - res_df = self.storage.eval_df(full_uids_df=full_uids_df, values="objs") - res = res_df[indexer] - return res - else: - full_uids_df = ( - self[indexer] - .to_pandas() - .applymap(lambda x: x.full_uid if x is not None else None) - ) - return self.storage.eval_df( - full_uids_df=full_uids_df, values="objs", conn=conn - ) - - @transaction() - def creators(self, col: str, conn: Optional[Connection] = None) -> np.ndarray: - calls, output_names = self.storage.get_creators( - refs=self.var_nodes[col].refs, prov_df=self.prov_df, conn=conn - ) - return np.array( - [ - call.func_op.sig.versioned_ui_name if call is not None else None - for call in calls - ] - ) - - @transaction() - def consumers(self, col: str, conn: Optional[Connection] = None) -> np.ndarray: - calls_list, input_names_list = self.storage.get_consumers( - refs=self.var_nodes[col].refs, prov_df=self.prov_df, conn=conn - ) - res = np.empty(len(calls_list), dtype=object) - res[:] = [ - tuple( - [ - call.func_op.sig.versioned_ui_name if call is not None else None - for call in calls - ] - ) - for calls in calls_list - ] - return res - - @transaction() - def get_adjacent_calls( - self, - col: str, - direction: Literal["back", "forward"], - conn: Optional[Connection] = None, - ) -> Dict[Tuple[str, str], List[Optional[Call]]]: - """ - Given a column and a direction to traverse the graph (back or forward), - return the calls that created/used the values in the column, along with - the output/input names under which the values appear in the calls. - - The calls are grouped by the operation and the output/input name under - which the values in this column appear in the calls. - """ - refs: List[Ref] = self.var_nodes[col].refs - if direction == "back": - calls_list, names_list = self.storage.get_creators( - refs=refs, prov_df=self.prov_df, conn=conn - ) - calls_list = [[c] if c is not None else [] for c in calls_list] - names_list = [[n] if n is not None else [] for n in names_list] - elif direction == "forward": - calls_list, names_list = self.storage.get_consumers( - refs=refs, prov_df=self.prov_df, conn=conn - ) - else: - raise ValueError(f"Unknown direction: {direction}") - index_dict = defaultdict(list) # (op_id, input/output name) -> (idx, call) - # for i, (calls, names) in enumerate(zip(calls_list, names_list)): - for i in range(len(refs)): - calls = calls_list[i] - names = names_list[i] - for call, name in zip(calls, names): - index_dict[(call.func_op.sig.versioned_ui_name, name)].append((i, call)) - res = {} - for (op_id, name), indices_and_calls in index_dict.items(): - calls_list = [None for _ in range(len(refs))] - for i, call in indices_and_calls: - calls_list[i] = call - res[(op_id, name)] = calls_list - return res - - @transaction() - def _back_all( - self, - res: "ComputationFrame", - inplace: bool = False, - verbose: bool = False, - conn: Optional[Connection] = None, - ) -> "ComputationFrame": - # this means we want to expand the entire graph - node_frontier = res.var_nodes.keys() - visited = set() - while True: - res = res.back( - cols=list(node_frontier), - inplace=inplace, - skip_failures=True, - verbose=verbose, - conn=conn, - ) - visited |= node_frontier - nodes_after = set(res.var_nodes.keys()) - node_frontier = nodes_after - visited - if not node_frontier: - break - return res - - def join_var_node( - self, - refs: List[Ref], - tp: Type, - name_hint: Optional[str] = None, - ) -> Tuple[str, ValNode]: - refs_hash = ValNode.get_refs_hash(refs=refs) - for var_name, var_node in self.var_nodes.items(): - if var_node.refs_hash == refs_hash: - return var_name, var_node - else: - res = ValNode( - tp=tp, - refs=refs, - constraint=None, - ) - res_name = self.get_new_vname(hint=name_hint) - self.var_nodes[res_name] = res - return res_name, res - - def join_op_node( - self, - calls: List[Call], - out_map: Optional[Dict[str, str]] = None, # output name -> var name - in_map: Optional[Dict[str, str]] = None, # input name -> var name - ) -> Tuple[str, CallNode]: - """ - Join an op node to the graph. If a node with this hash already exists, - only connect any not yet connected inputs/outputs; otherwise, create a - new node and then connect inputs/outputs. - """ - out_map = {} if out_map is None else out_map - in_map = {} if in_map is None else in_map - calls_hash = CallNode.get_calls_hash(calls=calls) - for op_name, op_node in self.op_nodes.items(): - if op_node.calls_hash == calls_hash: - res = op_node - res_name = op_name - break - else: - call_representative = calls[0] - op = call_representative.func_op - - #! for struct calls, figure out the orientation; currently ad-hoc - if op.is_builtin: - logging.warning( - f"Found a ref data structure: {op.sig.ui_name}; ComputationFrame support for this is experimental and may not work as expected." - ) - output_names = list(out_map.keys()) - input_names = list(in_map.keys()) - if len(output_names) == 0: - if any(x in input_names for x in ("elt", "value")): - orientation = StructOrientations.construct - elif any(x in input_names for x in ("lst", "dct", "st")): - orientation = StructOrientations.destruct - else: - raise NotImplementedError - else: - if any(x in output_names for x in ("lst", "dct", "st")): - orientation = StructOrientations.construct - elif any(x in output_names for x in ("idx", "key")): - orientation = StructOrientations.destruct - else: - raise NotImplementedError - in_map, out_map = Builtins.reassign_io_using_orientation( - in_dict=in_map, - out_dict=out_map, - orientation=orientation, - builtin_id=op.sig.ui_name, - ) - in_map_for_linking = {**in_map, **out_map} - out_map_for_linking = {} - else: - orientation = None - in_map_for_linking = in_map - out_map_for_linking = out_map - - res = CallNode.link( - calls=calls, - inputs={k: self.var_nodes[v] for k, v in in_map_for_linking.items()}, - outputs={k: self.var_nodes[v] for k, v in out_map_for_linking.items()}, - constraint=None, - func_op=call_representative.func_op, - orientation=orientation, - ) - res_name = self.get_new_cname(op) - self.op_nodes[res_name] = res - for k, v in out_map.items(): - if k not in res.outputs.keys(): - # connect manually - res.outputs[k] = self.var_nodes[v] - self.var_nodes[v].add_creator(creator=res, created_as=k) - for k, v in in_map.items(): - if k not in res.inputs.keys(): - # connect manually - res.inputs[k] = self.var_nodes[v] - self.var_nodes[v].add_consumer(consumer=res, consumed_as=k) - return res_name, res - - @transaction() - def back( - self, - cols: Optional[Union[str, List[str]]] = None, - inplace: bool = False, - skip_failures: bool = False, - verbose: bool = False, - conn: Optional[Connection] = None, - ) -> "ComputationFrame": - res = self if inplace else self.copy_subgraph() - if cols is None: - # this means we want to expand the entire graph - return res._back_all(res, inplace=inplace, verbose=verbose, conn=conn) - if verbose: - logger.info(f"Expanding graph to include the provenance of columns {cols}") - if isinstance(cols, str): - cols = [cols] - - adjacent_calls_data = { - col: res.get_adjacent_calls(col=col, direction="back", conn=conn) - for col in cols - } - - filtered_cols = [] - for col, calls_dict in adjacent_calls_data.items(): - if len(calls_dict) > 1: - reason = f"Values in column {col} were created by multiple ops and/or as different outputs: {[f'{op}::{out}' for op, out in calls_dict.keys()]}" - if skip_failures: - if verbose: - logger.info(f"{reason}; skipping column {col}") - continue - else: - raise ValueError(reason) - if len(calls_dict) == 0 or any(call is None for call in calls_dict[list(calls_dict.keys())[0]]): - reason = f"Some refs in column {col} were not created by any op" - if skip_failures: - if verbose: - logger.info(f"{reason}; skipping column {col}") - continue - else: - raise ValueError(reason) - filtered_cols.append(col) - cols = filtered_cols - - for col in cols: - calls_dict = adjacent_calls_data[col] - op_id, output_name = list(calls_dict.keys())[0] - calls = calls_dict[(op_id, output_name)] - # produce input nodes - representative_call: Call = calls[0] - op = representative_call.func_op - input_node_names = {} - for input_name, input_tp in op.input_types.items(): - input_node_names[input_name], _ = res.join_var_node( - refs=[call.inputs[input_name] for call in calls], - tp=input_tp, - name_hint=input_name,) - res.join_op_node( - calls=calls, - out_map={output_name: col}, - in_map={k: v for k, v in input_node_names.items()}, - ) - - return res - - @transaction() - def forward( - self, - cols: Optional[Union[str, List[str]]] = None, - inplace: bool = False, - skip_failures: bool = False, - verbose: bool = False, - conn: Optional[Connection] = None, - ) -> "ComputationFrame": - res = self if inplace else self.copy_subgraph() - if cols is None: - # this means we want to expand the entire graph - return res._forward_all(res, inplace=inplace, verbose=verbose, conn=conn) - if verbose: - logger.info(f"Expanding graph to include the consumers of columns {cols}") - if isinstance(cols, str): - cols = [cols] - - adjacent_calls_data = { - col: res.get_adjacent_calls(col=col, direction="forward", conn=conn) - for col in cols - } - - filtered_cols = [] - raise NotImplementedError - # for col, calls_dict in adjacent_calls_data.items(): - - # - # filtered_cols.append(col) - # cols = filtered_cols - - # for col in cols: - # calls_dict = adjacent_calls_data[col] - # op_id, input_name = list(calls_dict.keys())[0] - # calls = calls_dict[(op_id, input_name)] - # # produce output nodes - # representative_call: Call = calls[0] - # op = representative_call.func_op - # output_node_names = {} - # for output_name, output_tp in enumerate(op.output_types): - # output_node_names[dump_output_name(i=output_name)], _ = res.join_var_node( - # refs=[call.outputs[dump_output_name(i=output_name)] for call in calls], - # tp=output_tp - - @transaction() - def delete( - self, - delete_dependents: bool, - verbose: bool = True, - ask: bool = True, - conn: Optional[Connection] = None, - ): - """ - ! this is a powerful method that can delete a lot of data, use with caution - - Delete the calls referenced by this ComputationFrame from the storage, and - clean up any orphaned refs. - - Warning: You probably want to apply this only on RefFunctors that are - "forward-closed", i.e., that have been expanded to include all the calls - that use their values. Otherwise, you may end up with obscure refs for - which you have no provenance, i.e. "zombie" refs that have no meaning in - the context of the rest of the storage. Alternatively, you can set - `delete_dependents` to True, which will delete all the calls that depend - on the calls in this ComputationFrame, and then clean up the orphaned refs. - """ - # gather all the calls to be deleted - call_uids_to_delete = defaultdict(list) - call_outputs = {} - for x in self.op_nodes.values(): - # process the call uids - for call in x.calls: - call_uids_to_delete[x.func_op.sig.versioned_ui_name].append( - call.causal_uid - ) - # process the call outputs - if delete_dependents: - for vnode in x.outputs.values(): - for ref in vnode.refs: - call_outputs[ref.causal_uid] = ref - if delete_dependents: - dependent_calls = self.storage.get_dependent_calls( - refs=list(call_outputs.values()), prov_df=self.prov_df, conn=conn - ) - for call in dependent_calls: - call_uids_to_delete[call.func_op.sig.versioned_ui_name].append( - call.causal_uid - ) - if verbose: - # summarize the number of calls per op to be deleted - for op, uids in call_uids_to_delete.items(): - print(f"Op {op} has {len(uids)} calls to be deleted") - if ask: - if input("Proceed? (y/n) ").strip().lower() != "y": - logging.info("Aborting deletion") - return - for versioned_ui_name, call_uids in call_uids_to_delete.items(): - self.storage.rel_adapter.delete_calls( - versioned_ui_name=versioned_ui_name, - causal_uids=call_uids, - conn=conn, - ) - self.storage.rel_adapter.cleanup_vrefs(conn=conn, verbose=verbose) - - ############################################################################ - ### creating new RefFunctors - ############################################################################ - def copy_subgraph( - self, - val_nodes: Optional[Iterable[ValNode]] = None, - call_nodes: Optional[Iterable[CallNode]] = None, - ) -> "ComputationFrame": - """ - Get a copy of the ComputationFrame supported on the given nodes. - """ - # must copy the graph - val_nodes = ( - set(self.var_nodes.values()) if val_nodes is None else set(val_nodes) - ) - call_nodes = ( - set(self.op_nodes.values()) if call_nodes is None else set(call_nodes) - ) - val_map, call_map = copy_subgraph( - vqs=val_nodes, - fqs=call_nodes, - ) - return ComputationFrame( - call_nodes={ - k: call_map[v] for k, v in self.op_nodes.items() if v in call_map - }, - val_nodes={ - k: val_map[v] for k, v in self.var_nodes.items() if v in val_map - }, - storage=self.storage, - prov_df=self.prov_df, - ) - - @staticmethod - def from_refs( - refs: Iterable[Ref], - storage: Storage, - prov_df: Optional[pd.DataFrame] = None, - name: Optional[str] = None, - ) -> "ComputationFrame": - val_node = ValNode( - constraint=None, - tp=None, - # refs=list(refs), - refs=refs, - ) - name = "v0" if name is None else name - return ComputationFrame( - call_nodes={}, - val_nodes={name: val_node}, - storage=storage, - prov_df=prov_df, - ) - - @staticmethod - def from_op( - func: FuncInterface, - storage: Storage, - prov_df: Optional[pd.DataFrame] = None, - ) -> "ComputationFrame": - """ - Get a ComputationFrame expressing the memoization table for a single function - """ - storage.synchronize(f=func) - reftable = storage.get_table(func, values="lazy", meta=True) - op = func.func_op - if op.is_builtin: - raise ValueError("Cannot create a ComputationFrame from a builtin op") - input_nodes = { - input_name: ValNode( - constraint=None, - tp=op.input_types[input_name], - refs=reftable[input_name].values.tolist(), - ) - for input_name in op.input_types.keys() - } - output_nodes = { - dump_output_name(i): ValNode( - constraint=None, - tp=tp, - refs=reftable[dump_output_name(i)].values.tolist(), - ) - for i, tp in enumerate(op.output_types) - } - call_uids = reftable[Config.causal_uid_col].values.tolist() - calls = storage.cache.call_mget( - uids=call_uids, - versioned_ui_name=op.sig.versioned_ui_name, - by_causal=True, - ) - call_node = CallNode.link( - inputs=input_nodes, - func_op=op, - outputs=output_nodes, - constraint=None, - calls=calls, - orientation=None, - ) - return ComputationFrame( - call_nodes={func.func_op.sig.versioned_ui_name: call_node}, - val_nodes={ - k: v - for k, v in itertools.chain(input_nodes.items(), output_nodes.items()) - }, - storage=storage, - prov_df=prov_df, - ) - - ############################################################################ - ### visualization - ############################################################################ - def get_printer(self) -> GraphPrinter: - printer = GraphPrinter( - vqs=set(self.var_nodes.values()), - fqs=set(self.op_nodes.values()), - names={v: k for k, v in self.var_nodes.items()}, - fnames={v: k for k, v in self.op_nodes.items()}, - value_loader=ValueLoader(storage=self.storage), - ) - return printer - - def _get_string_representation(self) -> str: - printer = self.get_printer() - graph_description = printer.print_computational_graph( - show_sources_as="name_only" - ) - # indent the graph description - graph_description = textwrap.indent(graph_description, " ") - return f"{self.__class__.__name__} with {self.num_vars} variable(s), {self.num_ops} operation(s) and {len(self)} row(s), representing the computation:\n{graph_description}" - - def __repr__(self) -> str: - return self._get_string_representation() - # return self.to_pandas().head(5).to_string() - - def print(self): - print(self._get_string_representation()) - - def show(self, how: Literal["inline", "browser"] = "browser"): - visualize_graph( - vqs=set(self.var_nodes.values()), - fqs=set(self.op_nodes.values()), - layout="computational", - names={v: k for k, v in self.var_nodes.items()}, - show_how=how, - ) - - def get_new_vname(self, hint: Optional[str] = None) -> str: - """ - Return the first name of the form `v{i}` that is not in self.val_nodes - """ - if hint is not None and hint not in self.var_nodes: - return hint - i = 0 - prefix = "v" if hint is None else hint - while f"{prefix}{i}" in self.var_nodes: - i += 1 - return f"{prefix}{i}" - - def get_new_cname(self, op: FuncOp) -> str: - if op.sig.versioned_ui_name not in self.op_nodes: - return op.sig.versioned_ui_name - i = 0 - while f"{op.sig.versioned_ui_name}_{i}" in self.op_nodes: - i += 1 - return f"{op.sig.versioned_ui_name}_{i}" - - def rename(self, columns: Dict[str, str], inplace: bool = False): - for old_name, new_name in columns.items(): - if old_name not in self.var_nodes: - raise ValueError(f"Column {old_name} does not exist") - if new_name in self.var_nodes: - raise ValueError(f"Column {new_name} already exists") - if inplace: - res = self - else: - res = self.copy_subgraph() - for old_name, new_name in columns.items(): - res.var_nodes[new_name] = res.var_nodes.pop(old_name) - return res - - def r( - self, - inplace: bool = False, - **kwargs, - ) -> "ComputationFrame": - """ - Fast alias for rename - """ - return self.rename(columns=kwargs, inplace=inplace) - - @property - def num_vars(self) -> int: - return len(self.var_nodes) - - @property - def num_ops(self) -> int: - return len(self.op_nodes) - - def get_var_info( - self, - include_uniques: bool = False, - small_threshold_bytes: int = 4096, - units: Literal["bytes", "KB", "MB", "GB"] = "MB", - sample_size: int = 20, - ) -> pd.DataFrame: - var_rows = [] - for k, v in self.var_nodes.items(): - if len(v.refs) == 0: - avg_size, std = 0, 0 - else: - avg_size_bytes, std_bytes = estimate_uid_storage( - uids=[ref.uid for ref in v.refs], - storage=self.storage, - units="bytes", - sample_size=sample_size, - ) - avg_size, std = convert_bytes_to( - num_bytes=avg_size_bytes, units=units - ), convert_bytes_to(num_bytes=std_bytes, units=units) - # round to 2 decimal places - avg_size, std = round(avg_size, 2), round(std, 2) - var_data = { - "name": k, - "size": f"{avg_size}±{std} {units}", - "nunique": len(set(ref.uid for ref in v.refs)), - } - if include_uniques: - if avg_size_bytes < small_threshold_bytes: - uniques = {ref.uid: ref for ref in v.refs} - uniques_values = self.storage.unwrap(list(uniques.values())) - try: - uniques_values = sorted(uniques_values) - except: - pass - var_data["unique_values"] = uniques_values - else: - var_data["unique_values"] = "" - var_rows.append(var_data) - var_df = pd.DataFrame(var_rows) - var_df.set_index("name", inplace=True) - var_df = var_df.sort_values(by="size", ascending=False) - return var_df - - def get_op_info(self) -> pd.DataFrame: - rows = [] - for k, op_node in self.op_nodes.items(): - input_types = op_node.func_op.input_types - output_types = { - dump_output_name(index=i): op_node.func_op.output_types[i] - for i in range(len(op_node.func_op.output_types)) - } - input_types_dict = { - k: input_types.get(k, output_types.get(k)) - for k in op_node.inputs.keys() - } - output_types_dict = { - k: output_types.get(k, input_types.get(k)) - for k in op_node.outputs.keys() - } - signature = f'{op_node.func_op.sig.ui_name}({", ".join([f"{k}: {v}" for k, v in input_types_dict.items()])}) -> {", ".join([f"{k}: {v}" for k, v in output_types_dict.items()])}' - rows.append( - { - "name": k, - "function": op_node.func_op.sig.ui_name, - "version": op_node.func_op.sig.version, - "num_calls": len(op_node.calls), - "num_unique_calls": len( - set(call.causal_uid for call in op_node.calls) - ), - "signature": signature, - } - ) - op_df = pd.DataFrame(rows) - op_df.set_index("name", inplace=True) - return op_df - - def info( - self, - units: Literal["bytes", "KB", "MB", "GB"] = "MB", - sample_size: int = 20, - show_uniques: bool = False, - small_threshold_bytes: int = 4096, - ): - """ - Print some basic info about the ComputationFrame - """ - # print(self.__class__) - var_df = self.get_var_info( - include_uniques=show_uniques, - small_threshold_bytes=small_threshold_bytes, - units=units, - sample_size=sample_size, - ) - op_df = self.get_op_info() - print( - f"{self.__class__.__name__} with {self.num_vars} variable(s), {self.num_ops} operation(s), {len(self)} row(s)" - ) - printer = self.get_printer() - print("Computation graph:") - print( - textwrap.indent( - printer.print_computational_graph(show_sources_as="name_only"), " " - ) - ) - try: - print("Variables:") - import prettytable - from io import StringIO - - output = StringIO() - var_df.to_csv(output) - output.seek(0) - pt = prettytable.from_csv(output) - print(textwrap.indent(pt.get_string(), " ")) - print("Operations:") - output = StringIO() - op_df.to_csv(output) - output.seek(0) - pt = prettytable.from_csv(output) - print(textwrap.indent(pt.get_string(), " ")) - except ImportError: - print("Variables:") - print(textwrap.indent(var_df.to_string(), " ")) - print("Operations:") - print(textwrap.indent(op_df.to_string(), " ")) - # representative_node = self.val_nodes[list(self.val_nodes.keys())[0]] - # num_rows = len(representative_node.refs) - # print(f"ComputationFrame with {self.num_vars} variable(s) and {self.num_ops} operations(s), representing {num_rows} computations") - - ############################################################################ - ### `Transactable` interface - ############################################################################ - def _get_connection(self) -> Connection: - return self.storage.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.storage.rel_storage._end_transaction(conn=conn) diff --git a/mandala/ui/cfs_utils.py b/mandala/ui/cfs_utils.py deleted file mode 100644 index 9d99375..0000000 --- a/mandala/ui/cfs_utils.py +++ /dev/null @@ -1,60 +0,0 @@ -from ..common_imports import * -from ..core.config import * -from .storage import Storage - -################################################################################ -### tools to summarize tables of refs -################################################################################ -def convert_bytes_to( - num_bytes: float, units: Literal["bytes", "KB", "MB", "GB"] -) -> float: - if units == "KB": - return num_bytes / 1024 - elif units == "MB": - return num_bytes / (1024**2) - elif units == "GB": - return num_bytes / (1024**3) - elif units == "bytes": - return num_bytes - else: - raise ValueError(f"Unknown units: {units}") - - -def estimate_uid_storage( - uids: List[str], - storage: Storage, - units: Literal["bytes", "KB", "MB", "GB"] = "bytes", - sample_size: int = 20, -) -> Tuple[float, float]: - # sample 5 random elements from the column - sample_uids = pd.Series(uids).sample(sample_size, replace=True).values - query = ( - "SELECT length(value) AS size_in_bytes FROM __vrefs__ WHERE __uid__ IN " - + str(tuple(sample_uids)) - ) - sample_counts = storage.rel_storage.execute_df(query) - mean, std = ( - sample_counts["size_in_bytes"].astype(int).mean(), - sample_counts["size_in_bytes"].astype(int).std(), - ) - # check if std is nan - if np.isnan(std): - std = 0 - # convert to requested units - if units == "KB": - avg_column_size = mean / 1024 - std = std / 1024 - elif units == "MB": - avg_column_size = mean / (1024**2) - std = std / (1024**2) - elif units == "GB": - avg_column_size = mean / (1024**3) - std = std / (1024**3) - elif units == "bytes": - avg_column_size = mean - else: - raise ValueError(f"Unknown units: {units}") - # round to 2 decimal places - avg_column_size = round(avg_column_size, 2) - std = round(std, 2) - return avg_column_size, std diff --git a/mandala/ui/cfs_wip.py b/mandala/ui/cfs_wip.py deleted file mode 100644 index f51d89b..0000000 --- a/mandala/ui/cfs_wip.py +++ /dev/null @@ -1,126 +0,0 @@ -import textwrap -from ..common_imports import * -from ..core.config import * -from ..core.tps import Type -from ..core.model import Ref, FuncOp, Call -from ..core.builtins_ import StructOrientations, Builtins -from ..storages.rel_impls.utils import Transactable, transaction, Connection -from ..queries.graphs import copy_subgraph -from ..core.prov import propagate_struct_provenance -from ..queries.weaver import CallNode, ValNode, PaddedList -from ..queries.viz import GraphPrinter, visualize_graph -from .funcs import FuncInterface -from .storage import Storage, ValueLoader -from .cfs_utils import estimate_uid_storage, convert_bytes_to - - - -class CanonicalCF: - """ - Some experiments with more general cfs - """ - pass - - - -class PaddedList(Sequence[Optional[T]]): - """ - A list-like object that is backed by a list of values and a list of indices, - and has length `length`. When indexed, it returns the value from the list at - the corresponding index, or None if the index is not in the list of indices. - """ - - def __init__(self, support: Dict[int, T], length: int): - self.support = support - self.length = length - - def __repr__(self) -> str: - return f"PaddedList({self.tolist()})" - - def tolist(self) -> List[Optional[T]]: - return [self.support.get(i, None) for i in range(self.length)] - - def copy(self) -> "PaddedList": - return PaddedList(support=self.support.copy(), length=self.length) - - def copy_item(self, i: int, times: int, inplace: bool = False) -> "PaddedList": - res = self if inplace else self.copy() - if i not in res.support: - res.length = res.length + times - else: - for j in range(res.length, res.length + times): - res.support[j] = res.support[i] - res.length += times - return res - - def change_length(self, length: int, inplace: bool = False) -> "PaddedList": - res = self if inplace else self.copy() - res.length = length - return res - - def append_items(self, items: List[T], inplace: bool = False) -> "PaddedList": - res = self if inplace else self.copy() - for i, item in enumerate(items): - res.support[res.length + i] = item - res.length += len(items) - return res - - @staticmethod - def from_list(lst: List[Optional[T]], length: Optional[int] = None) -> "PaddedList": - if length is None: - length = len(lst) - items = {i: v for i, v in enumerate(lst) if v is not None} - return PaddedList(support=items, length=length) - - def dropna(self) -> List[T]: - return [self.support[k] for k in sorted(self.support.keys())] - - @staticmethod - def padded_like(plist: "PaddedList[T1]", values: List[T]) -> "PaddedList[T]": - support = {i: values[j] for j, i in enumerate(sorted(plist.support.keys()))} - return PaddedList(support=support, length=len(plist)) - - def __len__(self) -> int: - return self.length - - def __iter__(self) -> Iterator[Optional[T]]: - return (self.support.get(i, None) for i in range(self.length)) - - def __getitem__( - self, idx: Union[int, slice, List[bool], np.ndarray] - ) -> Union[T, "PaddedList"]: - if isinstance(idx, int): - return self.support.get(idx, None) - elif isinstance(idx, slice): - raise NotImplementedError - elif isinstance(idx, (list, np.ndarray)): - return self.masked(idx) - else: - raise NotImplementedError( - "Indexing only supported for integers, slices, and boolean arrays" - ) - - def masked(self, mask: Union[List[bool], np.ndarray]) -> "PaddedList": - """ - Return a new `PaddedList` object with the values masked by the given - boolean array, and the indices updated accordingly. - """ - if len(mask) != self.length: - raise ValueError("Boolean mask must have the same length as the list") - result_items = {} - cur_masked_idx = 0 - for mask_idx, m in enumerate(mask): - if m: - if mask_idx in self.support: - result_items[cur_masked_idx] = self.support[mask_idx] - cur_masked_idx += 1 - return PaddedList(support=result_items, length=cur_masked_idx) - - def keep_only(self, indices: Set[int]) -> "PaddedList": - """ - Return a new `PaddedList` object with the values masked by the given - list of indices. - """ - items = {i: v for i, v in self.support.items() if i in indices} - return PaddedList(support=items, length=self.length) - diff --git a/mandala/ui/context_cache.py b/mandala/ui/context_cache.py deleted file mode 100644 index 848290f..0000000 --- a/mandala/ui/context_cache.py +++ /dev/null @@ -1,255 +0,0 @@ -from ..common_imports import * -from ..core.model import Call, Ref, collect_detached -from ..storages.rels import RelAdapter, VersionAdapter -from ..deps.versioner import Versioner -from ..storages.kv import KVCache, InMemoryStorage, MultiProcInMemoryStorage -from ..storages.rel_impls.utils import Transactable, transaction, Connection - - -class Cache(Transactable): - """ - A layer between calls happening in the context and the persistent storage. - Also responsible for detaching objects from the computational graph. - - All calls represented in the cache are detached (see `Call.detached`), but - objects are represented as `Ref` instances (which are not detached). - - TODO: protect the objects in the cache from being modified. - """ - - def __init__( - self, - rel_adapter: RelAdapter, - ): - self.rel_adapter = rel_adapter - # uid -> detached call - self.call_cache_by_uid = InMemoryStorage() - # causal uid -> detached call - self.call_cache_by_causal = InMemoryStorage() - # uid -> unlinked ref without causal - self.obj_cache = InMemoryStorage() - - def mcache_call_and_objs(self, calls: List[Call]) -> None: - # a more efficient version of `cache_call_and_objs` for multiple calls - # that avoids calling `unlinked` multiple times on the same object, - # which could be expensive for large collections. - unique_vrefs = {} - unique_calls = {} - for call in calls: - for vref in itertools.chain(call.inputs.values(), call.outputs): - unique_vrefs[vref.uid] = vref - unique_calls[call.causal_uid] = call - for vref in unique_vrefs.values(): - self.obj_cache[vref.uid] = vref.unlinked(keep_causal=False) - for call in unique_calls.values(): - self.cache_call(causal_uid=call.causal_uid, call=call) - - def cache_call_and_objs(self, call: Call) -> None: - for vref in itertools.chain(call.inputs.values(), call.outputs): - self.obj_cache[vref.uid] = vref.unlinked(keep_causal=False) - self.cache_call(causal_uid=call.causal_uid, call=call) - - def cache_call(self, causal_uid: str, call: Call) -> None: - self.call_cache_by_causal.set(causal_uid, call.detached()) - self.call_cache_by_uid.set(call.uid, call.detached()) - - def mattach( - self, vrefs: List[Ref], shallow: bool = False, _attach_atoms: bool = True - ): - """ - Regardless of `shallow`, recursively find all the refs that can be found - in the cache. Then pass what's left to the storage method. - """ - vrefs = collect_detached(vrefs, include_transient=False) - cur_frontier = vrefs - new_frontier = [] - not_found = [] - while len(cur_frontier) > 0: - for vref in cur_frontier: - if ( - vref.uid in self.obj_cache.keys() - and self.obj_cache[vref.uid].in_memory - ): - vref.attach(reference=self.obj_cache[vref.uid]) - new_frontier.extend( - collect_detached([vref], include_transient=False) - ) - else: - not_found.append(vref) - cur_frontier = new_frontier - new_frontier = [] - if len(not_found) > 0: - self.rel_adapter.mattach( - vrefs=not_found, shallow=shallow, _attach_atoms=_attach_atoms - ) - - def obj_get(self, obj_uid: str, causal_uid: Optional[str] = None) -> Ref: - """ - Get the given object from the cache or the storage. - """ - #! note that this is not transactional to avoid creating a connection - #! when the object is already in the cache - if self.obj_cache.exists(obj_uid): - res = self.obj_cache.get(obj_uid) - else: - res = self.rel_adapter.obj_get(uid=obj_uid) - res = res.clone() - if causal_uid is not None: - res.causal_uid = causal_uid - return res - - def call_exists(self, uid: str, by_causal: bool) -> bool: - #! note that this is not transactional to avoid creating a connection - #! when the object is already in the cache - if by_causal: - return self.call_cache_by_causal.exists( - uid - ) or self.rel_adapter.call_exists(uid=uid, by_causal=True) - else: - return self.call_cache_by_uid.exists(uid) or self.rel_adapter.call_exists( - uid=uid, by_causal=False - ) - - @transaction() - def call_mget( - self, - uids: List[str], - versioned_ui_name: str, - by_causal: bool = True, - lazy: bool = True, - conn: Optional[Connection] = None, - ) -> List[Call]: - if not by_causal: - raise NotImplementedError() - if not lazy: - raise NotImplementedError() - res = [None for _ in uids] - missing_indices = [] - missing_uids = [] - for i, uid in enumerate(uids): - if self.call_cache_by_causal.exists(uid): - res[i] = self.call_cache_by_causal.get(uid) - else: - missing_indices.append(i) - missing_uids.append(uid) - if len(missing_uids) > 0: - lazy_calls = self.rel_adapter.mget_call_lazy( - versioned_ui_name=versioned_ui_name, - uids=missing_uids, - by_causal=True, - conn=conn, - ) - for i, lazy_call in zip(missing_indices, lazy_calls): - res[i] = lazy_call - return res - - def call_get(self, uid: str, by_causal: bool, lazy: bool = True) -> Call: - """ - Return a *detached* call with the given UID, if it exists. - """ - #! note that this is not transactional to avoid creating a connection - #! when the object is already in the cache - if by_causal and self.call_cache_by_causal.exists(uid): - return self.call_cache_by_causal.get(uid) - elif not by_causal and self.call_cache_by_uid.exists(uid): - return self.call_cache_by_uid.get(uid) - else: - lazy_call = self.rel_adapter.call_get_lazy(uid=uid, by_causal=by_causal) - if not lazy: - #! you need to be more careful here about guarantees provided by - #! the cache - raise NotImplementedError - # # load the values of the inputs and outputs - # inputs = { - # k: self.obj_get(v.uid) - # for k, v in lazy_call.inputs.items() - # } - # outputs = [self.obj_get(v.uid) for v in lazy_call.outputs] - # call_without_outputs = lazy_call.set_input_values(inputs=inputs) - # call = call_without_outputs.set_output_values(outputs=outputs) - # return call - else: - return lazy_call - - @transaction() - def commit( - self, - calls: Optional[List[Call]] = None, - versioner: Optional[Versioner] = None, - version_adapter: VersionAdapter = None, - conn: Optional[Connection] = None, - ): - """ - Flush dirty (written since last time) calls and objs from the cache to - persistent storage, and mark them as clean. - - Note that the cache keeps calls and objects in memory in case they are - needed again, but the storage is the only source of truth. - """ - if calls is None: - new_objs = { - key: self.obj_cache.get(key) for key in self.obj_cache.dirty_entries - } - new_calls = [ - self.call_cache_by_causal.get(key) - for key in self.call_cache_by_causal.dirty_entries - ] - else: - #! if calls are provided, we assume they are attached - new_objs = {} - for call in calls: - for vref in itertools.chain(call.inputs.values(), call.outputs): - new_obj = self.obj_cache[vref.uid] - assert new_obj.in_memory - new_objs[vref.uid] = new_obj - new_calls = calls - self.rel_adapter.obj_sets(new_objs, conn=conn) - self.rel_adapter.upsert_calls(new_calls, conn=conn) - # if self.evict_on_commit: - # self.evict_caches() - if versioner is not None: - version_adapter.dump_state(state=versioner, conn=conn) - self.clear_all() - - @transaction() - def preload_objs(self, uids: List[str], conn: Optional[Connection] = None): - """ - Put the objects with the given UIDs in the cache. Should be used for - bulk loading b/c it opens a connection - """ - uids_not_in_cache = [uid for uid in uids if not self.obj_cache.exists(uid)] - for uid, vref in zip( - uids_not_in_cache, - self.rel_adapter.obj_gets(uids=uids_not_in_cache, conn=conn), - ): - self.obj_cache.set(k=uid, v=vref) - - def evict_all(self): - """ - Remove all entries from the cache. - """ - self.call_cache_by_causal.evict_all() - self.call_cache_by_uid.evict_all() - self.obj_cache.evict_all() - - def clear_all(self): - """ - Mark all entries as clean, but don't remove them from the cache. - """ - self.call_cache_by_causal.clear_all() - self.call_cache_by_uid.clear_all() - self.obj_cache.clear_all() - - def detach_all(self): - for k, v in self.call_cache_by_causal.items(): - self.call_cache_by_causal[k] = v.detached() - for k, v in self.call_cache_by_uid.items(): - self.call_cache_by_uid[k] = v.detached() - for k, v in self.obj_cache.items(): - self.obj_cache[k] = v.detached() - - def _get_connection(self) -> Connection: - return self.rel_adapter._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_adapter._end_transaction(conn=conn) diff --git a/mandala/ui/contexts.py b/mandala/ui/contexts.py deleted file mode 100644 index 5c65d0e..0000000 --- a/mandala/ui/contexts.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import Literal - -from ..common_imports import * -from ..core.config import Config -from ..core.model import Call -from ..queries.workflow import Workflow -from ..queries.graphs import get_canonical_order -from ..queries.weaver import traverse_all -from ..queries.viz import get_names, extract_names_from_scope -from ..queries.main import Querier - -from ..deps.versioner import CodeState, Versioner -from .utils import MODES - -from ..queries.weaver import ( - qwrap, -) - - -class GlobalContext: - current: Optional["Context"] = None - - -class Context: - def __init__( - self, - storage: "storage.Storage" = None, - mode: str = MODES.run, - lazy: bool = False, - allow_calls: bool = True, - debug_calls: bool = False, - recompute_transient: bool = False, - _attach_call_to_outputs: bool = False, # for debugging - debug_truncate: Optional[int] = 20, - ): - self.storage = storage - self.mode = mode - self.lazy = lazy - self.allow_calls = allow_calls - self.debug_calls = debug_calls - self.recompute_transient = recompute_transient - self._attach_call_to_outputs = _attach_call_to_outputs - self.debug_truncate = debug_truncate - self.updates = {} - self._updates_stack = [] - self._call_depth = 0 - self._call_structs = [] - self._call_uids: Dict[Tuple[str, int], List[str]] = defaultdict(list) - self._defined_funcs: List["FuncInterface"] = [] - self._call_buffer: List[Call] = [] - self._code_state: CodeState = None - self._cached_versioner: Versioner = None - - def _backup_state(self, keys: Iterable[str]) -> Dict[str, Any]: - res = {} - for k in keys: - cur_v = self.__dict__[f"{k}"] - if k == "storage": # gotta use a pointer - res[k] = cur_v - else: - res[k] = copy.deepcopy(cur_v) - return res - - def __enter__(self) -> "Context": - is_top = len(self._updates_stack) == 0 - ### verify update keys - updates = self.updates - if not all( - k - in ( - "storage", - "mode", - "lazy", - "allow_calls", - "debug_calls", - "recompute_transient", - "_attach_call_to_outputs", - ) - for k in updates.keys() - ): - raise ValueError(updates.keys()) - if "mode" in updates.keys() and updates["mode"] not in MODES.all_: - raise ValueError(updates["mode"]) - ### backup state - before_update = self._backup_state(keys=updates.keys()) - # self._updates_stack.append(before_update) - ### apply updates - for k, v in updates.items(): - if v is not None: - self.__dict__[f"{k}"] = v - # Load state from remote - if self.storage is not None: - # self.storage.sync_with_remote() - self.storage.sync_from_remote() - if ( - self.mode in (MODES.run, MODES.query) - and self.storage is not None - and self.storage.versioned - ): - storage = self.storage - if is_top: - versioner, code_state = storage.sync_code() - self._cached_versioner = versioner - self._code_state = code_state - # this is last so that any exceptions don't leave the context in an - # inconsistent state - self._updates_stack.append(before_update) - return self - - def _undo_updates(self): - """ - Roll back the updates from the current level - """ - if not self._updates_stack: - raise InternalError("No context to exit from") - ascent_updates = self._updates_stack.pop() - for k, v in ascent_updates.items(): - self.__dict__[f"{k}"] = v - # unlink from global if done - if len(self._updates_stack) == 0: - GlobalContext.current = None - - def __exit__(self, exc_type, exc_value, exc_traceback): - if exc_type is None: - if self.mode == MODES.run: - if self.storage is not None: - # commit calls from temp partition to main and tabulate them - if Config.autocommit: - self.storage.commit(versioner=self._cached_versioner) - self.storage.sync_to_remote() - elif self.mode == MODES.query: - pass - elif self.mode == MODES.noop: - pass - elif self.mode == MODES.batch: - executor = SimpleWorkflowExecutor() - workflow = Workflow.from_call_structs(self._call_structs) - calls = executor.execute(workflow=workflow, storage=self.storage) - self.storage.commit(calls=calls, versioner=self._cached_versioner) - else: - raise InternalError(self.mode) - self._undo_updates() - return None - - def __call__( - self, - storage: Optional["Storage"] = None, - allow_calls: bool = True, - debug_calls: bool = False, - recompute_transient: bool = False, - _attach_call_to_outputs: bool = False, - **updates, - ): - self.updates = { - "storage": storage, - "allow_calls": allow_calls, - "debug_calls": debug_calls, - "recompute_transient": recompute_transient, - "_attach_call_to_outputs": _attach_call_to_outputs, - **updates, - } - return self - - -from . import storage -from .executors import SimpleWorkflowExecutor diff --git a/mandala/ui/executors.py b/mandala/ui/executors.py deleted file mode 100644 index 034484c..0000000 --- a/mandala/ui/executors.py +++ /dev/null @@ -1,52 +0,0 @@ -from ..common_imports import * -from ..queries.workflow import Workflow -from abc import ABC, abstractmethod -from ..core.model import Call -from ..core.config import MODES -from ..core.wrapping import unwrap - - -class WorkflowExecutor(ABC): - @abstractmethod - def execute(self, workflow: Workflow, storage: "storage.Storage") -> List[Call]: - pass - - -class SimpleWorkflowExecutor(WorkflowExecutor): - def execute(self, workflow: Workflow, storage: "storage.Storage") -> List[Call]: - result = [] - for op_node in workflow.op_nodes: - call_structs = workflow.op_node_to_call_structs[op_node] - for call_struct in call_structs: - func_op, inputs, outputs = ( - call_struct.func_op, - call_struct.inputs, - call_struct.outputs, - ) - assert all([not inp.is_delayed() for inp in inputs.values()]) - inputs = unwrap(inputs) - fi = funcs.FuncInterface(func_op=func_op) - with storage.run(): - _, call = fi.call(args=tuple(), kwargs=inputs) - vref_outputs = call.outputs - # vref_outputs, call, _ = r.call(args=tuple(), kwargs=inputs, conn=None) - # vref_outputs, call, _ = storage.call_run( - # func_op=func_op, - # inputs=inputs, - # _call_depth=0, - # ) - # overwrite things - for output, vref_output in zip(outputs, vref_outputs): - output._obj = vref_output.obj - output.uid = vref_output.uid - output.in_memory = True - result.append(call) - # filter out repeated calls - result = list({call.uid: call for call in result}.values()) - return result - - -from . import storage -from . import runner -from . import contexts -from . import funcs diff --git a/mandala/ui/funcs.py b/mandala/ui/funcs.py deleted file mode 100644 index de5331e..0000000 --- a/mandala/ui/funcs.py +++ /dev/null @@ -1,243 +0,0 @@ -from functools import wraps -from ..common_imports import * -from ..core.model import FuncOp, TransientObj, Call -from ..core.utils import unwrap_decorators -from ..core.sig import Signature, _postprocess_outputs -from ..core.tps import AnyType -from ..queries.weaver import ValNode, qwrap, call_query -from ..deps.tracers.dec_impl import DecTracer - -from . import contexts -from .utils import bind_inputs, format_as_outputs, MODES, wrap_atom - - -def Q(pattern: Optional[Any] = None) -> "pattern": - """ - Create a `ValQuery` instance to be used as a placeholder in a query - """ - if pattern is None: - # return ValQuery(creator=None, created_as=None) - return ValNode(creators=[], created_as=[], constraint=None, tp=AnyType()) - else: - return qwrap(obj=pattern) - - -T = TypeVar("T") - - -def Transient(obj: T, unhashable: bool = False) -> T: - if contexts.GlobalContext.current is not None: - return TransientObj(obj=obj, unhashable=unhashable) - else: - return obj - - -class FuncInterface: - """ - Wrapper around a memoized function. - - This is the object the `@op` decorator converts functions into. - """ - - def __init__( - self, - func_op: FuncOp, - executor: str = "python", - ): - self.func_op = func_op - self.__name__ = self.func_op.sig.ui_name - self._is_synchronized = False - self._is_invalidated = False - self._storage_id = None - self.executor = executor - if ( - contexts.GlobalContext.current is not None - and contexts.GlobalContext.current.mode == MODES.define - ): - contexts.GlobalContext.current._defined_funcs.append(self) - - @property - def sig(self) -> Signature: - return self.func_op.sig - - def __repr__(self) -> str: - sig = self.func_op.sig - if self._is_invalidated: - # clearly distinguish stale functions - return f"FuncInterface(func_name={sig.ui_name}, is_invalidated=True)" - else: - from rich.text import Text - - return f"FuncInterface(signature={sig})" - - def invalidate(self): - self._is_invalidated = True - self._is_synchronized = False - - @property - def is_invalidated(self) -> bool: - return self._is_invalidated - - def _preprocess_call( - self, *args, **kwargs - ) -> Tuple[Dict[str, Any], str, "storage.Storage", contexts.Context]: - context = contexts.GlobalContext.current - storage = context.storage - if self._is_invalidated: - raise RuntimeError( - "This function has been invalidated due to a change in the signature, and cannot be called" - ) - # synchronize if necessary - storage.synchronize(self) - # synchronize(func=self, storage=context.storage) - inputs = bind_inputs(args, kwargs, mode=context.mode, func_op=self.func_op) - mode = context.mode - return inputs, mode, storage, context - - def call(self, args, kwargs) -> Tuple[Union[None, Any, Tuple[Any]], Optional[Call]]: - # low-level API for more control over internal machinery - r = runner.Runner(context=contexts.GlobalContext.current, func_op=self.func_op) - # inputs, mode, storage, context = self._preprocess_call(*args, - # **kwargs) - if r.storage is not None: - r.storage.synchronize(self) - if r.mode == MODES.run: - func_op = r.func_op - r.preprocess(args, kwargs) - if r.must_execute: - r.pre_execute(conn=None) - if r.tracer_option is not None: - tracer = r.tracer_option - with tracer: - if isinstance(tracer, DecTracer): - node = tracer.register_call(func=func_op.func) - result = func_op.func(**r.func_inputs) - if isinstance(tracer, DecTracer): - tracer.register_return(node=node) - else: - result = func_op.func(**r.func_inputs) - outputs = _postprocess_outputs(sig=func_op.sig, result=result) - call = r.post_execute(outputs=outputs) - else: - call = r.load_call(conn=None) - return r.postprocess(call=call), call - return r.process_other_modes(args, kwargs), None - - def __call__(self, *args, **kwargs) -> Union[None, Any, Tuple[Any]]: - return self.call(args, kwargs)[0] - - -class AsyncFuncInterface(FuncInterface): - async def call( - self, args, kwargs - ) -> Tuple[Union[None, Any, Tuple[Any]], Optional[Call]]: - # low-level API for more control over internal machinery - r = runner.Runner(context=contexts.GlobalContext.current, func_op=self.func_op) - # inputs, mode, storage, context = self._preprocess_call(*args, - # **kwargs) - if r.storage is not None: - r.storage.synchronize(self) - if r.mode == MODES.run: - func_op = r.func_op - r.preprocess(args, kwargs) - if r.must_execute: - r.pre_execute(conn=None) - if r.tracer_option is not None: - tracer = r.tracer_option - with tracer: - if isinstance(tracer, DecTracer): - node = tracer.register_call(func=func_op.func) - result = await func_op.func(**r.func_inputs) - if isinstance(tracer, DecTracer): - tracer.register_return(node=node) - else: - result = await func_op.func(**r.func_inputs) - outputs = _postprocess_outputs(sig=func_op.sig, result=result) - call = r.post_execute(outputs=outputs) - else: - call = r.load_call(conn=None) - return r.postprocess(call=call), call - return r.process_other_modes(args, kwargs), None - - async def __call__(self, *args, **kwargs) -> Union[None, Any, Tuple[Any]]: - return (await self.call(args, kwargs))[0] - - -class FuncDecorator: - # parametrized version of `@op` decorator - def __init__( - self, - **kwargs, - ): - self.kwargs = kwargs - - def __call__(self, func: Callable) -> "func": - # func = unwrap_decorators(func, strict=True) - func_op = FuncOp( - func=func, - n_outputs_override=self.kwargs.get("n_outputs"), - version=self.kwargs.get("version"), - ui_name=self.kwargs.get("ui_name"), - is_super=self.kwargs.get("is_super", False), - ) - if inspect.iscoroutinefunction(func): - InterfaceCls = AsyncFuncInterface - else: - InterfaceCls = FuncInterface - return wraps(func)( - InterfaceCls( - func_op=func_op, - executor=self.kwargs.get("executor", "python"), - ) - ) - - -def op( - version: Union[Callable, Optional[int]] = None, - nout: Optional[int] = None, - ui_name: Optional[str] = None, - executor: str = "python", -) -> "version": # a hack to make mypy/autocomplete happy - if callable(version): - # a hack to handle the @op case - func = version - # func = unwrap_decorators(func, strict=True) - func_op = FuncOp(func=func, n_outputs_override=nout) - if inspect.iscoroutinefunction(func): - return wraps(func)(AsyncFuncInterface(func_op=func_op)) - else: - return wraps(func)(FuncInterface(func_op=func_op)) - else: - # @op(...) case - return FuncDecorator( - version=version, - n_outputs=nout, - ui_name=ui_name, - executor=executor, - ) - - -def superop( - version: Union[Callable, Optional[int]] = None, - ui_name: Optional[str] = None, - executor: str = "python", -) -> "version": - if callable(version): - # a hack to handle the @op case - func = version - # func_op = FuncOp(func=unwrap_decorators(func, strict=True), - # is_super=True) - func_op = FuncOp(func=func, is_super=True) - if inspect.iscoroutinefunction(func): - return AsyncFuncInterface(func_op=func_op) - else: - return FuncInterface(func_op=func_op) - else: - # @op(...) case - return FuncDecorator( - version=version, ui_name=ui_name, executor=executor, is_super=True - ) - - -from . import storage -from . import runner diff --git a/mandala/ui/provenance.py b/mandala/ui/provenance.py deleted file mode 100644 index ab965fd..0000000 --- a/mandala/ui/provenance.py +++ /dev/null @@ -1,234 +0,0 @@ -from ..common_imports import * -from .storage import Storage -from ..core.model import Ref, Call -from ..core.builtins_ import StructRef, ListRef, DictRef, SetRef -from ..core.config import Provenance, parse_output_idx -from ..queries.weaver import ValNode, CallNode, StructOrientations, traverse_all -from ..queries.viz import GraphPrinter - -BUILTIN_OP_IDS = ("__list___0", "__dict___0", "__set___0") -OP_ID_TO_STRUCT_NAME = { - "__list___0": "lst", - "__dict___0": "dct", - "__set___0": "st", -} - -OP_ID_TO_ELT_NAME = { - "__list___0": "elt", - "__dict___0": "val", - "__set___0": "elt", -} - -ELT_NAMES = ("elt", "val") -STRUCT_NAMES = ("lst", "dct", "st") -IDX_NAMES = ("idx", "key") - - -class ProvHelpers: - """ - Tools for producing provenance graphs identical to runtime provenance graphs - """ - - def __init__(self, storage: Storage, prov_df: pd.DataFrame): - self.storage = storage - self.prov = prov_df.sort_index() # improve performance of .loc - - def get_call_df(self, call_causal_uid: str) -> pd.DataFrame: - """ - Get the provenance dataframe for the given call causal uid - """ - #! very inefficient - return self.prov[self.prov[Provenance.call_causal_uid] == call_causal_uid] - - def get_df_as_input(self, causal_uid: str) -> pd.DataFrame: - """ - Get the provenance dataframe for the given causal uid as an input - """ - key = (causal_uid, "input") - if key not in self.prov.index: - # return empty like self.prov - return self.prov.loc[[]] - else: - return self.prov.loc[[key]] # passing a list to .loc guarantees a dataframe - - def get_df_as_output(self, causal_uid: str) -> pd.DataFrame: - """ - Get the provenance dataframe for the given causal uid as an output - """ - key = (causal_uid, "output") - if key not in self.prov.index: - # return empty like self.prov - return self.prov.loc[[]] - else: - res = self.prov.loc[[key]] # passing a list to .loc guarantees a dataframe - if res.shape[0] > 1: - raise NotImplementedError( - f"causal uid {causal_uid} is the output of multiple calls" - ) - return res - - def is_op_output(self, causal_uid: str) -> bool: - """ - Check if the given causal uid is the output of a (non-struct) op call - and verify that there is a unique such call - """ - n = self.get_df_as_output(causal_uid=causal_uid).shape[0] - if n > 1: - raise NotImplementedError( - f"causal uid {causal_uid} is the output of multiple op calls" - ) - return n == 1 - - def get_containing_structs(self, causal_uid: str) -> Tuple[List[str], List[str]]: - """ - Return two lists: the causal UIDs of structs containing this ref, and - the causal call UIDs of the corresponding structural calls - """ - elt_rows = self.get_df_as_input(causal_uid=causal_uid) - # restrict to structs containing this ref only - elt_rows = elt_rows[elt_rows[Provenance.op_id].isin(BUILTIN_OP_IDS)] - elt_rows = elt_rows[elt_rows[Provenance.name].isin(ELT_NAMES)] - # get the call causal UIDs - call_causal_uids = elt_rows[Provenance.call_causal_uid].tolist() - # restrict to the rows containing the structs themselves - x = self.prov[self.prov[Provenance.call_causal_uid].isin(call_causal_uids)] - x = x[x[Provenance.name].isin(STRUCT_NAMES)] - return ( - x.index.get_level_values(0).tolist(), - x[Provenance.call_causal_uid].tolist(), - ) - - def get_struct_call_uids(self, causal_uid: str) -> List[str]: - key = (causal_uid, "input") - df = self.prov.loc[[key]] - df = df[df[Provenance.op_id].isin(BUILTIN_OP_IDS)] - df = df[df[Provenance.name].isin(STRUCT_NAMES)] - return df[Provenance.call_causal_uid].tolist() - - def get_creator_chain(self, causal_uid: str) -> List[str]: - """ - Return a list of call UIDs for the chain of calls creating a value (if any) - """ - as_output_df = self.get_df_as_output(causal_uid=causal_uid) - if as_output_df.shape[0] > 0: - return as_output_df[Provenance.call_causal_uid].tolist() - else: - # gotta check for containing structs - struct_uids, struct_call_uids = self.get_containing_structs( - causal_uid=causal_uid - ) - for struct_uid, struct_call_uid in zip(struct_uids, struct_call_uids): - recursive_result = self.get_creator_chain(causal_uid=struct_uid) - if len(recursive_result) > 0: - return [struct_call_uid] + self.get_creator_chain( - causal_uid=struct_uid - ) - return [] - - def link_call_into_graph( - self, - call_causal_uid: str, - refs: Dict[str, Ref], - orientation: Optional[str], - calls: Dict[str, Call], - ) -> Call: - """ - Load the call object from storage, link in into the current graph - (creating `Ref` instances when they are not already in `refs`), and add - pointwise constraints to each ref (via the full UID). - - ! Note that for structs we don't unify the indices of the struct with the - rest of the graph to avoid unnatural clashes between refs. We also - don't put the indices in the `refs` object. - """ - # load the call - call = self.storage.rel_adapter.call_get_lazy( - uid=call_causal_uid, by_causal=True - ) - if call.full_uid in calls: - return calls[call.full_uid] - # look up inputs/outputs in `refs` or add them there - for name, inp in call.inputs.items(): - if call.func_op.is_builtin and name in IDX_NAMES: - continue - if inp.causal_uid in refs: - call.inputs[name] = refs[inp.causal_uid] - else: - refs[inp.causal_uid] = inp - for idx, outp in enumerate(call.outputs): - if outp.causal_uid in refs: - call.outputs[idx] = refs[outp.causal_uid] - else: - refs[outp.causal_uid] = outp - # actual linking step - call.link(orientation=orientation) - # add pointwise constraints - for ref in itertools.chain(call.inputs.values(), call.outputs): - ref.query.constraint = [ref.full_uid] - calls[call.full_uid] = call - return call - - def step(self, ref: Ref, refs: Dict[str, Ref], calls: Dict[str, Call]) -> List[Ref]: - """ - Given the (*non-index*) refs constructed so far and a ref in the graph, - link the chain of calls creating this ref. If some of the refs along the - way already exist, use those instead of creating new refs. - """ - # build one step of the provenance graph - creator_chain = self.get_creator_chain(causal_uid=ref.causal_uid) - if not creator_chain: - if isinstance(ref, StructRef): - constituent_call_uids = self.get_struct_call_uids( - causal_uid=ref.causal_uid - ) - elts = [] - for constituent_call_uid in constituent_call_uids: - call = self.link_call_into_graph( - call_causal_uid=constituent_call_uid, - refs=refs, - orientation=StructOrientations.construct, - calls=calls, - ) - op_id = call.func_op.sig.versioned_internal_name - elts.append(call.inputs[OP_ID_TO_ELT_NAME[op_id]]) - return elts - else: - return [] - elif len(creator_chain) == 1: - op_call = self.link_call_into_graph( - call_causal_uid=creator_chain[0], - refs=refs, - orientation=None, - calls=calls, - ) - return list(op_call.inputs.values()) - else: - for struct_call_uid in creator_chain[:-1]: - self.link_call_into_graph( - call_causal_uid=struct_call_uid, - refs=refs, - orientation=StructOrientations.destruct, - calls=calls, - ) - op_call = self.link_call_into_graph( - call_causal_uid=creator_chain[-1], - refs=refs, - orientation=None, - calls=calls, - ) - return list(op_call.inputs.values()) - - def get_graph(self, full_uid: str) -> Tuple[Set[ValNode], Set[CallNode]]: - """ - Given the full UID of a ref, recover the full provenance graph for that - ref - """ - refs = {} - calls = {} - res = Ref.from_full_uid(full_uid=full_uid) - refs[res.causal_uid] = res - queue = [res] - while queue: - ref = queue.pop() - queue.extend(self.step(ref=ref, refs=refs, calls=calls)) - return traverse_all(vqs=[res.query], direction="backward") diff --git a/mandala/ui/remote_utils.py b/mandala/ui/remote_utils.py deleted file mode 100644 index d62661d..0000000 --- a/mandala/ui/remote_utils.py +++ /dev/null @@ -1,142 +0,0 @@ -from pypika import Table, Query -import pyarrow.parquet as pq - -from ..common_imports import * -from ..core.config import Config -from ..storages.rels import RelAdapter, RemoteEventLogEntry -from ..storages.rel_impls.bases import RelStorage -from ..storages.sigs import SigSyncer, SigAdapter -from ..storages.remote_storage import RemoteStorage -from ..storages.rel_impls.utils import Transactable, transaction, Connection - - -class RemoteManager(Transactable): - def __init__( - self, - rel_adapter: RelAdapter, - sig_adapter: SigAdapter, - rel_storage: RelStorage, - sig_syncer: SigSyncer, - root: RemoteStorage, - ): - self.rel_adapter = rel_adapter - self.sig_adapter = sig_adapter - self.rel_storage = rel_storage - self.sig_syncer = sig_syncer - self.root = root - - def _get_connection(self) -> Connection: - return self.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_storage._end_transaction(conn=conn) - - @transaction() - def bundle_to_remote( - self, conn: Optional[Connection] = None - ) -> RemoteEventLogEntry: - """ - Collect the new calls according to the event log, and pack them into a - dict of binary blobs to be sent off to the remote server. - - NOTE: this also renames tables and columns to their immutable internal - names. - """ - # Bundle event log and referenced calls into tables. - event_log_df = self.rel_adapter.get_event_log(conn=conn) - tables_with_changes = {} - table_names_with_changes = event_log_df["table"].unique() - - event_log_table = Table(self.rel_adapter.EVENT_LOG_TABLE) - for table_name in table_names_with_changes: - table = Table(table_name) - tables_with_changes[table_name] = self.rel_storage.execute_arrow( - query=Query.from_(table) - .join(event_log_table) - .on(table[Config.uid_col] == event_log_table[Config.uid_col]) - .select(table.star), - conn=conn, - ) - # pass to internal names - tables_with_changes = self.sig_adapter.rename_tables( - tables_with_changes, to="internal", conn=conn - ) - output = {} - for table_name, table in tables_with_changes.items(): - buffer = io.BytesIO() - pq.write_table(table, buffer) - output[table_name] = buffer.getvalue() - return output - - @transaction() - def apply_from_remote( - self, changes: List[RemoteEventLogEntry], conn: Optional[Connection] = None - ): - """ - Apply new calls from the remote server. - - NOTE: this also renames tables and columns to their UI names. - """ - for raw_changeset in changes: - changeset_data = {} - for table_name, serialized_table in raw_changeset.items(): - buffer = io.BytesIO(serialized_table) - deserialized_table = pq.read_table(buffer) - changeset_data[table_name] = deserialized_table - # pass to UI names - changeset_data = self.sig_adapter.rename_tables( - tables=changeset_data, to="ui", conn=conn - ) - for table_name, deserialized_table in changeset_data.items(): - self.rel_storage.upsert(table_name, deserialized_table, conn=conn) - - @transaction() - def sync_from_remote(self, conn: Optional[Connection] = None): - """ - Pull new calls from the remote server. - - Note that the server's schema (i.e. signatures) can be a super-schema of - the local schema, but all local schema elements must be present in the - remote schema, because this is enforced by how schema updates are - performed. - """ - if not isinstance(self.root, RemoteStorage): - return - # apply signature changes from the server first, because the new calls - # from the server may depend on the new schema. - self.sig_syncer.sync_from_remote(conn=conn) - # next, pull new calls - new_log_entries, timestamp = self.root.get_log_entries_since( - self.last_timestamp - ) - self.apply_from_remote(new_log_entries, conn=conn) - self.last_timestamp = timestamp - logger.debug("synced from remote") - - @transaction() - def sync_to_remote(self, conn: Optional[Connection] = None): - """ - Send calls to the remote server. - - As with `sync_from_remote`, the server may have a super-schema of the - local schema. The current signatures are first pulled and applied to the - local schema. - """ - if not isinstance(self.root, RemoteStorage): - # todo: there should be a way to completely ignore the event log - # when there's no remote - self.rel_adapter.clear_event_log(conn=conn) - else: - # collect new work and send it to the server - changes = self.bundle_to_remote(conn=conn) - self.root.save_event_log_entry(changes) - # clear the event log only *after* the changes have been received - self.rel_adapter.clear_event_log(conn=conn) - logger.debug("synced to remote") - - @transaction() - def sync_with_remote(self, conn: Optional[Connection] = None): - if not isinstance(self.root, RemoteStorage): - return - self.sync_to_remote(conn=conn) - self.sync_from_remote(conn=conn) diff --git a/mandala/ui/runner.py b/mandala/ui/runner.py deleted file mode 100644 index 97c2ce6..0000000 --- a/mandala/ui/runner.py +++ /dev/null @@ -1,344 +0,0 @@ -from ..common_imports import * -from ..core.model import Call, Ref -from ..core.config import Config -from ..core.builtins_ import StructOrientations -from .storage import FuncOp, Connection -from .utils import bind_inputs, format_as_outputs, wrap_atom -from .contexts import Context - -from ..queries.weaver import call_query - -from .utils import MODES, debug_call, get_terminal_data, check_determinism - -from ..common_imports import * -from ..core.config import Config -from ..core.model import Ref, Call, FuncOp -from ..core.builtins_ import Builtins -from ..core.wrapping import ( - wrap_inputs, - wrap_outputs, - causify_outputs, - decausify, - unwrap, - contains_transient, - contains_not_in_memory, -) - -from ..storages.rel_impls.utils import Transactable, transaction, Connection - -from ..deps.tracers import TracerABC -from ..deps.versioner import Versioner - -from ..queries.weaver import StructOrientations - - -class Runner(Transactable): # this is terrible - def __init__(self, context: Optional[Context], func_op: FuncOp): - self.context: Context = context - self.storage = context.storage if context is not None else None - self.func_op = func_op - self.mode = context.mode if context is not None else MODES.noop - self.code_state = context._code_state if context is not None else None - #! todo - these should be set by the context - self.recurse = False - self.collect_calls = False - - ### set by preprocess - self.linking_on: bool = None - self.must_execute: bool = None - self.suspended_trace_obj: Optional[Any] = None - self.tracer_option: Optional[TracerABC] = None - self.versioner: Optional[Versioner] = None - self.wrapped_inputs: Dict[str, Ref] = None - self.input_calls: List[Call] = None - self.pre_call_uid: str = None - self.call_option: Optional[Call] = None - - ### set by prep_execute - self.must_save: bool = None - self.is_recompute: bool = None - self.func_inputs: Dict[str, Any] = None - - def process_other_modes( - self, args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Union[None, Any, Tuple[Any]]: - if self.mode == MODES.query: - inputs = bind_inputs(args, kwargs, mode=self.mode, func_op=self.func_op) - return self.process_query(inputs=inputs) - elif self.mode == MODES.batch: - inputs = bind_inputs(args, kwargs, mode=self.mode, func_op=self.func_op) - return self.process_batch(inputs=inputs) - elif self.mode == MODES.noop: - return self.func_op.func(*args, **kwargs) - else: - raise NotImplementedError - - def process_query(self, inputs: Dict[str, Any]) -> Union[None, Any, Tuple[Any]]: - return format_as_outputs( - outputs=call_query(func_op=self.func_op, inputs=inputs) - ) - - def process_batch(self, inputs: Dict[str, Any]) -> Union[None, Any, Tuple[Any]]: - wrapped_inputs = {k: wrap_atom(v) for k, v in inputs.items()} - outputs, call_struct = self.storage.call_batch( - func_op=self.func_op, inputs=wrapped_inputs - ) - self.context._call_structs.append(call_struct) - return format_as_outputs(outputs=outputs) - - def preprocess( - self, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - conn: Optional[Connection] = None, - ): - self.context._call_depth += 1 - self.linking_on = self.context._call_depth == 1 - func_op = self.func_op - inputs = bind_inputs(args, kwargs, mode=self.mode, func_op=func_op) - - if self.storage.versioned: - self.versioner = self.context._cached_versioner - self.suspended_trace_obj = self.versioner.TracerCls.get_active_trace_obj() - self.versioner.TracerCls.set_active_trace_obj(trace_obj=None) - else: - self.versioner = None - self.suspended_trace_obj = None - - wrapped_inputs, input_calls = wrap_inputs( - objs=inputs, - annotations=func_op.input_annotations, - ) - if self.linking_on: - for input_call in input_calls: - input_call.link( - orientation=StructOrientations.construct, - ) - pre_call_uid = func_op.get_pre_call_uid( - input_uids={k: v.uid for k, v in wrapped_inputs.items()} - ) - call_option = self.storage.lookup_call( - func_op=func_op, - pre_call_uid=pre_call_uid, - input_uids={k: v.uid for k, v in wrapped_inputs.items()}, - input_causal_uids={k: v.causal_uid for k, v in wrapped_inputs.items()}, - conn=conn, - code_state=self.code_state, - versioner=self.versioner, - ) - self.tracer_option = ( - self.versioner.make_tracer() if self.storage.versioned else None - ) - - # condition determining whether we will actually call the underlying function - self.must_execute = ( - call_option is None - or (self.recurse and func_op.is_super) - or ( - call_option is not None - and call_option.transient - and self.context.recompute_transient - ) - ) - self.wrapped_inputs = wrapped_inputs - self.input_calls = input_calls - self.pre_call_uid = pre_call_uid - self.call_option = call_option - - @transaction() - def pre_execute( - self, - conn: Optional[Connection], - ): - call_option = self.call_option - wrapped_inputs = self.wrapped_inputs - func_op = self.func_op - self.is_recompute = ( - call_option is not None - and call_option.transient - and self.context.recompute_transient - ) - needs_input_values = ( - call_option is None or call_option is not None and call_option.transient - ) - pass_inputs_unwrapped = Config.autounwrap_inputs and not func_op.is_super - self.must_save = call_option is None - if not (self.recurse and func_op.is_super) and not self.context.allow_calls: - raise ValueError( - f"Call to {func_op.sig.ui_name} not found in call storage." - ) - if needs_input_values: - self.storage.cache.mattach(vrefs=list(wrapped_inputs.values())) - if any(contains_not_in_memory(ref=ref) for ref in wrapped_inputs.values()): - msg = ( - "Cannot execute function whose inputs are transient values " - "that are not in memory. " - "Use `recompute_transient=True` to force recomputation of these inputs." - ) - raise ValueError(msg) - if pass_inputs_unwrapped: - self.func_inputs = unwrap(obj=wrapped_inputs, through_collections=True) - else: - self.func_inputs = wrapped_inputs - - def post_execute(self, outputs: List[Any]): - call_option = self.call_option - wrapped_inputs = self.wrapped_inputs - pre_call_uid = self.pre_call_uid - input_calls = self.input_calls - func_op = self.func_op - - if self.tracer_option is not None: - # check the trace against the code state hypothesis - self.versioner.apply_state_hypothesis( - hypothesis=self.code_state, trace_result=self.tracer_option.graph.nodes - ) - # update the global topology and code state - self.versioner.update_global_topology(graph=self.tracer_option.graph) - self.code_state.add_globals_from(graph=self.tracer_option.graph) - - content_version, semantic_version = ( - self.versioner.get_version_ids( - pre_call_uid=pre_call_uid, - tracer_option=self.tracer_option, - is_recompute=self.is_recompute, - ) - if self.storage.versioned - else (None, None) - ) - call_uid = func_op.get_call_uid( - pre_call_uid=pre_call_uid, semantic_version=semantic_version - ) - wrapped_outputs, output_calls = wrap_outputs( - objs=outputs, - annotations=func_op.output_annotations, - ) - transient = any(contains_transient(ref) for ref in wrapped_outputs) - if self.is_recompute: - check_determinism( - observed_semver=semantic_version, - stored_semver=call_option.semantic_version, - observed_output_uids=[v.uid for v in wrapped_outputs], - stored_output_uids=[w.uid for w in call_option.outputs], - func_op=func_op, - ) - if self.linking_on: - wrapped_outputs = [x.unlinked(keep_causal=True) for x in wrapped_outputs] - call = Call( - uid=call_uid, - func_op=func_op, - inputs=wrapped_inputs, - outputs=wrapped_outputs, - content_version=content_version, - semantic_version=semantic_version, - transient=transient, - ) - for outp in wrapped_outputs: - decausify(ref=outp) # for now, a "clean slate" approach - causify_outputs(refs=wrapped_outputs, call_causal_uid=call.causal_uid) - if self.linking_on: - call.link(orientation=None) - output_calls = [Builtins.collect_all_calls(x) for x in wrapped_outputs] - output_calls = [x for y in output_calls for x in y] - for output_call in output_calls: - output_call.link( - orientation=StructOrientations.destruct, - ) - if self.must_save: - self.storage.cache.mcache_call_and_objs( - calls=[call] + input_calls + output_calls - ) - return call - - @transaction() - def load_call(self, conn: Optional[Connection] = None): - call_option = self.call_option - wrapped_inputs = self.wrapped_inputs - input_calls = self.input_calls - func_op = self.func_op - - assert call_option is not None - - if not self.context.lazy: - self.storage.cache.preload_objs( - [v.uid for v in call_option.outputs], conn=conn - ) - wrapped_outputs = [ - self.storage.cache.obj_get(v.uid) for v in call_option.outputs - ] - else: - wrapped_outputs = [v for v in call_option.outputs] - # recreate call - call = Call( - uid=call_option.uid, - func_op=func_op, - inputs=wrapped_inputs, - outputs=wrapped_outputs, - content_version=call_option.content_version, - semantic_version=call_option.semantic_version, - transient=call_option.transient, - ) - if self.linking_on: - call.outputs = [x.unlinked(keep_causal=False) for x in call.outputs] - causify_outputs(refs=call.outputs, call_causal_uid=call.causal_uid) - if not self.context.lazy: - output_calls = [Builtins.collect_all_calls(x) for x in call.outputs] - output_calls = [x for y in output_calls for x in y] - - else: - output_calls = [] - call.link(orientation=None) - for output_call in output_calls: - output_call.link( - orientation=StructOrientations.destruct, - ) - else: - causify_outputs(refs=call.outputs, call_causal_uid=call.causal_uid) - if call.causal_uid != call_option.causal_uid: - # this is a new call causally; must save it and its constituents - output_calls = [Builtins.collect_all_calls(x) for x in call.outputs] - output_calls = [x for y in output_calls for x in y] - self.storage.cache.mcache_call_and_objs( - calls=[call] + input_calls + output_calls - ) - return call - - def postprocess( - self, call: Call, output_format: str = "returns" - ) -> Union[None, Any, Tuple[Any]]: - func_op = self.func_op - if self.storage.versioned and self.suspended_trace_obj is not None: - self.versioner.TracerCls.set_active_trace_obj( - trace_obj=self.suspended_trace_obj - ) - terminal_data = get_terminal_data(func_op=func_op, call=call) - # Tracer.leaf_signal(data=terminal_data) - self.versioner.TracerCls.register_leaf_event( - trace_obj=self.suspended_trace_obj, data=terminal_data - ) - # self.tracer_impl.suspended_tracer.register_leaf(data=terminal_data) - if self.context.debug_calls: - debug_call( - func_name=func_op.sig.ui_name, - memoized=self.call_option is not None, - wrapped_inputs=self.wrapped_inputs, - wrapped_outputs=call.outputs, - ) - if self.collect_calls: - self.context._call_buffer.append(call) - sig = func_op.sig - self.context._call_uids[(sig.internal_name, sig.version)].append(call.uid) - self.context._call_depth -= 1 - if self.context._attach_call_to_outputs: - for output in call.outputs: - output._call = call.detached() - if output_format == "list": - return call.outputs - elif output_format == "returns": - return format_as_outputs(outputs=call.outputs) - - def _get_connection(self) -> Connection: - return self.storage.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.storage.rel_storage._end_transaction(conn=conn) diff --git a/mandala/ui/storage.py b/mandala/ui/storage.py deleted file mode 100644 index 96351a2..0000000 --- a/mandala/ui/storage.py +++ /dev/null @@ -1,1258 +0,0 @@ -import datetime -import tqdm -from typing import Literal -from collections import deque - -from .viz import _get_colorized_diff -from .utils import MODES -from .remote_utils import RemoteManager -from . import contexts -from .context_cache import Cache - -from ..common_imports import * -from ..core.config import Config, Provenance, dump_output_name -from ..core.model import Ref, Call, FuncOp, ValueRef, Delayed -from ..core.builtins_ import Builtins, ListRef, DictRef, SetRef -from ..core.wrapping import ( - unwrap, - compare_dfs_as_relations, -) -from ..core.tps import Type, AnyType, ListType, DictType, SetType -from ..core.sig import Signature -from ..core.utils import get_uid, OpKey - -from ..storages.rel_impls.utils import Transactable, transaction, Connection -from ..storages.kv import InMemoryStorage, MultiProcInMemoryStorage, KVCache - -if Config.has_duckdb: - from ..storages.rel_impls.duckdb_impl import DuckDBRelStorage -from ..storages.rel_impls.sqlite_impl import SQLiteRelStorage -from ..storages.rels import RelAdapter, VersionAdapter -from ..storages.sigs import SigSyncer -from ..storages.remote_storage import RemoteStorage -from ..deps.tracers import DecTracer -from ..deps.versioner import Versioner, CodeState -from ..deps.utils import get_dep_key_from_func, extract_func_obj -from ..deps.model import DepKey, TerminalData - -from ..queries.workflow import CallStruct -from ..queries.weaver import ( - ValNode, - CallNode, - traverse_all, -) -from ..queries.viz import ( - visualize_graph, - print_graph, - get_names, - extract_names_from_scope, - GraphPrinter, - ValueLoaderABC, -) -from ..queries.main import Querier -from ..queries.graphs import get_canonical_order, InducedSubgraph - -from ..core.prov import propagate_struct_provenance - - -class Storage(Transactable): - """ - Groups together all the components of the storage system. - - Responsible for things that require multiple components to work together, - e.g. - - committing: moving calls from the "temporary" partition to the "main" - partition. See also `CallStorage`. - - synchronizing: connecting an operation with the storage and performing - any necessary updates - """ - - def __init__( - self, - db_path: Optional[Union[str, Path]] = None, - db_backend: str = Config.db_backend, - spillover_dir: Optional[Union[str, Path]] = None, - spillover_threshold_mb: Optional[float] = None, - root: Optional[Union[Path, RemoteStorage]] = None, - timestamp: Optional[datetime.datetime] = None, - multiproc: bool = False, - evict_on_commit: bool = None, - signatures: Optional[Dict[Tuple[str, int], Signature]] = None, - _read_only: bool = False, - ### dependency tracking config - deps_path: Optional[Union[Path, str]] = None, - deps_package: Optional[str] = None, - track_methods: bool = True, - _strict_deps: bool = True, # for testing only - tracer_impl: Optional[type] = None, - ): - self.root = root - # all objects (inputs and outputs to operations, defaults) are saved here - # stores the memoization tables - if db_path is None and Config._persistent_storage_testing: - # get a temp db path - # generate a random filename - db_name = f"db_{get_uid()}.db" - db_path = Path( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - f"../../temp_dbs/{db_name}", - ) - ).resolve() - self.db_path = db_path - self.db_backend = db_backend - self.evict_on_commit = ( - Config.evict_on_commit if evict_on_commit is None else evict_on_commit - ) - if Config.has_duckdb and db_backend == "duckdb": - DBImplementation = DuckDBRelStorage - else: - DBImplementation = SQLiteRelStorage - self.rel_storage = DBImplementation( - address=None if db_path is None else str(db_path), - _read_only=_read_only, - ) - - # manipulates the memoization tables - self.rel_adapter = RelAdapter( - rel_storage=self.rel_storage, - spillover_dir=Path(spillover_dir) if spillover_dir is not None else None, - spillover_threshold_mb=spillover_threshold_mb, - ) - - self.cache = Cache(rel_adapter=self.rel_adapter) - - # self.versions_adapter = VersionAdapter(rel_adapter=rel_adapter) - self.sig_adapter = self.rel_adapter.sig_adapter - self.sig_syncer = SigSyncer(sig_adapter=self.sig_adapter, root=self.root) - if signatures is not None: - self.sig_adapter.dump_state(state=signatures) - self.last_timestamp = ( - timestamp if timestamp is not None else datetime.datetime.fromtimestamp(0) - ) - - self.version_adapter = VersionAdapter(rel_adapter=self.rel_adapter) - if deps_path is not None: - deps_path = ( - Path(deps_path).absolute().resolve() - if deps_path != "__main__" - else "__main__" - ) - roots = [] if deps_path == "__main__" else [deps_path] - self._versioned = True - current_versioner = self.version_adapter.load_state() - if current_versioner is not None: - if current_versioner.paths != roots: - raise ValueError( - f"Found existing versioner with roots {current_versioner.paths}, but " - f"was asked to use {roots}" - ) - else: - versioner = Versioner( - paths=roots, - TracerCls=DecTracer if tracer_impl is None else tracer_impl, - strict=_strict_deps, - track_methods=track_methods, - package_name=deps_package, - ) - self.version_adapter.dump_state(state=versioner) - else: - self._versioned = False - - if root is not None: - self.remote_manager = RemoteManager( - rel_adapter=self.rel_adapter, - sig_adapter=self.sig_adapter, - rel_storage=self.rel_storage, - sig_syncer=self.sig_syncer, - root=self.root, - ) - else: - self.remote_manager = None - - # set up builtins - for func_op in Builtins.OPS.values(): - self.synchronize_op(func_op=func_op) - - @property - def in_memory(self) -> bool: - return self.db_path is None - - @transaction() - def get_versioner(self, conn: Optional[Connection] = None) -> Versioner: - result = self.version_adapter.load_state(conn=conn) - if result is None: - raise ValueError("This storage is not versioned.") - return result - - @property - def versioned(self) -> bool: - return self._versioned - - ############################################################################ - ### `Transactable` interface - ############################################################################ - def _get_connection(self) -> Connection: - return self.rel_storage._get_connection() - - def _end_transaction(self, conn: Connection): - return self.rel_storage._end_transaction(conn=conn) - - def unwrap(self, obj: Any) -> Any: - with self.run(): - res = unwrap(obj) - return res - - @transaction() - def commit( - self, - calls: Optional[List[Call]] = None, - versioner: Optional[Versioner] = None, - conn: Optional[Connection] = None, - ): - """ - Flush calls and objs from the cache that haven't yet been written to the database. - """ - self.cache.commit( - calls=calls, - versioner=versioner, - version_adapter=self.version_adapter, - conn=conn, - ) - - @transaction() - def eval_df( - self, - full_uids_df: pd.DataFrame, - drop_duplicates: bool = False, - values: Literal["objs", "refs", "uids", "full_uids", "lazy"] = "objs", - conn: Optional[Connection] = None, - ) -> pd.DataFrame: - """ - - ! this function loads objects in the cache; this is probably not - transparent to the user - - Note that currently we pass the full UIDs as input, and thus we return - values with causal UIDs. Maybe it is desirable in some settings to - disable this behavior. - - Can handle nulls in the input dataframe. - """ - if values == "full_uids": - return full_uids_df - has_meta = set(Config.special_call_cols).issubset(full_uids_df.columns) - ref_cols = [ - col for col in full_uids_df.columns if col not in Config.special_call_cols - ] - uids_df = full_uids_df[ref_cols].applymap( - lambda uid: uid.rsplit(".", 1)[0] if uid is not None else None - ) - if has_meta: - uids_df[Config.special_call_cols] = full_uids_df[Config.special_call_cols] - if values in ("objs", "refs"): - uids_to_collect = [ - item - for _, column in uids_df.items() - for _, item in column.items() - if item is not None - ] - self.cache.preload_objs(uids_to_collect, conn=conn) - if values == "objs": - result = uids_df.applymap( - lambda uid: unwrap(self.cache.obj_get(uid)) - if uid is not None - else None - ) - else: - result = full_uids_df.applymap( - lambda full_uid: self.cache.obj_get( - obj_uid=full_uid.rsplit(".", 1)[0], - causal_uid=full_uid.rsplit(".", 1)[1], - ) - if full_uid is not None - else None - ) - elif values == "uids": - result = uids_df - elif values == "lazy": - result = full_uids_df[ref_cols].applymap( - lambda full_uid: Ref.from_uid( - uid=full_uid.rsplit(".", 1)[0], - causal_uid=full_uid.rsplit(".", 1)[1], - ) - if full_uid is not None - else None - ) - if has_meta: - result[Config.special_call_cols] = full_uids_df[ - Config.special_call_cols - ] - else: - raise ValueError( - f"Invalid value for `values`: {values}. Must be one of " - "['objs', 'refs', 'uids', 'lazy']" - ) - if drop_duplicates: - result = result.drop_duplicates() - return result - - @transaction() - def get_table( - self, - func_interface: Union["funcs.FuncInterface", Any], - meta: bool = False, - values: Literal["objs", "uids", "full_uids", "refs", "lazy"] = "objs", - drop_duplicates: bool = False, - conn: Optional[Connection] = None, - ) -> pd.DataFrame: - full_uids_df = self.rel_storage.get_data( - table=func_interface.func_op.sig.versioned_ui_name, conn=conn - ) - if not meta: - full_uids_df = full_uids_df.drop(columns=Config.special_call_cols) - df = self.eval_df( - full_uids_df=full_uids_df, - values=values, - drop_duplicates=drop_duplicates, - conn=conn, - ) - return df - - ############################################################################ - ### synchronization - ############################################################################ - @transaction() - def synchronize_op( - self, - func_op: FuncOp, - conn: Optional[Connection] = None, - ): - # first, pull the current data from the remote! - self.sig_syncer.sync_from_remote(conn=conn) - # this step also sends the signature to the remote - new_sig = self.sig_syncer.sync_from_local(sig=func_op.sig, conn=conn) - func_op.sig = new_sig - # to send any default values that were created by adding inputs - self.sync_to_remote(conn=conn) - - @transaction() - def synchronize( - self, f: Union["funcs.FuncInterface", Any], conn: Optional[Connection] = None - ): - if f._is_invalidated: - raise RuntimeError( - "This function has been invalidated due to a change in the signature, and cannot be called" - ) - # if f._is_synchronized: - # if f._storage_id != id(self): - # raise RuntimeError( - # "This function is already synchronized with a different storage object. Re-define the function to synchronize it with this storage object." - # ) - # return - self.synchronize_op(func_op=f.func_op, conn=conn) - f._is_synchronized = True - f._storage_id = id(self) - - ############################################################################ - ### versioning - ############################################################################ - @transaction() - def guess_code_state( - self, versioner: Optional[Versioner] = None, conn: Optional[Connection] = None - ) -> CodeState: - if versioner is None: - versioner = self.get_versioner(conn=conn) - return versioner.guess_code_state() - - @transaction() - def sync_code( - self, conn: Optional[Connection] = None - ) -> Tuple[Versioner, CodeState]: - versioner = self.get_versioner(conn=conn) - code_state = self.guess_code_state(versioner=versioner, conn=conn) - versioner.sync_codebase(code_state=code_state) - return versioner, code_state - - @transaction() - def sync_component( - self, - component: types.FunctionType, - is_semantic_change: Optional[bool], - conn: Optional[Connection] = None, - ): - # low-level versioning - dep_key = get_dep_key_from_func(func=component) - versioner = self.get_versioner(conn=conn) - code_state = self.guess_code_state(versioner=versioner, conn=conn) - result = versioner.sync_component( - component=dep_key, - is_semantic_change=is_semantic_change, - code_state=code_state, - ) - self.version_adapter.dump_state(state=versioner, conn=conn) - return result - - @transaction() - def _show_version_data( - self, - f: Union[Callable, "funcs.FuncInterface"], - deps: bool = True, - meta: bool = False, - plain: bool = False, - compact: bool = False, - conn: Optional[Connection] = None, - ): - # show the versions of a function, with/without its dependencies - func = extract_func_obj(obj=f, strict=True) - component = get_dep_key_from_func(func=func) - versioner = self.get_versioner(conn=conn) - if deps: - versioner.show_versions( - component=component, - include_metadata=meta, - plain=plain, - ) - else: - versioner.component_dags[component].show( - compact=compact, plain=plain, include_metadata=meta - ) - - @transaction() - def versions( - self, - f: Union[Callable, "funcs.FuncInterface"], - meta: bool = False, - plain: bool = False, - conn: Optional[Connection] = None, - ): - self._show_version_data( - f=f, - deps=True, - meta=meta, - plain=plain, - compact=False, - conn=conn, - ) - - @transaction() - def sources( - self, - f: Union[Callable, "funcs.FuncInterface"], - meta: bool = False, - plain: bool = False, - compact: bool = False, - conn: Optional[Connection] = None, - ): - func = extract_func_obj(obj=f, strict=True) - component = get_dep_key_from_func(func=func) - versioner = self.get_versioner(conn=conn) - print( - f"Revision history for the source code of function {component[1]} from module {component[0]} " - '("===HEAD===" is the current version):' - ) - versioner.component_dags[component].show( - compact=compact, plain=plain, include_metadata=meta - ) - - @transaction() - def code( - self, version_id: str, meta: bool = False, conn: Optional[Connection] = None - ): - # show a copy-pastable version of the code for a given version id. Plain - # by design. - result = self.get_code(version_id=version_id, show=False, meta=meta, conn=conn) - print(result) - - @transaction() - def get_code( - self, - version_id: str, - show: bool = True, - meta: bool = False, - conn: Optional[Connection] = None, - ) -> str: - versioner = self.get_versioner(conn=conn) - for dag in versioner.component_dags.values(): - if version_id in dag.commits.keys(): - text = dag.get_content(commit=version_id) - if show: - print(text) - return text - for ( - content_version, - version, - ) in versioner.get_flat_versions().items(): - if version_id == content_version: - raw_string = versioner.present_dependencies( - commits=version.semantic_expansion, - include_metadata=meta, - ) - if show: - print(raw_string) - return raw_string - raise ValueError(f"version id {version_id} not found") - - @transaction() - def diff( - self, - id_1: str, - id_2: str, - context_lines: int = 2, - conn: Optional[Connection] = None, - ): - code_1: str = self.get_code(version_id=id_1, show=False, conn=conn) - code_2: str = self.get_code(version_id=id_2, show=False, conn=conn) - print( - _get_colorized_diff(current=code_1, new=code_2, context_lines=context_lines) - ) - - ############################################################################ - ### make calls in contexts - ############################################################################ - @transaction() - def _load_memoization_tables( - self, evaluate: bool = False, conn: Optional[Connection] = None - ) -> Dict[str, pd.DataFrame]: - """ - Get a dict of {versioned internal name: memoization table} for all - functions. Note that memoization tables are labeled by UI arg names. - """ - sigs = self.sig_adapter.load_state(conn=conn) - ui_to_internal = { - sig.versioned_ui_name: sig.versioned_internal_name for sig in sigs.values() - } - ui_call_data = self.rel_adapter.get_all_call_data(conn=conn) - call_data = {ui_to_internal[k]: v for k, v in ui_call_data.items()} - if evaluate: - call_data = { - k: self.eval_df(full_uids_df=v, values="objs", conn=conn) - for k, v in call_data.items() - } - return call_data - - @transaction() - def get_compatible_semantic_versions( - self, - fqs: Set[CallNode], - conn: Optional[Connection] = None, - ) -> Tuple[Optional[Dict[OpKey, Set[str]]], Optional[Dict[DepKey, Set[str]]]]: - if not self.versioned: - return None, None - if contexts.GlobalContext.current is not None: - code_state = contexts.GlobalContext.current._code_state - else: - code_state = self.guess_code_state( - versioner=self.get_versioner(), conn=conn - ) - result_ops = {} - result_deps = {} - versioner = self.get_versioner(conn=conn) - for func_query in fqs: - sig = func_query.func_op.sig - op_key = (sig.internal_name, sig.version) - # dep_key = get_dep_key_from_func(func=func_query.func_op.func) - dep_key = (func_query.func_op._module, func_query.func_op._qualname) - if func_query.func_op._is_builtin: - result_ops[op_key] = None - result_deps[dep_key] = None - else: - versions = versioner.get_semantically_compatible_versions( - component=dep_key, code_state=code_state - ) - result_ops[op_key] = set([v.semantic_version for v in versions]) - result_deps[dep_key] = result_ops[op_key] - return result_ops, result_deps - - @transaction() - def execute_query( - self, - selection: List[ValNode], - vqs: Set[ValNode], - fqs: Set[CallNode], - names: Dict[ValNode, str], - values: Literal["objs", "refs", "uids", "lazy"] = "objs", - engine: Optional[Literal["sql", "naive", "_test"]] = None, - local: bool = False, - verbose: bool = True, - drop_duplicates: bool = True, - visualize_steps_at: Optional[Path] = None, - conn: Optional[Connection] = None, - ) -> pd.DataFrame: - """ - Execute the given queries and return the result as a pandas DataFrame. - """ - if engine is None: - engine = Config.query_engine - - def rename_cols( - df: pd.DataFrame, selection: List[ValNode], names: Dict[ValNode, str] - ): - df.columns = [str(i) for i in range(len(df.columns))] - cols = [names[query] for query in selection] - df.rename(columns=dict(zip(df.columns, cols)), inplace=True) - - Querier.validate_query(vqs=vqs, fqs=fqs, selection=selection, names=names) - context = contexts.GlobalContext.current - if verbose: - print( - "Pattern-matching to the following computational graph (all constraints apply):" - ) - print_graph(vqs=vqs, fqs=fqs, names=names, selection=selection) - if visualize_steps_at is not None: - assert engine == "naive" - if engine in ["sql", "_test"]: - version_constraints, _ = self.get_compatible_semantic_versions( - fqs=fqs, conn=conn - ) - call_uids = context._call_uids if local else None - query = Querier.compile( - selection=selection, - vqs=vqs, - fqs=fqs, - version_constraints=version_constraints, - filter_duplicates=drop_duplicates, - call_uids=call_uids, - ) - start = time.time() - sql_uids_df = self.rel_storage.execute_df(query=str(query), conn=conn) - end = time.time() - logger.debug(f"SQL query took {round(end - start, 3)} seconds") - rename_cols(df=sql_uids_df, selection=selection, names=names) - uids_df = sql_uids_df - if engine in ["naive", "_test"]: - memoization_tables = self._load_memoization_tables(conn=conn) - logger.debug("Executing query naively...") - naive_uids_df = Querier.execute_naive( - vqs=vqs, - fqs=fqs, - selection=selection, - memoization_tables=memoization_tables, - filter_duplicates=drop_duplicates, - table_evaluator=self.eval_df, - visualize_steps_at=visualize_steps_at, - ) - rename_cols(df=naive_uids_df, selection=selection, names=names) - uids_df = naive_uids_df - if engine == "_test": - outcome, reason = compare_dfs_as_relations( - df_1=sql_uids_df, df_2=naive_uids_df, return_reason=True - ) - assert outcome, reason - return self.eval_df(full_uids_df=uids_df, values=values, conn=conn) - - def _get_graph_and_names( - self, - objs: Tuple[Union[Ref, ValNode]], - direction: Literal["forward", "backward", "both"] = "both", - scope: Optional[Dict[str, Any]] = None, - project: bool = False, - ): - vqs = {obj.query if isinstance(obj, Ref) else obj for obj in objs} - vqs, fqs = traverse_all(vqs=vqs, direction=direction) - hints = extract_names_from_scope(scope=scope) if scope is not None else {} - g = InducedSubgraph(vqs=vqs, fqs=fqs) - if project: - v_proj, f_proj, _ = g.project() - proj_hints = {v_proj[vq]: hints[vq] for vq in v_proj if vq in hints} - names = get_names( - hints=proj_hints, - canonical_order=get_canonical_order( - vqs=set(v_proj.values()), fqs=set(f_proj.values()) - ), - ) - final_vqs = set(v_proj.values()) - final_fqs = set(f_proj.values()) - else: - v_proj = {vq: vq for vq in vqs} - f_proj = {fq: fq for fq in fqs} - names = get_names( - hints=hints, - canonical_order=get_canonical_order(vqs=set(vqs), fqs=set(fqs)), - ) - final_vqs = vqs - final_fqs = fqs - return final_vqs, final_fqs, names, v_proj, f_proj - - def draw_graph( - self, - *objs: Union[Ref, ValNode], - traverse: Literal["forward", "backward", "both"] = "backward", - project: bool = False, - show_how: Literal["none", "browser", "inline", "open"] = "browser", - ): - scope = inspect.currentframe().f_back.f_locals - vqs, fqs, names, v_proj, f_proj = self._get_graph_and_names( - objs, - direction=traverse, - scope=scope, - project=project, - ) - visualize_graph(vqs, fqs, names=names, show_how=show_how) - - def print_graph( - self, - *objs: Union[Ref, ValNode], - project: bool = False, - traverse: Literal["forward", "backward", "both"] = "backward", - ): - scope = inspect.currentframe().f_back.f_locals - vqs, fqs, names, v_proj, f_proj = self._get_graph_and_names( - objs, - direction=traverse, - scope=scope, - project=project, - ) - print_graph( - vqs=vqs, - fqs=fqs, - names=names, - selection=[ - v_proj[obj.query] if isinstance(obj, Ref) else obj for obj in objs - ], - ) - - @transaction() - def similar( - self, - *objs: Union[Ref, ValNode], - values: Literal["objs", "refs", "uids", "lazy"] = "objs", - context: bool = False, - verbose: Optional[bool] = None, - local: bool = False, - drop_duplicates: bool = True, - engine: Literal["sql", "naive", "_test"] = None, - _visualize_steps_at: Optional[Path] = None, - conn: Optional[Connection] = None, - ) -> pd.DataFrame: - scope = inspect.currentframe().f_back.f_back.f_locals - return self.df( - *objs, - direction="backward", - scope=scope, - values=values, - context=context, - skip_objs=False, - verbose=verbose, - local=local, - drop_duplicates=drop_duplicates, - engine=engine, - _visualize_steps_at=_visualize_steps_at, - conn=conn, - ) - - @transaction() - def df( - self, - *objs: Union[Ref, ValNode], - direction: Literal["forward", "backward", "both"] = "both", - values: Literal["objs", "refs", "uids", "lazy"] = "objs", - context: bool = False, - skip_objs: bool = False, - verbose: Optional[bool] = None, - local: bool = False, - drop_duplicates: bool = True, - engine: Literal["sql", "naive", "_test"] = None, - _visualize_steps_at: Optional[Path] = None, - scope: Optional[Dict[str, Any]] = None, - conn: Optional[Connection] = None, - ) -> pd.DataFrame: - """ - Universal query method over computational graphs, both imperative and - declarative. - """ - if verbose is None: - verbose = Config.verbose_queries - if not all(isinstance(obj, (Ref, ValNode)) for obj in objs): - raise ValueError( - "All arguments to df() must be either `Ref`s or `ValQuery`s." - ) - #! important - # We must sync any dirty cache elements to the db before performing a query. - # If we don't, we'll query a store that might be missing calls and objs. - self.commit(versioner=None) - selection = [obj.query if isinstance(obj, Ref) else obj for obj in objs] - # deps = get_deps(nodes=set(selection)) - vqs, fqs = traverse_all(vqs=set(selection), direction=direction) - if scope is None: - scope = inspect.currentframe().f_back.f_back.f_locals - name_hints = extract_names_from_scope(scope=scope) - v_map, f_map, target_selection, target_names = Querier.prepare_projection_query( - vqs=vqs, fqs=fqs, selection=selection, name_hints=name_hints - ) - target_vqs, target_fqs = set(v_map.values()), set(f_map.values()) - if context: - g = InducedSubgraph(vqs=target_vqs, fqs=target_fqs) - _, _, topsort = g.canonicalize() - target_selection = [vq for vq in topsort if isinstance(vq, ValNode)] - df = self.execute_query( - selection=target_selection, - vqs=set(v_map.values()), - fqs=set(f_map.values()), - values=values, - names=target_names, - verbose=verbose, - drop_duplicates=drop_duplicates, - visualize_steps_at=_visualize_steps_at, - engine=engine, - local=local, - conn=conn, - ) - for col in df.columns: - try: - df = df.sort_values(by=col) - except Exception: - continue - if skip_objs: - # drop the dtypes that are objects - df = df.select_dtypes(exclude=["object"]) - return df - - def _make_terminal_data(self, func_op: FuncOp, call: Call) -> TerminalData: - terminal_data = TerminalData( - op_internal_name=func_op.sig.internal_name, - op_version=func_op.sig.version, - call_content_version=call.content_version, - call_semantic_version=call.semantic_version, - dep_key=get_dep_key_from_func(func=func_op.func), - ) - return terminal_data - - @transaction() - def lookup_call( - self, - func_op: FuncOp, - pre_call_uid: str, - input_uids: Dict[str, str], - input_causal_uids: Dict[str, str], - code_state: Optional[CodeState] = None, - versioner: Optional[Versioner] = None, - conn: Optional[Connection] = None, - ) -> Optional[Call]: - """ - Return a *detached* call for the given function and inputs, if it - exists. - """ - if not self.versioned: - semantic_version = None - else: - assert code_state is not None - component = get_dep_key_from_func(func=func_op.func) - lookup_outcome = versioner.lookup_call( - component=component, pre_call_uid=pre_call_uid, code_state=code_state - ) - if lookup_outcome is None: - return - else: - _, semantic_version = lookup_outcome - causal_uid = func_op.get_call_causal_uid( - input_uids=input_uids, - input_causal_uids=input_causal_uids, - semantic_version=semantic_version, - ) - if self.cache.call_exists(uid=causal_uid, by_causal=True): - return self.cache.call_get(uid=causal_uid, by_causal=True, lazy=True) - call_uid = func_op.get_call_uid( - pre_call_uid=pre_call_uid, semantic_version=semantic_version - ) - if self.cache.call_exists(uid=call_uid, by_causal=False): - return self.cache.call_get(uid=call_uid, by_causal=False, lazy=True) - return None - - def call_batch( - self, func_op: FuncOp, inputs: Dict[str, Ref] - ) -> Tuple[List[Ref], CallStruct]: - output_types = [Type.from_annotation(a) for a in func_op.output_annotations] - outputs = [make_delayed(tp=tp) for tp in output_types] - call_struct = CallStruct(func_op=func_op, inputs=inputs, outputs=outputs) - return outputs, call_struct - - ############################################################################ - ### low-level provenance interfaces - ############################################################################ - @transaction() - def get_creators( - self, - refs: List[Ref], - prov_df: Optional[pd.DataFrame] = None, - conn: Optional[Connection] = None, - ) -> Tuple[List[Optional[Call]], List[Optional[str]]]: - """ - Given some Refs, return the - - calls that created them (there may be at most one such call per Ref), - or None if there was no such call. - - the output name under which the refs were created - - ! This currently fails for refs created as a list of inputs. - """ - if not refs: - return [], [] - if prov_df is None: - prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn) - prov_df = propagate_struct_provenance(prov_df=prov_df) - causal_uids = list([ref.causal_uid for ref in refs]) - assert all(x is not None for x in causal_uids) - res_df = prov_df.query('causal in @causal_uids and direction_new == "output"')[ - ["causal", "call_causal", "name", "op_id"] - ].set_index("causal")[["call_causal", "name", "op_id"]] - if len(res_df) == 0: - return [None] * len(refs), [None] * len(refs) - if not res_df.index.is_unique: - logging.warning( - "Work in progress: Detected ref w/ multiple creators (this happens when a data structure is created explicitly out of its elements), choosing arbitrary creator." - ) - causal_to_creator_call_uid = res_df["call_causal"].to_dict() - causal_to_output_name = res_df["name"].to_dict() - causal_to_op_id = res_df["op_id"].to_dict() - op_groups = res_df.groupby("op_id")["call_causal"].apply(list).to_dict() - call_causal_to_call = {} - for op_id, call_causal_list in op_groups.items(): - internal_name, version = Signature.parse_versioned_name( - versioned_name=op_id - ) - versioned_ui_name = self.sig_adapter.load_state(conn=conn)[ - internal_name, version - ].versioned_ui_name - op_calls = self.cache.call_mget( - uids=call_causal_list, - by_causal=True, - versioned_ui_name=versioned_ui_name, - conn=conn, - ) - call_causal_to_call.update({call.causal_uid: call for call in op_calls}) - calls = [ - call_causal_to_call[causal_to_creator_call_uid[causal_uid]] - if causal_uid in causal_to_creator_call_uid - else None - for causal_uid in causal_uids - ] - output_names = [ - causal_to_output_name[causal_uid] - if causal_uid in causal_to_output_name - else None - for causal_uid in causal_uids - ] - sess.d() - return calls, output_names - - @transaction() - def get_consumers( - self, - refs: List[Ref], - prov_df: Optional[pd.DataFrame] = None, - conn: Optional[Connection] = None, - ) -> Tuple[List[List[Call]], List[List[str]]]: - """ - Given some Refs, return the - - calls that use them (there may be multiple such calls per Ref), or an empty list if there were no such calls. - - the input names under which the refs were used - """ - if prov_df is None: - prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn) - prov_df = propagate_struct_provenance(prov_df=prov_df) - causal_uids = [ref.causal_uid for ref in refs] - assert all(x is not None for x in causal_uids) - res_groups = prov_df.query( - 'causal in @causal_uids and direction_new == "input"' - )[["causal", "call_causal", "name", "op_id"]] - op_to_causal_to_call_uids_and_inp_names = defaultdict(dict) - for causal, call_causal, name, op_id in res_groups.itertuples(index=False): - if causal not in op_to_causal_to_call_uids_and_inp_names[op_id]: - op_to_causal_to_call_uids_and_inp_names[op_id][causal] = [] - op_to_causal_to_call_uids_and_inp_names[op_id][causal].append( - (call_causal, name) - ) - op_id_to_versioned_ui_name = {} - for op_id in op_to_causal_to_call_uids_and_inp_names.keys(): - internal_name, version = Signature.parse_versioned_name( - versioned_name=op_id - ) - op_id_to_versioned_ui_name[op_id] = self.sig_adapter.load_state(conn=conn)[ - internal_name, version - ].versioned_ui_name - op_to_causal_to_calls_and_inp_names = defaultdict(dict) - for ( - op_id, - causal_to_call_uids_and_inp_names, - ) in op_to_causal_to_call_uids_and_inp_names.items(): - versioned_ui_name = op_id_to_versioned_ui_name[op_id] - op_calls = self.cache.call_mget( - uids=[ - elt[0] - for v in causal_to_call_uids_and_inp_names.values() - for elt in v - ], - by_causal=True, - versioned_ui_name=versioned_ui_name, - conn=conn, - ) - call_causal_to_call = {call.causal_uid: call for call in op_calls} - op_to_causal_to_calls_and_inp_names[op_id] = { - causal: [ - (call_causal_to_call[call_causal], name) - for call_causal, name in call_causal_list - ] - for causal, call_causal_list in causal_to_call_uids_and_inp_names.items() - } - concat_lists = lambda l: [elt for sublist in l for elt in sublist] - calls = [ - concat_lists( - [ - [ - v[0] - for v in op_to_causal_to_calls_and_inp_names[op_id].get( - causal_uid, [] - ) - ] - for op_id in op_to_causal_to_calls_and_inp_names.keys() - ] - ) - for causal_uid in causal_uids - ] - input_names = [ - concat_lists( - [ - [ - v[1] - for v in op_to_causal_to_calls_and_inp_names[op_id].get( - causal_uid, [] - ) - ] - for op_id in op_to_causal_to_calls_and_inp_names.keys() - ] - ) - for causal_uid in causal_uids - ] - return calls, input_names - - @transaction() - def get_dependent_calls( - self, - refs: List[Ref], - prov_df: Optional[pd.DataFrame] = None, - conn: Optional[Connection] = None, - ) -> List[Call]: - """ - Get all calls that depend on the given refs. - """ - if prov_df is None: - prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn) - prov_df = propagate_struct_provenance(prov_df=prov_df) - res = {} - current = refs - while True: - calls_list, _ = self.get_consumers(refs=current, prov_df=prov_df, conn=conn) - if not calls_list: - break - for calls in calls_list: - for call in calls: - res[call.causal_uid] = call - current = list( - { - ref.causal_uid: ref - for calls in calls_list - for call in calls - for ref in call.outputs - }.values() - ) - return list(res.values()) - - ############################################################################ - ### provenance - ############################################################################ - @transaction() - def prov( - self, - ref: Ref, - conn: Optional[Connection] = None, - uids_only: bool = False, - debug: bool = False, - ): - prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn) - prov_df = prov_df.set_index([Provenance.causal_uid, Provenance.direction]) - x = provenance.ProvHelpers(storage=self, prov_df=prov_df) - val_nodes, call_nodes = x.get_graph(full_uid=ref.full_uid) - show_sources_as = "values" if not uids_only else "uids" - printer = GraphPrinter( - vqs=val_nodes, - fqs=call_nodes, - names=None, - value_loader=ValueLoader(storage=self), - ) - print(printer.print_computational_graph(show_sources_as=show_sources_as)) - if debug: - visualize_graph( - vqs=val_nodes, fqs=call_nodes, names=None, show_how="browser" - ) - - ############################################################################ - ### spawning contexts - ############################################################################ - def _nest(self, **updates) -> contexts.Context: - if contexts.GlobalContext.current is not None: - return contexts.GlobalContext.current(**updates) - else: - result = contexts.Context(**updates) - contexts.GlobalContext.current = result - return result - - def __call__(self, **updates) -> contexts.Context: - return self.run(**updates) - - def run( - self, - allow_calls: bool = True, - debug_calls: bool = False, - attach_call_to_outputs: bool = False, - recompute_transient: bool = False, - lazy: Optional[bool] = None, - **updates, - ) -> contexts.Context: - # spawn context to execute or retrace calls - lazy = not self.in_memory if lazy is None else lazy - return self._nest( - storage=self, - allow_calls=allow_calls, - debug_calls=debug_calls, - recompute_transient=recompute_transient, - _attach_call_to_outputs=attach_call_to_outputs, - mode=MODES.run, - lazy=lazy, - **updates, - ) - - def query(self, **updates) -> contexts.Context: - # spawn a context to define a query - return self._nest( - storage=self, - mode=MODES.query, - **updates, - ) - - def batch(self, **updates) -> contexts.Context: - # spawn a context to execute calls in batch - return self._nest( - storage=self, - mode=MODES.batch, - **updates, - ) - - def noop(self) -> contexts.Context: - return self._nest( - storage=self, - mode=MODES.noop, - ) - - ############################################################################ - ### remote sync operations - ############################################################################ - @transaction() - def sync_from_remote(self, conn: Optional[Connection] = None): - if self.remote_manager is not None: - self.remote_manager.sync_from_remote(conn=conn) - - @transaction() - def sync_to_remote(self, conn: Optional[Connection] = None): - if self.remote_manager is not None: - self.remote_manager.sync_to_remote(conn=conn) - - @transaction() - def sync_with_remote(self, conn: Optional[Connection] = None): - if self.remote_manager is not None: - self.sync_to_remote(conn=conn) - self.sync_from_remote(conn=conn) - - ############################################################################ - ### refactoring - ############################################################################ - @property - def is_clean(self) -> bool: - """ - Check that the storage has no uncommitted calls or objects. - """ - return ( - self.cache.call_cache_by_causal.is_clean and self.cache.obj_cache.is_clean - ) - - def _check_rename_precondition(self, func: "funcs.FuncInterface"): - """ - In order to rename function data, the function must be synced with the - storage, and the storage must be clean - """ - if not func._is_synchronized: - raise RuntimeError("Cannot rename while function is not synchronized.") - if not self.is_clean: - raise RuntimeError("Cannot rename while there is uncommited work.") - - @transaction() - def rename_func( - self, - func: "funcs.FuncInterface", - new_name: str, - conn: Optional[Connection] = None, - ) -> Signature: - """ - Rename a memoized function. - - What happens here: - - check renaming preconditions - - check there is no name clash with the new name - - rename the memoization table - - update signature object - - invalidate the function (making it impossible to compute with it) - """ - self._check_rename_precondition(func=func) - sig = self.sig_syncer.sync_rename_sig( - sig=func.func_op.sig, new_name=new_name, conn=conn - ) - func.invalidate() - return sig - - @transaction() - def rename_arg( - self, - func: "funcs.FuncInterface", - name: str, - new_name: str, - conn: Optional[Connection] = None, - ) -> Signature: - """ - Rename memoized function argument. - - What happens here: - - check renaming preconditions - - update signature object - - rename table - - invalidate the function (making it impossible to compute with it) - """ - self._check_rename_precondition(func=func) - sig = self.sig_syncer.sync_rename_input( - sig=func.func_op.sig, input_name=name, new_input_name=new_name, conn=conn - ) - func.invalidate() - return sig - - -from . import funcs -from . import provenance - -FuncInterface = funcs.FuncInterface - - -TP_TO_CLS = { - AnyType: ValueRef, - ListType: ListRef, - DictType: DictRef, - SetType: SetRef, -} - - -def make_delayed(tp: Type) -> Ref: - return TP_TO_CLS[type(tp)](uid="", obj=Delayed(), in_memory=False) - - -class ValueLoader(ValueLoaderABC): - def __init__(self, storage: Storage): - self.storage = storage - - def load_value(self, full_uid: str) -> Any: - uid, _ = Ref.parse_full_uid(full_uid) - return self.storage.rel_adapter.obj_get(uid=uid) diff --git a/mandala/ui/utils.py b/mandala/ui/utils.py deleted file mode 100644 index 7f0ec04..0000000 --- a/mandala/ui/utils.py +++ /dev/null @@ -1,115 +0,0 @@ -from ..common_imports import * -from ..core.model import FuncOp, Ref, wrap_atom, Call -from ..core.wrapping import unwrap, causify_atom -from ..core.config import Config, MODES -from ..queries.weaver import ValNode, qwrap, prepare_query -from ..deps.model import TerminalData -from ..deps.utils import get_dep_key_from_func -from textwrap import shorten - - -T = TypeVar("T") - - -def check_determinism( - observed_semver: Optional[str], - stored_semver: Optional[str], - stored_output_uids: List[str], - observed_output_uids: List[str], - func_op: FuncOp, -): - # check deterministic behavior - if stored_semver != observed_semver: - raise ValueError( - f"Detected non-deterministic dependencies for function " - f"{func_op.sig.ui_name} after recomputation of transient values." - ) - if len(stored_output_uids) != len(observed_output_uids): - raise ValueError( - f"Detected non-deterministic number of outputs for function " - f"{func_op.sig.ui_name} after recomputation of transient values." - ) - if observed_output_uids != stored_output_uids: - raise ValueError( - f"Detected non-deterministic outputs for function " - f"{func_op.sig.ui_name} after recomputation of transient values. " - f"{observed_output_uids} != {stored_output_uids}" - ) - - -def get_terminal_data(func_op: FuncOp, call: Call) -> TerminalData: - return TerminalData( - op_internal_name=func_op.sig.internal_name, - op_version=func_op.sig.version, - call_content_version=call.content_version, - call_semantic_version=call.semantic_version, - dep_key=get_dep_key_from_func(func=func_op.func), - ) - - -def wrap_ui(obj: T, recurse: bool = True) -> T: - if isinstance(obj, Ref): - return obj - elif type(obj) in (list, tuple): - if recurse: - return type(obj)(wrap_ui(v, recurse=recurse) for v in obj) - else: - return obj - elif type(obj) is dict: - if recurse: - return {k: wrap_ui(v, recurse=recurse) for k, v in obj.items()} - else: - return obj - else: - res = wrap_atom(obj) - causify_atom(ref=res) - return res - - -def bind_inputs(args, kwargs, mode: str, func_op: FuncOp) -> Dict[str, Any]: - """ - Given args and kwargs passed by the user from python, this adds defaults - and returns a dict where they are indexed via internal names. - """ - if mode == MODES.query: - bound_args = func_op.py_sig.bind_partial(*args, **kwargs) - inputs_dict = dict(bound_args.arguments) - input_tps = func_op.input_types - inputs_dict = { - k: qwrap(obj=v, tp=input_tps[k], strict=True) - for k, v in inputs_dict.items() - } - else: - bound_args = func_op.py_sig.bind(*args, **kwargs) - bound_args.apply_defaults() - inputs_dict = dict(bound_args.arguments) - return inputs_dict - return inputs_dict - - -def format_as_outputs( - outputs: Union[List[Ref], List[ValNode]] -) -> Union[None, Any, Tuple[Any]]: - if len(outputs) == 0: - return None - elif len(outputs) == 1: - return outputs[0] - else: - return tuple(outputs) - - -def debug_call( - func_name: str, - memoized: bool, - wrapped_inputs: Dict[str, Ref], - wrapped_outputs: List[Ref], - io_truncate: Optional[int] = 20, -): - shortener = lambda s: shorten( - repr(unwrap(s)), width=io_truncate, break_long_words=True - ) - inputs_str = ", ".join(f"{k}={shortener(v)}" for k, v in wrapped_inputs.items()) - outputs_str = ", ".join(shortener(v) for v in wrapped_outputs) - logging.info( - f'{"(memoized)" if memoized else ""}: {func_name}({inputs_str}) ---> {outputs_str}' - ) diff --git a/mandala/ui/viz.py b/mandala/ui/viz.py deleted file mode 100644 index 8326bb9..0000000 --- a/mandala/ui/viz.py +++ /dev/null @@ -1,375 +0,0 @@ -""" -A minimal OOP wrapper around dot/graphviz -""" -import difflib -from ..common_imports import * -from ..core.config import Config -import tempfile -import subprocess -import webbrowser -from typing import Literal - -if Config.has_pil: - from PIL import Image - - -class Color: - def __init__(self, r: int, g: int, b: int, opacity: float = 1.0): - self.r, self.g, self.b, self.opacity = r, g, b, opacity - - def __str__(self) -> str: - opacity_int = int(self.opacity * 255) - return f"#{self.r:02x}{self.g:02x}{self.b:02x}{opacity_int:02x}" - - -SOLARIZED_LIGHT = { - "base03": Color(0, 43, 54, 1), - "base02": Color(7, 54, 66, 1), - "base01": Color(88, 110, 117, 1), - "base00": Color(101, 123, 131, 1), - "base0": Color(131, 148, 150, 1), - "base1": Color(147, 161, 161, 1), - "base2": Color(238, 232, 213, 1), - "base3": Color(253, 246, 227, 1), - "yellow": Color(181, 137, 0, 1), - "orange": Color(203, 75, 22, 1), - "red": Color(220, 50, 47, 1), - "magenta": Color(211, 54, 130, 1), - "violet": Color(108, 113, 196, 1), - "blue": Color(38, 139, 210, 1), - "cyan": Color(42, 161, 152, 1), - "green": Color(133, 153, 0, 1), -} - - -def _colorize(text: str, color: str) -> str: - """ - Return `text` with ANSI color codes for `color` added. - """ - colors = { - "red": 31, - "green": 32, - "blue": 34, - "yellow": 33, - "magenta": 35, - "cyan": 36, - "white": 37, - } - return f"\033[{colors[color]}m{text}\033[0m" - - -def _get_diff(current: str, new: str) -> str: - """ - Return a line-by-line diff of the changes between `current` and `new`. each - line removed from `current` is prefixed with a '-', and each line added to - `new` is prefixed with a '+'. - """ - lines = [] - for line in difflib.unified_diff( - current.splitlines(), - new.splitlines(), - n=2, # number of lines of context around changes to show - # fromfile="current", tofile="new" - lineterm="", - ): - if line.startswith("@@") or line.startswith("+++") or line.startswith("---"): - continue - lines.append(line) - return "\n".join(lines) - - -def _get_colorized_diff( - current: str, new: str, style: str = "multiline", context_lines: int = 2 -) -> str: - """ - Return a line-by-line colorized diff of the changes between `current` and - `new`. each line removed from `current` is colored red, and each line added - to `new` is colored green. - """ - lines = [] - for line in difflib.unified_diff( - current.splitlines(), - new.splitlines(), - n=context_lines, # number of lines of context around changes to show - # fromfile="current", tofile="new" - lineterm="", - ): - if line.startswith("@@") or line.startswith("+++") or line.startswith("---"): - continue - if line.startswith("-"): - if style == "inline": - line = line[1:] - lines.append(_colorize(line, "red")) - elif line.startswith("+"): - if style == "inline": - line = line[1:] - lines.append(_colorize(line, "green")) - else: - lines.append(line) - if style == "multiline": - return "\n".join(lines) - elif style == "inline": - return " ---> ".join(lines) - else: - raise ValueError(f"Unknown style: {style}") - - -################################################################################ -### tiny graphviz model -################################################################################ -class Cell: - def __init__( - self, - text: str, - port: Optional[str] = None, - colspan: int = 1, - bgcolor: Color = SOLARIZED_LIGHT["base3"], - border_color: Color = SOLARIZED_LIGHT["base03"], # border color - font_color: Optional[Color] = None, - bold: bool = False, - ): - self.port = port - self.text = text - self.colspan = colspan - self.bgcolor = bgcolor - self.border_color = border_color - self.font_color = font_color - self.bold = bold - - def to_dot_string(self) -> str: - text_str = self.text - if self.bold: - text_str = f"{text_str}" - if self.font_color is not None: - text_str = f'{text_str}' - return f'{text_str}' - - -class HTMLBuilder: - def __init__(self): - self.rows: List[List[Cell]] = [] - - def add_row(self, cells: List[Cell]): - self.rows.append(cells) - - def to_html_like_label(self) -> str: - start = '' - end = "
" - # broadcast colspan - row_sizes = set([len(row) for row in self.rows]) - lcm = np.lcm.reduce(list(row_sizes)) - for row in self.rows: - elt_colspan = lcm // len(row) - for cell in row: - cell.colspan = elt_colspan - - rows = [] - for row in self.rows: - rows.append("") - for cell in row: - rows.append(cell.to_dot_string()) - rows.append("") - return start + "".join(rows) + end - - -class Node: - def __init__( - self, - internal_name: str, - label: str, - color: Color = SOLARIZED_LIGHT["base3"], - shape: str = "rect", - ): - """ - `shape` can be "rect", "record" or "Mrecord" for a record with rounded corners. - """ - self.internal_name = internal_name - self.label = label - self.color = color - self.shape = shape - - def to_dot_string(self) -> str: - dot_label = f'"{self.label}"' if self.shape != "plain" else f"<{self.label}>" - return f'"{self.internal_name}" [label={dot_label}, color="{self.color}", shape="{self.shape}"];' - - -class Edge: - def __init__( - self, - source_node: Node, - target_node: Node, - source_port: Optional[str] = None, - target_port: Optional[str] = None, - arrowtail: Optional[str] = None, - arrowhead: Optional[str] = None, - label: str = "", - color: Color = SOLARIZED_LIGHT["base03"], - ): - self.source_node = source_node - self.target_node = target_node - self.color = color - self.label = label - self.source_port = source_port - self.target_port = target_port - self.arrowtail = arrowtail - self.arrowhead = arrowhead - - def to_dot_string(self) -> str: - source = f'"{self.source_node.internal_name}"' - target = f'"{self.target_node.internal_name}"' - if self.source_port is not None: - source += f":{self.source_port}" - if self.target_port is not None: - target += f":{self.target_port}" - attrs = [f'label="{self.label}"', f'color="{self.color}"'] - if self.arrowtail is not None: - attrs.append(f'arrowtail="{self.arrowtail}"') - if self.arrowhead is not None: - attrs.append(f'arrowhead="{self.arrowhead}"') - return f"{source} -> {target} [{', '.join(attrs)}];" - - -class Group: - def __init__( - self, - label: str, - nodes: List[Node], - parent: Optional["Group"] = None, - border_color: Color = SOLARIZED_LIGHT["base03"], - ): - self.label = label - self.nodes = nodes - self.border_color = border_color - self.parent = parent - - -FONT = "Helvetica" -FONT_SIZE = 10 - -GRAPH_CONFIG = { - # "overlap": "scalexy", - "overlap": "scale", - "rankdir": "TB", # top to bottom - "fontname": FONT, - "fontsize": FONT_SIZE, - "fontcolor": SOLARIZED_LIGHT["base03"], -} - -NODE_CONFIG = { - "style": "rounded", - "shape": "rect", - "fontname": FONT, - "fontsize": FONT_SIZE, - "fontcolor": SOLARIZED_LIGHT["base03"], -} - -EDGE_CONFIG = { - "fontname": FONT, - "fontsize": FONT_SIZE, - "fontcolor": SOLARIZED_LIGHT["base03"], -} - - -def dict_to_dot_string(d: Dict[str, Any]) -> str: - """Converts a dict to a dot string""" - return ",".join([f'{k}="{v}"' for k, v in d.items()]) - - -def _get_group_string_shallow(group: Group, children_string: str) -> str: - nodes_string = " ".join([f'"{node.internal_name}"' for node in group.nodes]) - return f'subgraph "cluster_{group.label}" {{style="rounded"; label="{group.label}"; color="{group.border_color}"; {nodes_string};\n {children_string} }}' - - -def get_group_string(group: Group, groups_forest: Dict[Group, List[Group]]) -> str: - children = groups_forest.get(group, []) - return _get_group_string_shallow( - group, - children_string="\n".join( - [get_group_string(child, groups_forest=groups_forest) for child in children] - ), - ) - - -def to_dot_string( - nodes: List[Node], - edges: List[Edge], - groups: List[Group], - rankdir: Literal["TB", "BT", "LR", "RL"] = "TB", -) -> str: - """Converts a graph to a dot string""" - joiner = "\n" + " " * 12 - ### global config - graph_config = copy.deepcopy(GRAPH_CONFIG) - graph_config["rankdir"] = rankdir - graph_config = f"graph [ {dict_to_dot_string(graph_config)} ];" - node_config = f"node [ {dict_to_dot_string(NODE_CONFIG)} ];" - edge_config = f"edge [ {dict_to_dot_string(EDGE_CONFIG)} ];" - preamble = joiner.join([graph_config, node_config, edge_config]) - ### nodes - node_strings = [] - for node in nodes: - node_strings.append(node.to_dot_string()) - nodes_part = joiner.join(node_strings) - ### edges - edge_strings = [] - for edge in edges: - edge_strings.append(edge.to_dot_string()) - edges_part = joiner.join(edge_strings) - ### groups - groups_forest = { - group: [g for g in groups if g.parent is group] for group in groups - } - roots = [group for group in groups if group.parent is None] - group_strings = [] - for group in roots: - group_strings.append(get_group_string(group, groups_forest=groups_forest)) - groups_part = joiner.join(group_strings) - result = f""" - digraph G {{ - {preamble} - {nodes_part} - {edges_part} - {groups_part} - }} - """ - return result - - -def write_output( - dot_string: str, - output_ext: str, - output_path: Optional[Path] = None, - show_how: Literal["none", "browser", "inline", "open"] = "none", -): - """ - Writes the given dot string to a dot file (temp file by default) and then - optionally shows it in the browser, opens it in a program, or does nothing. - """ - # make a temp file and write the dot string to it - if output_path is None: - tfile = tempfile.NamedTemporaryFile(suffix=f".{output_ext}", delete=False) - output_path = Path(tfile.name) - with tempfile.NamedTemporaryFile(mode="w", delete=True) as f: - path = f.name - with open(path, "w") as f: - f.write(dot_string) - cmd = f"dot -T{output_ext} -o{output_path} {path}" - subprocess.call(cmd, shell=True) - if show_how == "browser": - assert output_ext in [ - "png", - "jpg", - "jpeg", - "svg", - ], "Can only show png, jpg, jpeg, or svg in browser" - webbrowser.open(str(output_path)) - return - if show_how == "inline" or show_how == "open": - assert ( - Config.has_pil - ), "Pillow is not installed. Please install it to show images inline" - img = Image.open(output_path, "r") - if show_how == "inline": - return img - else: - img.show() diff --git a/mandala/utils.py b/mandala/utils.py index d3de54a..bbfb65a 100644 --- a/mandala/utils.py +++ b/mandala/utils.py @@ -1,4 +1,25 @@ from .common_imports import * +import joblib +import io +import inspect +import prettytable +import sqlite3 +from .config import * +from abc import ABC, abstractmethod + +def dataframe_to_prettytable(df: pd.DataFrame) -> str: + # Initialize a PrettyTable object + table = prettytable.PrettyTable() + + # Set the column names + table.field_names = df.columns.tolist() + + # Add rows to the table + for row in df.itertuples(index=False): + table.add_row(row) + + # Return the pretty-printed table as a string + return table.get_string() def serialize(obj: Any) -> bytes: @@ -18,25 +39,201 @@ def deserialize(value: bytes) -> Any: return joblib.load(buffer) -def _rename_cols_pandas(df: pd.DataFrame, mapping: Dict[str, str]) -> pd.DataFrame: - return df.rename(columns=mapping, inplace=False) +def get_content_hash(obj: Any) -> str: + if hasattr(obj, "__get_mandala_dict__"): + obj = obj.__get_mandala_dict__() + if Config.has_torch: + # TODO: ideally, should add a label to distinguish this from a numpy + # array with the same contents! + obj = tensor_to_numpy(obj) + if isinstance(obj, pd.DataFrame): + # DataFrames cause collisions for joblib hashing for some reason + # TODO: the below may be incomplete + obj = { + "columns": obj.columns, + "values": obj.values, + "index": obj.index, + } + result = joblib.hash(obj) # this hash is canonical wrt python collections + if result is None: + raise RuntimeError("joblib.hash returned None") + return result + + +def dump_output_name(index: int, output_names: Optional[List[str]] = None) -> str: + if output_names is not None and index < len(output_names): + return output_names[index] + else: + return f"output_{index}" + + +def parse_output_name(name: str) -> int: + return int(name.split("_")[-1]) + + +def get_setdict_union( + a: Dict[str, Set[str]], b: Dict[str, Set[str]] +) -> Dict[str, Set[str]]: + return {k: a.get(k, set()) | b.get(k, set()) for k in a.keys() | b.keys()} + + +def get_setdict_intersection( + a: Dict[str, Set[str]], b: Dict[str, Set[str]] +) -> Dict[str, Set[str]]: + return {k: a[k] & b[k] for k in a.keys() & b.keys()} + + +def get_dict_union_over_keys(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]: + return {k: a[k] if k in a else b[k] for k in a.keys() | b.keys()} + + +def get_dict_intersection_over_keys( + a: Dict[str, Any], b: Dict[str, Any] +) -> Dict[str, Any]: + return {k: a[k] for k in a.keys() & b.keys()} + + +def get_adjacency_union( + a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]] +) -> Dict[str, Dict[str, Set[str]]]: + return { + k: get_setdict_union(a.get(k, {}), b.get(k, {})) for k in a.keys() | b.keys() + } + + +def get_adjacency_intersection( + a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]] +) -> Dict[str, Dict[str, Set[str]]]: + return {k: get_setdict_intersection(a[k], b[k]) for k in a.keys() & b.keys()} -def _rename_cols_arrow(table: pa.Table, mapping: Dict[str, str]) -> pa.Table: - columns = table.column_names - new_columns = [mapping.get(col, col) for col in columns] - table = table.rename_columns(new_columns) - return table +def get_nullable_union(*sets: Set[str]) -> Set[str]: + return set.union(*sets) if len(sets) > 0 else set() -def _rename_cols(table: TableType, mapping: Dict[str, str]) -> TableType: - if isinstance(table, pd.DataFrame): - return _rename_cols_pandas(df=table, mapping=mapping) - elif isinstance(table, pa.Table): - return _rename_cols_arrow(table=table, mapping=mapping) +def get_nullable_intersection(*sets: Set[str]) -> Set[str]: + return set.intersection(*sets) if len(sets) > 0 else set() + + +def get_adj_from_edges( + edges: Set[Tuple[str, str, str]], node_support: Optional[Set[str]] = None +) -> Tuple[Dict[str, Dict[str, Set[str]]], Dict[str, Dict[str, Set[str]]]]: + """ + Given edges, convert them into the adjacency representation used by the + `ComputationFrame` class. + """ + out = {} + inp = {} + for src, dst, label in edges: + if src not in out: + out[src] = {} + if label not in out[src]: + out[src][label] = set() + out[src][label].add(dst) + if dst not in inp: + inp[dst] = {} + if label not in inp[dst]: + inp[dst][label] = set() + inp[dst][label].add(src) + if node_support is not None: + for node in node_support: + if node not in out: + out[node] = {} + if node not in inp: + inp[node] = {} + return out, inp + + +def parse_returns( + sig: inspect.Signature, + returns: Any, + nout: Union[Literal["auto", "var"], int], + output_names: Optional[List[str]] = None, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Return two dicts based on the returns: + - {output name: output value} + - {output name: output type annotation}, where things like `Tuple[T, ...]` are expanded. + """ + ### figure out the number of outputs, and convert them to a tuple + if nout == "auto": # infer from the returns + if isinstance(returns, tuple): + nout = len(returns) + returns_tuple = returns + else: + nout = 1 + returns_tuple = (returns,) + elif nout == "var": + assert isinstance(returns, tuple) + nout = len(returns) + returns_tuple = returns + else: # nout is an integer + assert isinstance(nout, int) + assert isinstance(returns, tuple) + assert len(returns) == nout + returns_tuple = returns + ### get the dict of outputs + outputs_dict = { + dump_output_name(i, output_names): returns_tuple[i] for i in range(nout) + } + ### figure out the annotations + annotations_dict = {} + output_annotation = sig.return_annotation + if output_annotation is inspect._empty: # no annotation + annotations_dict = {k: Any for k in outputs_dict.keys()} else: - raise NotImplementedError(f"rename_cols not implemented for {type(table)}") + if ( + hasattr(output_annotation, "__origin__") + and output_annotation.__origin__ is tuple + ): + if ( + len(output_annotation.__args__) == 2 + and output_annotation.__args__[1] == Ellipsis + ): + annotations_dict = { + k: output_annotation.__args__[0] for k in outputs_dict.keys() + } + else: + annotations_dict = { + k: output_annotation.__args__[i] + for i, k in enumerate(outputs_dict.keys()) + } + else: + assert nout == 1 + annotations_dict = {k: output_annotation for k in outputs_dict.keys()} + return outputs_dict, annotations_dict + +def unwrap_decorators( + obj: Callable, strict: bool = True +) -> Union[types.FunctionType, types.MethodType]: + while hasattr(obj, "__wrapped__"): + obj = obj.__wrapped__ + if not isinstance(obj, (types.FunctionType, types.MethodType)): + msg = f"Expected a function or method, but got {type(obj)}" + if strict: + raise RuntimeError(msg) + else: + logger.debug(msg) + return obj + +def is_subdict(a: Dict, b: Dict) -> bool: + """ + Check that all keys in `a` are in `b` with the same value. + """ + return all((k in b and a[k] == b[k]) for k in a) + +_KT, _VT = TypeVar("_KT"), TypeVar("_VT") +def invert_dict(d: Dict[_KT, _VT]) -> Dict[_VT, List[_KT]]: + """ + Invert a dictionary + """ + out = {} + for k, v in d.items(): + if v not in out: + out[v] = [] + out[v].append(k) + return out def ask_user(question: str, valid_options: List[str]) -> str: """ diff --git a/mandala/_next/viz.py b/mandala/viz.py similarity index 100% rename from mandala/_next/viz.py rename to mandala/viz.py