diff --git a/.gitignore b/.gitignore
index 65711ab..9ffaecb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -56,8 +56,6 @@
# eggs
mandala.egg-info/*
-mandala/demos/data/
-mandala/__archive__
scratchpad/
mandala/scipy/
@@ -83,4 +81,7 @@ mandala/tests/output/*
.idea/**
# ignore docs build
-site/
\ No newline at end of file
+site/
+
+# Ignore previous version
+mandala/_prev/
\ No newline at end of file
diff --git a/README.md b/README.md
index 84d04f8..4ab2e29 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,7 @@
-[![Gitter](https://img.shields.io/gitter/room/amakelov/mandala)](https://app.gitter.im/#/room/#mandala:gitter.im)
-
# Computations that save, query and version themselves
-![tldr](https://user-images.githubusercontent.com/1467702/231244639-f318af0e-3993-4ad1-8822-e8d889003dc1.gif)
-
-`mandala` is a framework for experiment tracking and [incremental
-computing](https://en.wikipedia.org/wiki/Incremental_computing) with a simple
-plain-Python interface. It automates away the pain of experiment data management
-(*did I run this already? what's in this file? where's the result of that
-computation?* ...) with the following mix of features:
-- **plain-Python**: decorate the functions whose calls you want to save, and
- just write ordinary Python code using them - including data structures
-and control flow. The results are automatically accessible upon a
-re-run, queriable, and versioned. No need to save, load, or name anything by yourself
-- **never compute the same thing twice**: `mandala` saves the result of each
- function call, and (hashes of) the inputs and the dependencies
- (functions/globals) accessed by the call. If later the inputs and dependencies
- are the same, it just loads the results from storage.
-- **query by pattern-matching directly to computational code**: your code
- already knows the relationships between variables in your project! `mandala`
- lets you produce tables relating any variables by [directly pointing to
- the code establishing the wanted relationship between them](#query-by-directly-pattern-matching-python-code).
-- **fine-grained versioning that's under your control**: each function's source
- code has its own "mini git repo" behind the scenes, and each call tracks the
- versions of all dependencies that went into it. You can decide when a change
- to a dependency (e.g. refactoring a function to improve readability) doesn't change its semantics (so calls dependent on it won't be recomputed).
-
-
-## Install
+`mandala` eliminates the effort and code overhead of ML experiment tracking with
+two tools:
+
+1. The `@op` decorator:
+ - Automatically captures inputs, outputs and code (+dependencies) of Python function calls
+ - Automatically reuses past results & never computes the same call twice
+ - Designed to be composed into end-to-end persisted programs, enabling
+ efficient iterative development in plain-Python, without thinking about the
+ storage backend.
+
+2. The `ComputationFrame` data structure:
+ - Generalized dataframes: "columns" are a computation graph, "rows" are
+ (partial) executions of the graph
+ - Automates exploration, querying and high-level operations over
+ heterogeneous "webs" of `@op` calls
+ - Can be converted to a `DataFrame` of execution traces for downstream
+ analysis of relationships between variables in a project
+
+# Install
```
pip install git+https://github.com/amakelov/mandala
```
-## Testimonials
-
-> "`mandala` addresses a core challenge in my notebook workflow: being able to
-> explore data with code, without having to worry about losing the results of
-> expensive calculations." - *Adam Jermyn, Member of Technical Staff, Anthropic*
-
-
-## Video walkthroughs
-### Rapidly iterate on a project with memoization
-![mem](https://user-images.githubusercontent.com/1467702/231246050-21855bb2-6ce0-43d6-b7c0-3e0ed2a68f28.gif)
-
-Decorate the functions you want to memoize with `@op`, and compose programs out
-of them by chaining their inputs and outputs using ordinary **control flow** and
-**collections**. Every such program is **end-to-end memoized**:
-- it becomes an **imperative query interface** to its own results by
- (quickly) re-executing the same code, or parts of it
-- it is **incrementally extensible** with new logic and parameters in-place,
- which makes it both easy and efficient to interact with experiments
-
-### Query by directly pattern-matching Python code
-![query](https://user-images.githubusercontent.com/1467702/231246102-276d7ae9-3a7f-46f8-9899-ae9dcf4f0484.gif)
-
-Any computation in a `with storage.run():` block is also a
-**declarative query interface** to analogous computations in the entire storage:
-- **get a table of all values with the same computational history**:
-`storage.similar(x, y, ...)` returns a table of all values in the storage that
-were computed in the same ways as `x, y, ...`, but possibly starting from
-different inputs
-- **query through collections**: queries propagate provenance from a collection
- to its elements and vice-versa. This means you can query through operations
- that aggregate many objects into one, or produce many objects from a fixed
- number.
- - **NOTE**: a collection in a computation can pattern-match any collection in
- the storage with the same *kinds* of elements (as given by their computational
- history), but not necessarily in the same order or quantity. This ensures that
- you don't only match to the specific computation you have, but all *analogous*
- computations too.
-- **define queries explicitly**: for full control, use the `with storage.query():` context manager. For more on how this works, see [below](#explicit-declarative-queries-with-storagequery)
-
-### Automatic per-call versioning and dependency tracking
-![deps](https://user-images.githubusercontent.com/1467702/231246159-fc8996a1-0987-4cec-9f0d-f0408609886e.gif)
-
-`mandala` comes with a very fine-grained versioning system:
-- **per-call dependency tracking**: automatically track the functions and global
-variables accessed by each memoized call, and alert you to changes in them, so
-you can (carefully) choose whether a change to a dependency requires
-recomputation of dependent calls (like bug fixes and changes to logic) or not
-(like refactoring, comments, and logging)
-- **the code determines all versions automatically**: use the current state of
-each dependency in your codebase to automatically determine the currently
-compatible versions of each memoized function to use in computation and queries.
-In particular, this means that:
- - **you can go "back in time"** and access the storage relative to an earlier
- state of the code (or even branch in a new direction like in `git`) by
- just restoring this state
- - **the code is the truth**: when in doubt about the meaning of a result, you
- can just look at the current code.
-
-## Quickstart
-```python
-from mandala.imports import *
-
-# the storage saves calls and tracks dependencies, versions, etc.
-storage = Storage(
- deps_path='__main__' # track dependencies in current session
- )
-
-@op # memoization (and more) decorator
-def increment(x: int) -> int: # always indicate number of outputs in return type
- print('hi from increment!')
- return x + 1
-
-increment(23) # function acts normally
-
-with storage.run(): # context manager that triggers `mandala`
- y = increment(23) # now it's memoized w.r.t. this version of `increment`
-
-print(y) # result wrapped with metadata.
-print(unwrap(y)) # `unwrap` gets the raw value
-
-with storage.run():
- y = increment(23) # loads result from `storage`; doesn't execute `increment`
-
-@op # type-annotate data structures to store elts separately
-def average(nums: list) -> float:
- print('hi from average!')
- return sum(nums) / len(nums)
-
-# memoized functions are designed to be composed!
-with storage.run():
- # sliding averages of `increment`'s results over 3 elts
- nums = [increment(i) for i in range(5)]
- for i in range(3):
- result = average(nums[i:i+3])
-
-# get a table of all values similar to `result` in `storage`,
-# i.e., computed as average([increment(something), ...])
-# read the message this prints out!
-print(storage.similar(result, context=True))
-
-# change implementation of `increment` and re-run
-# you'll be asked if the change requires recomputing dependencies (say yes)
-@op
-def increment(x: int) -> int:
- print('hi from new increment!')
- return x + 2
-
-with storage.run():
- nums = [increment(i) for i in range(5)]
- for i in range(3):
- # only one call to `average` is executed!
- result = average(nums[i:i+3])
-
-# query is ran against the *new* version of `increment`
-print(storage.similar(result, context=True))
-```
-
-## Basic usage
-This is a quick guide on how to get up to speed with the core features and avoid
-common pitfalls.
-
-- [defining storages and memoized functions](#storage-and-the-op-decorator)
-- [memoization basics](#compute--memoize-with-storagerun)
-- [query storage directly from computational code](#implicit-declarative-queries)
-- [explicit query interface](#explicit-declarative-queries-with-storagequery)
-- [versioning and dependency tracking](#versioning-and-dependency-tracking)
-
-### `Storage` and the `@op` decorator
-A `Storage` instance holds all the data (saved calls and metadata) for a
-collection of memoized functions. In a given project, you should have just one
-`Storage` and many memoized functions connected to it. This way, the calls to
-memoized functions create a queriable web of interlinked objects.
-
-```python
-from mandala.all import Storage, op
-
-storage = Storage(
- db_path='my_persistent_storage.db', # omit for an in-memory storage
- deps_path='path_to_code_folder/', # omit to disable automatic dependency tracking
- spillover_dir='spillover_dir/', # spillover storage for large objects
- # see docs for more options
-)
-```
-The `@op` decorator marks a function `f` as memoizable. Some notes apply:
-- `f` must have a **fixed number of arguments** (defaults are allowed, and
- arguments can always be added backward-compatibly)
-- `f` must have a **fixed number of return values**, and this must be annotated in
- the function signature **or** declared as `@op(nout=...)`
-- `f` must have a **compatible interface** throughout its life.
- - **The only way to change `f`'s interface once it already has memoized calls
- is to add new arguments with default values.**
- - if you want to change the interface in an incompatible way, you should
- either just make a new function (under a new name), or increment the
- `@op(version=...)` argument.
-- the `list`, `dict` and `set` collections, when used as argument/return
- annotations, cause elements of these collections to be stored separately. To
- avoid confusion, `tuple`s are reserved for specifying the number of outputs.
-
-```python
-from sklearn.datasets import load_digits
-from sklearn.ensemble import RandomForestClassifier
-from typing import Tuple
-import numpy as np
-
-@op # core mandala decorator
-def load_data(n_class: int = 10) -> Tuple[np.ndarray, np.ndarray]:
- return load_digits(n_class=n_class, return_X_y=True)
-
-@op
-def train_model(X: np.ndarray, y: np.ndarray,
- n_estimators:int = 5) -> RandomForestClassifier:
- return RandomForestClassifier(n_estimators=n_estimators,
- max_depth=2).fit(X, y)
-```
-
-**Calling an `@op`-decorated function "normally" does not memoize**. To actually
-put data in the storage, you must put calls inside a `with storage.run():`
-block.
+# Quickstart
-### Compute & memoize `with storage.run():`
-**`@op`-decorated functions are designed to be composed** with one another
-inside `storage.run()` blocks. This composability lets
-you use the same piece of ordinary Python code to compute, save, load, *and any
-combination of the three*:
-```python
-# generate the dataset. This saves `X, y` to storage.
-with storage.run():
- X, y = load_data()
+[Run in Colab](https://colab.research.google.com/github/amakelov/mandala/blob/master/mandala/_next/tutorials/hello.ipynb)
-# later, train a model by directly adding on top of this code. `load_data` is
-# not computed again
-with storage.run():
- X, y = load_data()
- model = train_model(X, y)
-
-# iterate on this with more parameters & logic
-from sklearn.metrics import accuracy_score
+# Documentation
+TODO: link
-@op
-def get_acc(model:RandomForestClassifier, X:np.ndarray, y:np.ndarray) -> float:
- return round(accuracy_score(y_pred=model.predict(X), y_true=y), 2)
-
-with storage.run():
- for n_class in (10, 5, 2):
- X, y = load_data(n_class)
- for n_estimators in (5, 10, 20):
- model = train_model(X, y, n_estimators=n_estimators)
- acc = get_acc(model, X, y)
- print(acc)
-```
-```python
-ValueRef(0.66, uid=15a...)
-ValueRef(0.73, uid=79e...)
-ValueRef(0.81, uid=5a4...)
-ValueRef(0.84, uid=6c4...)
-ValueRef(0.89, uid=fb8...)
-ValueRef(0.93, uid=c3d...)
-ValueRef(1.0, uid=b67...)
-ValueRef(1.0, uid=b67...)
-ValueRef(1.0, uid=b67...)
-```
-
-Memoized functions return `Ref` instances (`ValueRef`, `ListRef`,
-...), which bundle the actual return value with storage metadata. To get the
-return value itself, use `unwrap`. It works recursively on collections (lists,
-dicts, sets, tuples) as well.
-
-You can **imperatively query** storage just by retracing some code that's been
-entirely memoized:
-```python
-from mandala.all import unwrap
-# use composable memoization as imperative query interface
-with storage.run():
- X, y = load_data(5)
- for n_estimators in (5, 20):
- model = train_model(X, y, n_estimators=n_estimators)
- acc = get_acc(model, X, y)
- print(unwrap(acc))
-```
-```python
-0.84
-0.93
-```
-### Implicit declarative queries
-You can point to local variables in memoized code, and get a table of all values
-in storage with the same functional dependencies as these variables have in the
-code. For example, the `storage.similar()` method can be used to get values with
-the same **joint computational history** as the given variables. To be able to
-use a local variable in `storage.similar()`, it needs to be `wrap`ped as a `Ref`
-(which has no effect on computation):
-
-```python
-from mandala.all import wrap
-# use composable memoization as imperative query interface
-with storage.run():
- n_class = wrap(5)
- X, y = load_data(n_class)
- for n_estimators in wrap((5, 20)): # `wrap` maps over list, set and tuple
- model = train_model(X, y, n_estimators=n_estimators)
- acc = get_acc(model, X, y)
-
-storage.similar(n_class, n_estimators, acc)
-```
-```python
-Pattern-matching to the following computational graph (all constraints apply):
- n_estimators = Q() # input to computation; can match anything
- n_class = Q() # input to computation; can match anything
- X, y = load_data(n_class=n_class)
- model = train_model(X=X, y=y, n_estimators=n_estimators)
- acc = get_acc(model=model, X=X, y=y)
- result = storage.df(n_class, n_estimators, acc)
- n_class n_estimators acc
-1 10 5 0.66
-0 10 10 0.73
-2 10 20 0.81
-7 5 5 0.84
-6 5 10 0.89
-8 5 20 0.93
-4 2 5 1.00
-3 2 10 1.00
-5 2 20 1.00
-```
-
-The computational graph printed out by the query (default `verbose=True`) is
-also a good starting point for running an explicit query where you directly
-provide the computational graph instead of extracting it from a program.
-
-### Explicit declarative queries `with storage.query():`
-The kind of printout above can be directly copy-pasted into a `with
-storage.query():` block. Here it is with some more explanation:
-```python
-with storage.query():
- n_class = Q() # creates a variable that can match any value in storage
- # @op calls impose constraints between the values variables can match
- X, y = load_data(n_class)
- n_estimators = Q() # another variable that can match anything
- model = train_model(X, y, n_estimators=n_estimators)
- acc = get_acc(model, X, y)
- # get a table where each row is a matching of the given variables
- # that satisfies the constraints
- result = storage.df(n_class, n_estimators, acc)
-```
-#### How the `.query()` context works
-- a query variable, generated with `Q()` or as return value from an `@op`
- (like `X`, `y`, ... above), can in
-principle match any value in the storage.
-- by chaining together calls to `@op`s, you impose constraints between the
- inputs and outputs to the op. For exampe, `X, y = load_data(n_class)` imposes
- the constraint that a matching of values `(a, b, c)` to `(n_class, X, y)` must
- satisfy `b, c = load_data(a)`.
-- You can omit even required function arguments. This leaves them unconstrained.
-- the `df` method takes any sequence of variables, and returns a table
-where each row is a matching of values to the respective variables that
-satisfies **all** the constraints.
-
-**The query implementation has not been optimized for performance at this point**. Keep in mind that
-- if your query is not sufficiently constrained, there may be a combinatorial
-explosion of results;
-- if you query involves many variables and constraints, the default `SQL` query
-solver may have a hard time, or flat out raise an error. Try using
-`engine='naive'` in the `storage.similar()` or `storage.df()` methods instead.
-
-### Versioning and dependency tracking
-Passing a value to the `deps_path` parameter of the `Storage` class enables
-dependency tracking and versioning. This means that any time a memoized function
-*actually executes* (instead of loading an already saved call), it keeps track of
-the functions and global variables it accesses along the way.
-
-The number of tracked functions should be limited for efficiency (you typically
-don't want to track changes in installed libraries!). Setting `deps_path` to
-`"__main__"` will only look for dependencies defined in the current interactive
-session or process. Setting it to a folder will only look for dependencies
-defined in this folder.
-
-#### NOTE: The `@track` decorator
-The most efficient and reliable implementation of dependency tracking currently
-requires you to explicitly put `@track` on non-memoized functions and classes
-you want to track. This limitation may be lifted in the future, but at the cost
-of more magic.
-
-#### What is a version?
-A **version** for a memoized function is (to a first approximation) a set of
-source codes for functions/methods/global variables accessed by some call to
-this function. Even if you don't change anything in the code, a single function
-can have multiple versions if it invokes different dependencies for different calls. For example, consider this code:
-```python
-import numpy as np
-from sklearn.datasets import load_digits
-from sklearn.linear_model import LogisticRegression
-from sklearn.preprocessing import StandardScaler
-from mandala.imports import Storage, op, track
-from typing import Tuple, Any
-
-N_CLASS = 10
-
-@track # to track a non-memoized function as a dependency
-def scale_data(X):
- return StandardScaler(with_mean=True, with_std=False).fit_transform(X)
-
-@op
-def load_data() -> Tuple[np.ndarray, np.ndarray]:
- X, y = load_digits(n_class=N_CLASS, return_X_y=True)
- return X, y
-
-@op
-def train_model(X, y, scale=False) -> LogisticRegression:
- if scale:
- X = scale_data(X)
- return LogisticRegression().fit(X, y)
-
-@op
-def eval_model(model, X, y, scale=False) -> Any:
- if scale:
- X = scale_data(X)
- return model.score(X, y)
-
-storage = Storage(deps_path='__main__')
-
-with storage.run():
- X, y = load_data()
- for scale in [False, True]:
- model = train_model(X, y, scale=scale)
- acc = eval_model(model, X, y, scale=scale)
-```
-When you run it, `train_model` and `eval_model` will each have two versions -
-one that depends on `scale_data` and one that doesn't. You can confirm this by
-calling `storage.versions(train_model)`. Now suppose we make some changes
-and re-run:
-```python
-N_CLASS = 5
-
-@track
-def scale_data(X):
- return StandardScaler(with_mean=True, with_std=True).fit_transform(X)
-
-@op
-def eval_model(model, X, y, scale=False) -> Any:
- if scale:
- X = scale_data(X)
- return round(model.score(X, y), 2)
-
-with storage.run():
- X, y = load_data()
- for scale in [False, True]:
- model = train_model(X, y, scale=scale)
- acc = eval_model(model, X, y, scale=scale)
-```
-When entering the `storage.run()` block, the storage will detect the changes in
-the tracked components, and for each change will present you with the functions
-affected:
-- `N_CLASS` is a dependency for `load_data`;
-- `scale_data` is a dependency for the calls to `train_model` and `eval_model`
- which had `scale=True`;
-- `eval_model` is a dependency for itself.
-
-#### Semantic vs content changes and versions
-For each change to the content of some dependency (the source code of a function
-or the value of a global variable), you can choose whether this content change
-is also a **semantic** change. A semantic change will cause all calls that
-have accessed this dependency to not appear memoized **with respect to the new
-state of the code**. The content versions of a single dependency are organized
-in a `git`-like DAG (currently, tree) that can be inspected using
-`storage.sources(f)` for functions.
-
-#### Going back in time
-Since the versioning system is content-based, simply restoring an old state of
-the code makes the storage automatically recognize which "world" it's in, and
-which calls are memoized in this world.
-
-#### A warning about non-semantic changes
-The main motivation for allowing non-semantic changes is to maintain clarity in
-the storage when doing routine code improvements (refactoring, comments,
-logging). **However**, non-semantic changes should be applied with care. Apart from
-being prone to errors (you wrongly conclude that a change has no effect on
-semantics when it does), they can also introduce **invisible dependencies**:
-suppose you factor a function out of some dependency and mark the change
-non-semantic. Then the newly extracted function may in reality be a dependency
-of the existing calls, but this goes unnoticed by the system.
-
-## Other gotchas
-
-- **under development**: the biggest gotcha is that this project is under active
-development, which means things can change unpredictably.
-- **slow**: it hasn't been optimized for performance, so many things are quite
-inefficient
-- **pure functions**: you should probably only use it for functions with a
- deterministic input-output behavior if you're new to this project:
- - **changing a `Ref`'s object in-place will generally break things**. If you
- really need to update an object in-place, wrap the update in an `@op` so
- that you get instead a new `Ref` (with updated metadata) pointing to the
- same (changed) object, and discard the old `Ref`.
- - if a function does not have a **deterministic set of dependencies**
- it invokes for each given call, this may break the versioning system's
- invariants.
-- **avoid long (e.g. > 50) chains of calls in queries**: you should keep your
- workflows relatively shallow for queries to be efficient. This means e.g. no
- long recursive chains of calling a function repeatedly on its output
-- **don't rename anything (yet)**: there isn't good support yet for renaming
-functions, or moving functions around files. It's possible to rename functions
-and their arguments, but this is still undocumented.
-- **deletion**: no interfaces are currently exposed for deleting results.
-- **examine complex queries manually**: the color refinement algorithm used to
- extract a declarative query from a computational graph can in rare cases fail
- to realize that two vertices have different roles in the computational graph
- when projecting to the query. When in doubt, you should examine the printout
- of the query and tweak it if necessary.
-
-## Tutorials
+# Tutorials
- see the ["Hello world!"
tutorial](https://github.com/amakelov/mandala/blob/master/tutorials/00_hello.ipynb)
for a 2-minute introduction to the library's main features
-- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_logistic.ipynb)
+- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_random_forest.ipynb)
for a more realistic example of a machine learning project managed by Mandala.
-- [dependency
- tracking](https://github.com/amakelov/mandala/blob/master/tutorials/02_dependencies.ipynb)
- tutorial
+- TODO: dependency tracking
+
+# FAQs
+
+## How is this different from other experiment tracking frameworks?
+Compared to popular tools like W&B, MLFlow or Comet, `mandala`:
+- **is tightly integrated with the actual Python code execution**, as
+opposed to being an external logging framework. This makes it much easier to
+compose and iterate on multi-step experiments with non-trivial control flow.
+ - For instance, Python's collections can be (if so desired) made
+ transparent to the storage system, so that individual elements are
+ stored separately and can be reused across collections and calls.
+- **is founded on memoization, which allows direct and transparent reuse of
+results**.
+ - While in other frameworks you need to come up with arbitrary names
+for artifacts (which can later cause confusion)
+- **allows reuse, queries and versioning on a more granular and flexible
+level** - the function call - as opposed to entire scripts and/or notebooks.
+- **provides the `ComputationFrame` data structure**, a much more natural way to
+represent and manipulate persisted computations.
+- **automatically resolves the version of every `@op`** from the current state
+of the codebase.
+
+## How is the `@op` cache invalidated?
+- given inputs for a call to an `@op`, e.g. `f`, it searches for a past call
+to `f` on inputs with the same contents (as determined by a hash function) where the dependencies accessed by this call (including `f`
+itself) have versions compatible with their current state.
+- compatibility between versions of a function is decided by the user: you
+have the freedom to mark certain changes as compatible with past results.
+- internally, `mandala` uses slightly modified `joblib` hashing to compute a
+content hash for Python objects. This is practical for many use cases, but
+not perfect, as discussed in the "gotchas" notebook TODO.
+
+## Can I change the code of `@op`s, and what happens if I do?
+- a frequent use case: you have some `@op` you've been using, then want to
+extend its functionality in a way that doesn't invalidate the past results.
+The recommended way is to add a new argument `a`, and provide a default
+value for it wrapped with `NewArgDefault(x)`. When a value equal to `x` is
+passed for this argument, the storage falls back on calls before
+
+## How self-contained is it?
+- `mandala`'s core is simple (only a few kLoCs) and only depends on `pandas`
+and `joblib`.
+- for visualization of `ComputationFrame`s, you should have `dot` installed
+on the system level, and/or the Python `graphviz` library installed.
+
+# Limitations
+
+# Roadmap
+
+# Testimonials
+
+> "`mandala` addresses a core challenge in my notebook workflow: being able to
+> explore data with code, without having to worry about losing the results of
+> expensive calculations." - *Adam Jermyn, Member of Technical Staff, Anthropic*
-## Related work
+# Galaxybrained vision
+Aspirationally, `mandala` is about much more than ML experiment tracking. The
+main goal is to **make persistence logic & best practices a natural extension of Python**.
+Once this is achieved, the purely "computational" code you must write anyway
+doubles as a storage interface. It's hard to think of a simpler and more
+reliable way to manage computational artifacts.
+
+## A first-principles approach to managing computational artifacts
+What we want from our storage are ways to
+- refer to artifacts with short, unambiguous descriptions: "here's [big messy Python object] I computed, which to me
+means [human-readable description]"
+- save artifacts: "save [big messy Python object]"
+- refer to artifacts and load them at a later time: "give me [human-readable description] that I computed before"
+- know when you've already computed something: "have I computed [human-readable description]?"
+- query results in more complicated ways: "give me all the things that satisfy
+[higher-level human-readable description]", which in practice means some
+predicate over combinations of artifacts.
+- get a report of how artifacts were generated: "what code went into [human-readable description]?"
+
+The key observation is that **execution traces** can already answer ~all of
+these questions.
+
+# Related work
`mandala` combines ideas from, and shares similarities with, many technologies.
Here are some useful points of comparison:
- **memoization**:
@@ -544,16 +148,9 @@ Here are some useful points of comparison:
computation data processing framework that unifies over different resource
types (files or services). It also uses an analogous notion of hashing to
keep track of computations.
-- **queries**:
- - all queries in `mandala` are [conjunctive queries](https://en.wikipedia.org/wiki/Conjunctive_query), a
- fundamental class of queries in relational algebra.
- - conjunctive queries are also related to category theory, see e.g.
+- **computation frames**:
+ - computation frames are related to the idea of using certain functions category theory, see e.g.
[here](https://blog.algebraicjulia.org/post/2020/12/cset-conjunctive-queries/).
- - the [color refinement
- algorithm](https://en.wikipedia.org/wiki/Colour_refinement_algorithm) used
- to extract a query from an arbitrary computational graph is a standard tool
- for finding similar substructure in graphs and testing for graph
- isomorphism.
- **versioning**:
- the revision history of each function in the codebase is organized in a "mini-[`git`](https://git-scm.com/) repository" that shares only the most basic
features with `git`: it is a
@@ -566,6 +163,5 @@ Here are some useful points of comparison:
make backward-compatible changes to the interface and logic of dependencies.
It is different in that versions are still labeled by content, instead of by
"non-canonical" numbers.
- - the [unison programming
- language](https://www.unison-lang.org/learn/the-big-idea/) represents
+ - the [unison programming language](https://www.unison-lang.org/learn/the-big-idea/) represents
functions by the hash of their content (syntax tree, to be exact).
diff --git a/mandala/_next/README.md b/mandala/_next/README.md
deleted file mode 100644
index 4ab2e29..0000000
--- a/mandala/_next/README.md
+++ /dev/null
@@ -1,167 +0,0 @@
-
-
-# Computations that save, query and version themselves
-
-
-
-`mandala` eliminates the effort and code overhead of ML experiment tracking with
-two tools:
-
-1. The `@op` decorator:
- - Automatically captures inputs, outputs and code (+dependencies) of Python function calls
- - Automatically reuses past results & never computes the same call twice
- - Designed to be composed into end-to-end persisted programs, enabling
- efficient iterative development in plain-Python, without thinking about the
- storage backend.
-
-2. The `ComputationFrame` data structure:
- - Generalized dataframes: "columns" are a computation graph, "rows" are
- (partial) executions of the graph
- - Automates exploration, querying and high-level operations over
- heterogeneous "webs" of `@op` calls
- - Can be converted to a `DataFrame` of execution traces for downstream
- analysis of relationships between variables in a project
-
-# Install
-```
-pip install git+https://github.com/amakelov/mandala
-```
-
-# Quickstart
-
-[Run in Colab](https://colab.research.google.com/github/amakelov/mandala/blob/master/mandala/_next/tutorials/hello.ipynb)
-
-# Documentation
-TODO: link
-
-# Tutorials
-- see the ["Hello world!"
- tutorial](https://github.com/amakelov/mandala/blob/master/tutorials/00_hello.ipynb)
- for a 2-minute introduction to the library's main features
-- See [this notebook](https://github.com/amakelov/mandala/blob/master/tutorials/01_random_forest.ipynb)
-for a more realistic example of a machine learning project managed by Mandala.
-- TODO: dependency tracking
-
-# FAQs
-
-## How is this different from other experiment tracking frameworks?
-Compared to popular tools like W&B, MLFlow or Comet, `mandala`:
-- **is tightly integrated with the actual Python code execution**, as
-opposed to being an external logging framework. This makes it much easier to
-compose and iterate on multi-step experiments with non-trivial control flow.
- - For instance, Python's collections can be (if so desired) made
- transparent to the storage system, so that individual elements are
- stored separately and can be reused across collections and calls.
-- **is founded on memoization, which allows direct and transparent reuse of
-results**.
- - While in other frameworks you need to come up with arbitrary names
-for artifacts (which can later cause confusion)
-- **allows reuse, queries and versioning on a more granular and flexible
-level** - the function call - as opposed to entire scripts and/or notebooks.
-- **provides the `ComputationFrame` data structure**, a much more natural way to
-represent and manipulate persisted computations.
-- **automatically resolves the version of every `@op`** from the current state
-of the codebase.
-
-## How is the `@op` cache invalidated?
-- given inputs for a call to an `@op`, e.g. `f`, it searches for a past call
-to `f` on inputs with the same contents (as determined by a hash function) where the dependencies accessed by this call (including `f`
-itself) have versions compatible with their current state.
-- compatibility between versions of a function is decided by the user: you
-have the freedom to mark certain changes as compatible with past results.
-- internally, `mandala` uses slightly modified `joblib` hashing to compute a
-content hash for Python objects. This is practical for many use cases, but
-not perfect, as discussed in the "gotchas" notebook TODO.
-
-## Can I change the code of `@op`s, and what happens if I do?
-- a frequent use case: you have some `@op` you've been using, then want to
-extend its functionality in a way that doesn't invalidate the past results.
-The recommended way is to add a new argument `a`, and provide a default
-value for it wrapped with `NewArgDefault(x)`. When a value equal to `x` is
-passed for this argument, the storage falls back on calls before
-
-## How self-contained is it?
-- `mandala`'s core is simple (only a few kLoCs) and only depends on `pandas`
-and `joblib`.
-- for visualization of `ComputationFrame`s, you should have `dot` installed
-on the system level, and/or the Python `graphviz` library installed.
-
-# Limitations
-
-# Roadmap
-
-# Testimonials
-
-> "`mandala` addresses a core challenge in my notebook workflow: being able to
-> explore data with code, without having to worry about losing the results of
-> expensive calculations." - *Adam Jermyn, Member of Technical Staff, Anthropic*
-
-# Galaxybrained vision
-Aspirationally, `mandala` is about much more than ML experiment tracking. The
-main goal is to **make persistence logic & best practices a natural extension of Python**.
-Once this is achieved, the purely "computational" code you must write anyway
-doubles as a storage interface. It's hard to think of a simpler and more
-reliable way to manage computational artifacts.
-
-## A first-principles approach to managing computational artifacts
-What we want from our storage are ways to
-- refer to artifacts with short, unambiguous descriptions: "here's [big messy Python object] I computed, which to me
-means [human-readable description]"
-- save artifacts: "save [big messy Python object]"
-- refer to artifacts and load them at a later time: "give me [human-readable description] that I computed before"
-- know when you've already computed something: "have I computed [human-readable description]?"
-- query results in more complicated ways: "give me all the things that satisfy
-[higher-level human-readable description]", which in practice means some
-predicate over combinations of artifacts.
-- get a report of how artifacts were generated: "what code went into [human-readable description]?"
-
-The key observation is that **execution traces** can already answer ~all of
-these questions.
-
-# Related work
-`mandala` combines ideas from, and shares similarities with, many technologies.
-Here are some useful points of comparison:
-- **memoization**:
- - standard Python memoization solutions are [`joblib.Memory`](https://joblib.readthedocs.io/en/latest/generated/joblib.Memory.html)
- and
- [`functools.lru_cache`](https://docs.python.org/3/library/functools.html#functools.lru_cache).
- `mandala` uses `joblib` serialization and hashing under the hood.
- - [`incpy`](https://github.com/pajju/IncPy) is a project that integrates
- memoization with the python interpreter itself.
- - [`funsies`](https://github.com/aspuru-guzik-group/funsies) is a
- memoization-based distributed workflow executor that uses an analogous notion
- of hashing to `mandala` to keep track of which computations have already been done. It
- works on the level of scripts (not functions), and lacks queriability and
- versioning.
- - [`koji`](https://arxiv.org/abs/1901.01908) is a design for an incremental
- computation data processing framework that unifies over different resource
- types (files or services). It also uses an analogous notion of hashing to
- keep track of computations.
-- **computation frames**:
- - computation frames are related to the idea of using certain functions category theory, see e.g.
- [here](https://blog.algebraicjulia.org/post/2020/12/cset-conjunctive-queries/).
-- **versioning**:
- - the revision history of each function in the codebase is organized in a "mini-[`git`](https://git-scm.com/) repository" that shares only the most basic
- features with `git`: it is a
- [content-addressable](https://en.wikipedia.org/wiki/Content-addressable_storage)
- tree, where each edge tracks a diff from the content at one endpoint to that
- at the other. Additional metadata indicates equivalence classes of
- semantically equivalent contents.
- - [semantic versioning](https://semver.org/) is another popular code
- versioning system. `mandala` is similar to `semver` in that it allows you to
- make backward-compatible changes to the interface and logic of dependencies.
- It is different in that versions are still labeled by content, instead of by
- "non-canonical" numbers.
- - the [unison programming language](https://www.unison-lang.org/learn/the-big-idea/) represents
- functions by the hash of their content (syntax tree, to be exact).
diff --git a/mandala/_next/common_imports.py b/mandala/_next/common_imports.py
deleted file mode 100644
index c65c395..0000000
--- a/mandala/_next/common_imports.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import time
-import traceback
-import random
-import logging
-import itertools
-import copy
-import hashlib
-import io
-import os
-import shutil
-import sys
-import joblib
-import inspect
-import binascii
-import asyncio
-import ast
-import types
-import tempfile
-from collections import defaultdict, OrderedDict
-from typing import (
- Any,
- Dict,
- List,
- Callable,
- Tuple,
- Iterable,
- Optional,
- Set,
- Union,
- TypeVar,
- Literal,
-)
-from pathlib import Path
-
-import pandas as pd
-import pyarrow as pa
-import numpy as np
-
-try:
- import rich
-
- has_rich = True
-except ImportError:
- has_rich = False
-
-if has_rich:
- from rich.logging import RichHandler
-
- logger = logging.getLogger("mandala")
- logging_handler = RichHandler(enable_link_path=False)
- FORMAT = "%(message)s"
- logging.basicConfig(
- level="INFO", format=FORMAT, datefmt="[%X]", handlers=[logging_handler]
- )
-else:
- logger = logging.getLogger("mandala")
- # logger.addHandler(logging.StreamHandler())
- FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
- logging.basicConfig(format=FORMAT)
- logger.setLevel(logging.INFO)
-
-
-from mandala.common_imports import sess
diff --git a/mandala/_next/deps/crawler.py b/mandala/_next/deps/crawler.py
deleted file mode 100644
index 25f109c..0000000
--- a/mandala/_next/deps/crawler.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import types
-from ..common_imports import *
-from ..utils import unwrap_decorators
-import importlib
-from .model import (
- DepKey,
- CallableNode,
-)
-from .utils import (
- is_callable_obj,
- extract_func_obj,
- unknown_function,
-)
-
-
-def crawl_obj(
- obj: Any,
- module_name: str,
- include_methods: bool,
- result: Dict[DepKey, CallableNode],
- strict: bool,
- objs_result: Dict[DepKey, Callable],
-):
- """
- Find functions and optionally methods native to the module of this object.
- """
- if is_callable_obj(obj=obj, strict=strict):
- if isinstance(unwrap_decorators(obj, strict=False), types.BuiltinFunctionType):
- return
- v = extract_func_obj(obj=obj, strict=strict)
- if v is not unknown_function and v.__module__ != module_name:
- # exclude non-local functions
- return
- dep_key = (module_name, v.__qualname__)
- node = CallableNode.from_obj(obj=v, dep_key=dep_key)
- result[dep_key] = node
- objs_result[dep_key] = obj
- if isinstance(obj, type):
- if include_methods:
- if obj.__module__ != module_name:
- return
- for k in obj.__dict__.keys():
- v = obj.__dict__[k]
- crawl_obj(
- obj=v,
- module_name=module_name,
- include_methods=include_methods,
- result=result,
- strict=strict,
- objs_result=objs_result,
- )
-
-
-def crawl_static(
- root: Optional[Path],
- strict: bool,
- package_name: Optional[str] = None,
- include_methods: bool = False,
-) -> Tuple[Dict[DepKey, CallableNode], Dict[DepKey, Callable]]:
- """
- Find all python files in the root directory, and use importlib to import
- them, look for callable objects, and create callable nodes from them.
- """
- result: Dict[DepKey, CallableNode] = {}
- objs_result: Dict[DepKey, Callable] = {}
- paths = []
- if root is not None:
- if root.is_file():
- assert package_name is not None # needs this to be able to import
- paths = [root]
- else:
- paths.extend(list(root.rglob("*.py")))
- paths.append("__main__")
- for path in paths:
- filename = path.name if path != "__main__" else "__main__"
- if filename in ("setup.py", "console.py"):
- continue
- if path != "__main__" and root is not None:
- if root.is_file():
- module_name = root.stem
- else:
- module_name = (
- path.with_suffix("").relative_to(root).as_posix().replace("/", ".")
- )
- if package_name is not None:
- module_name = ".".join([package_name, module_name])
- else:
- module_name = "__main__"
- try:
- module = importlib.import_module(module_name)
- except:
- msg = f"Failed to import {module_name}:"
- if strict:
- raise ValueError(msg)
- else:
- logger.warning(msg)
- continue
- keys = list(module.__dict__.keys())
- for k in keys:
- v = module.__dict__[k]
- crawl_obj(
- obj=v,
- module_name=module_name,
- strict=strict,
- include_methods=include_methods,
- result=result,
- objs_result=objs_result,
- )
- return result, objs_result
diff --git a/mandala/_next/deps/deep_versions.py b/mandala/_next/deps/deep_versions.py
deleted file mode 100644
index dc4a0ba..0000000
--- a/mandala/_next/deps/deep_versions.py
+++ /dev/null
@@ -1,188 +0,0 @@
-from ..common_imports import *
-from .utils import DepKey, hash_dict
-from .model import (
- Node,
- CallableNode,
- GlobalVarNode,
- TerminalNode,
-)
-
-
-from .shallow_versions import DAG
-
-
-class Version:
- """
- Model of a "deep" version of a component that includes versions of its
- dependencies.
- """
-
- def __init__(
- self,
- component: DepKey,
- dynamic_deps_commits: Dict[DepKey, str],
- memoized_deps_content_versions: Dict[DepKey, Set[str]],
- ):
- ### raw data from the trace
- # the component whose dependencies are traced
- self.component = component
- # the content hashes of the direct dependencies
- self.direct_deps_commits = dynamic_deps_commits
- # pointers to content hashes of versions of memoized calls
- self.memoized_deps_content_versions = memoized_deps_content_versions
-
- ### cached data. These are set against a dependency state
- self._is_synced = False
- # the expanded set of dependencies, including all transitive
- # dependencies. Note this is a set of *content* hashes per dependency
- self._content_expansion: Dict[DepKey, Set[str]] = None
- # a hash uniquely identifying the content of dependencies of this version
- self._content_version: str = None
- # the semantic hashes of all dependencies for this version;
- # the system enforces that the semantic hash of a dependency is the same
- # for all commits of a component referenced by this version
- self._semantic_expansion: Dict[DepKey, str] = None
- # overall semantic hash of this version
- self._semantic_version: str = None
-
- @property
- def presentation(self) -> str:
- return f'Version of "{self.component[1]}" from module "{self.component[0]}" (content: {self.content_version}, semantic: {self.semantic_version})'
-
- @staticmethod
- def from_trace(
- component: DepKey, nodes: Dict[DepKey, Node], strict: bool = True
- ) -> "Version":
- dynamic_deps_commits = {}
- memoized_deps_content_versions = defaultdict(set)
- for dep_key, node in nodes.items():
- if isinstance(node, (CallableNode, GlobalVarNode)):
- dynamic_deps_commits[dep_key] = node.content_hash
- elif isinstance(node, TerminalNode):
- terminal_data = node.representation
- pointer_dep_key = terminal_data.dep_key
- version_content_hash = terminal_data.call_content_version
- memoized_deps_content_versions[pointer_dep_key].add(
- version_content_hash
- )
- else:
- raise ValueError(f"Unexpected node type {type(node)}")
- return Version(
- component=component,
- dynamic_deps_commits=dynamic_deps_commits,
- memoized_deps_content_versions=dict(memoized_deps_content_versions),
- )
-
- ############################################################################
- ### methods for setting cached data from a versioning state
- ############################################################################
- def _set_content_expansion(self, all_versions: Dict[DepKey, Dict[str, "Version"]]):
- result = defaultdict(set)
- for dep_key, content_hash in self.direct_deps_commits.items():
- result[dep_key].add(content_hash)
- for (
- dep_key,
- memoized_content_versions,
- ) in self.memoized_deps_content_versions.items():
- for memoized_content_version in memoized_content_versions:
- referenced_version = all_versions[dep_key][memoized_content_version]
- for (
- referenced_dep_key,
- referenced_content_hashes,
- ) in referenced_version.content_expansion.items():
- result[referenced_dep_key].update(referenced_content_hashes)
- self._content_expansion = dict(result)
-
- def _set_content_version(self):
- self._content_version = hash_dict(
- {
- dep_key: tuple(sorted(self.content_expansion[dep_key]))
- for dep_key in self.content_expansion
- }
- )
-
- def _set_semantic_expansion(
- self,
- component_dags: Dict[DepKey, DAG],
- all_versions: Dict[DepKey, Dict[str, "Version"]],
- ):
- result = {}
- # from own deps
- for dep_key, dep_content_hash in self.direct_deps_commits.items():
- dag = component_dags[dep_key]
- semantic_hash = dag.commits[dep_content_hash].semantic_hash
- result[dep_key] = semantic_hash
- # from pointers
- for (
- dep_key,
- memoized_content_versions,
- ) in self.memoized_deps_content_versions.items():
- for memoized_content_version in memoized_content_versions:
- dep_version_semantic_hashes = all_versions[dep_key][
- memoized_content_version
- ].semantic_expansion
- overlap = set(result.keys()).intersection(
- dep_version_semantic_hashes.keys()
- )
- if any(result[k] != dep_version_semantic_hashes[k] for k in overlap):
- raise ValueError(
- f"Version {self} has conflicting semantic hashes for {overlap}"
- )
- result.update(dep_version_semantic_hashes)
- self._semantic_expansion = result
- self._semantic_version = hash_dict(result)
-
- def sync(
- self,
- component_dags: Dict[DepKey, DAG],
- all_versions: Dict[DepKey, Dict[str, "Version"]],
- ):
- """
- Set all the cached things in the correct order
- """
- self._set_content_expansion(all_versions=all_versions)
- self._set_content_version()
- self._set_semantic_expansion(
- component_dags=component_dags, all_versions=all_versions
- )
- self.set_synced()
-
- @property
- def content_version(self) -> str:
- assert self._content_version is not None
- return self._content_version
-
- @property
- def semantic_version(self) -> str:
- assert self._semantic_version is not None
- return self._semantic_version
-
- @property
- def semantic_expansion(self) -> Dict[DepKey, str]:
- assert self._semantic_expansion is not None
- return self._semantic_expansion
-
- @property
- def content_expansion(self) -> Dict[DepKey, Set[str]]:
- assert self._content_expansion is not None
- return self._content_expansion
-
- @property
- def support(self) -> Iterable[DepKey]:
- return self.content_expansion.keys()
-
- @property
- def is_synced(self) -> bool:
- return self._is_synced
-
- def set_synced(self):
- # it can only go from unsynced to synced
- if self._is_synced:
- raise ValueError("Version is already synced")
- self._is_synced = True
-
- def __repr__(self) -> str:
- return f"""
-Version(
- dependencies={['.'.join(elt) for elt in self.support]},
-)"""
diff --git a/mandala/_next/deps/model.py b/mandala/_next/deps/model.py
deleted file mode 100644
index 731bb0e..0000000
--- a/mandala/_next/deps/model.py
+++ /dev/null
@@ -1,311 +0,0 @@
-import textwrap
-from abc import abstractmethod, ABC
-import types
-
-from ..common_imports import *
-from ..utils import get_content_hash
-from ..viz import (
- write_output,
-)
-
-from .utils import (
- DepKey,
- load_obj,
- get_runtime_description,
- extract_code,
- unknown_function,
- UNKNOWN_GLOBAL_VAR,
-)
-
-
-class Node(ABC):
- def __init__(self, module_name: str, obj_name: str, representation: Any):
- self.module_name = module_name
- self.obj_name = obj_name
- self.representation = representation
-
- @property
- def key(self) -> DepKey:
- return (self.module_name, self.obj_name)
-
- def present_key(self) -> str:
- raise NotImplementedError()
-
- @staticmethod
- @abstractmethod
- def represent(obj: Any) -> Any:
- raise NotImplementedError
-
- @abstractmethod
- def content(self) -> Any:
- raise NotImplementedError
-
- @abstractmethod
- def readable_content(self) -> str:
- raise NotImplementedError
-
- @property
- @abstractmethod
- def content_hash(self) -> str:
- raise NotImplementedError
-
- def load_obj(self, allow_fallback: bool) -> Any:
- obj, found = load_obj(module_name=self.module_name, obj_name=self.obj_name)
- if not found:
- msg = f"{self.present_key()} not found"
- if allow_fallback:
- logger.warning(msg)
- if hasattr(self, "FALLBACK_OBJ"):
- return self.FALLBACK_OBJ
- else:
- raise ValueError(f"No fallback object defined for {self.__class__}")
- else:
- raise ValueError(msg)
- return obj
-
-
-class CallableNode(Node):
- FALLBACK_OBJ = unknown_function
-
- def __init__(
- self,
- module_name: str,
- obj_name: str,
- representation: Optional[str],
- runtime_description: str,
- ):
- self.module_name = module_name
- self.obj_name = obj_name
- self.runtime_description = runtime_description
- if representation is not None:
- self._set_representation(value=representation)
- else:
- self._representation = None
- self._content_hash = None
-
- @staticmethod
- def from_obj(obj: Any, dep_key: DepKey) -> "CallableNode":
- representation = CallableNode.represent(obj=obj)
- code_obj = extract_code(obj)
- runtime_description = get_runtime_description(code=code_obj)
- return CallableNode(
- module_name=dep_key[0],
- obj_name=dep_key[1],
- representation=representation,
- runtime_description=runtime_description,
- )
-
- @staticmethod
- def from_runtime(
- module_name: str,
- obj_name: str,
- code_obj: types.CodeType,
- ) -> "CallableNode":
- return CallableNode(
- module_name=module_name,
- obj_name=obj_name,
- representation=None,
- runtime_description=get_runtime_description(code=code_obj),
- )
-
- @property
- def representation(self) -> str:
- return self._representation
-
- def _set_representation(self, value: str):
- assert isinstance(value, str)
- self._representation = value
- self._content_hash = get_content_hash(value)
-
- @representation.setter
- def representation(self, value: str):
- self._set_representation(value)
-
- @property
- def is_method(self) -> bool:
- return "." in self.obj_name
-
- def present_key(self) -> str:
- return f"function {self.obj_name} from module {self.module_name}"
-
- @property
- def class_name(self) -> str:
- assert self.is_method
- return ".".join(self.obj_name.split(".")[:-1])
-
- @staticmethod
- def represent(
- obj: Union[types.FunctionType, types.CodeType, Callable],
- allow_fallback: bool = False,
- ) -> str:
- if type(obj).__name__ == "Op":
- obj = obj.f
- if not isinstance(obj, (types.FunctionType, types.MethodType, types.CodeType)):
- logger.warning(f"Found {obj} of type {type(obj)}")
- try:
- source = inspect.getsource(obj)
- except Exception as e:
- msg = f"Could not get source for {obj} because {e}"
- if allow_fallback:
- source = inspect.getsource(CallableNode.FALLBACK_OBJ)
- logger.warning(msg)
- else:
- raise RuntimeError(msg)
- # strip whitespace to prevent different sources looking the same in the
- # ui
- lines = source.splitlines()
- lines = [line.rstrip() for line in lines]
- source = "\n".join(lines)
- return source
-
- def content(self) -> str:
- return self.representation
-
- def readable_content(self) -> str:
- return self.representation
-
- @property
- def content_hash(self) -> str:
- assert isinstance(self._content_hash, str)
- return self._content_hash
-
-
-class GlobalVarNode(Node):
- FALLBACK_OBJ = UNKNOWN_GLOBAL_VAR
-
- def __init__(
- self,
- module_name: str,
- obj_name: str,
- # (content hash, truncated repr)
- representation: Tuple[str, str],
- ):
- self.module_name = module_name
- self.obj_name = obj_name
- self._representation = representation
-
- @staticmethod
- def from_obj(obj: Any, dep_key: DepKey) -> "GlobalVarNode":
- representation = GlobalVarNode.represent(obj=obj)
- return GlobalVarNode(
- module_name=dep_key[0],
- obj_name=dep_key[1],
- representation=representation,
- )
-
- @property
- def representation(self) -> Tuple[str, str]:
- return self._representation
-
- @staticmethod
- def represent(obj: Any, allow_fallback: bool = False) -> Tuple[str, str]:
- truncated_repr = textwrap.shorten(text=repr(obj), width=80)
- try:
- content_hash = get_content_hash(obj=obj)
- except Exception as e:
- shortened_exception = textwrap.shorten(text=str(e), width=80)
- msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}"
- if allow_fallback:
- content_hash = UNKNOWN_GLOBAL_VAR
- logger.warning(msg)
- else:
- raise RuntimeError(msg)
- return content_hash, truncated_repr
-
- def present_key(self) -> str:
- return f"global variable {self.obj_name} from module {self.module_name}"
-
- def content(self) -> str:
- return self.representation
-
- def readable_content(self) -> str:
- return self.representation[1]
-
- @property
- def content_hash(self) -> str:
- assert isinstance(self.representation, tuple)
- return self.representation[0]
-
-
-class TerminalData:
- def __init__(
- self,
- op_internal_name: str,
- op_version: int,
- call_content_version: str,
- call_semantic_version: str,
- # data: Tuple[Tuple[str, int], Tuple[str, str]],
- dep_key: DepKey,
- ):
- # ((internal name, version), (content_version, semantic_version))
- self.op_internal_name = op_internal_name
- self.op_version = op_version
- self.call_content_version = call_content_version
- self.call_semantic_version = call_semantic_version
- self.dep_key = dep_key
-
-
-class TerminalNode(Node):
- def __init__(self, module_name: str, obj_name: str, representation: TerminalData):
- self.module_name = module_name
- self.obj_name = obj_name
- self.representation = representation
-
- @property
- def key(self) -> DepKey:
- return self.module_name, self.obj_name
-
- def present_key(self) -> str:
- raise NotImplementedError
-
- @property
- def content_hash(self) -> str:
- raise NotImplementedError
-
- def content(self) -> Any:
- raise NotImplementedError
-
- def readable_content(self) -> str:
- raise NotImplementedError
-
- @staticmethod
- def represent(obj: Any) -> Any:
- raise NotImplementedError
-
-
-class DependencyGraph:
- def __init__(self):
- self.nodes: Dict[DepKey, Node] = {}
- self.roots: Set[DepKey] = set()
- self.edges: Set[Tuple[DepKey, DepKey]] = set()
-
- def get_trace_state(self) -> Tuple[DepKey, Dict[DepKey, Node]]:
- if len(self.roots) != 1:
- raise ValueError(f"Expected exactly one root, got {len(self.roots)}")
- component = list(self.roots)[0]
- return component, self.nodes
-
- def show(self, path: Optional[Path] = None, how: str = "none"):
- dot = to_dot(self)
- output_ext = "svg" if how in ["browser"] else "png"
- return write_output(
- dot_string=dot, output_path=path, output_ext=output_ext, show_how=how
- )
-
- def __repr__(self) -> str:
- if len(self.nodes) == 0:
- return "DependencyGraph()"
- return to_string(self)
-
- def add_node(self, node: Node):
- self.nodes[node.key] = node
-
- def add_edge(self, source: Node, target: Node):
- if source.key not in self.nodes:
- self.add_node(source)
- if target.key not in self.nodes:
- self.add_node(target)
- self.edges.add((source.key, target.key))
-
-
-from .viz import to_dot, to_string
diff --git a/mandala/_next/deps/shallow_versions.py b/mandala/_next/deps/shallow_versions.py
deleted file mode 100644
index 98282f8..0000000
--- a/mandala/_next/deps/shallow_versions.py
+++ /dev/null
@@ -1,463 +0,0 @@
-from typing import Literal
-import textwrap
-from ..common_imports import *
-from ..utils import get_content_hash
-from ..config import Config
-from ..utils import ask_user
-from ..viz import _get_colorized_diff, _get_diff
-
-if Config.has_rich:
- from rich.tree import Tree
- from rich.panel import Panel
- from rich.text import Text
- from rich.syntax import Syntax
-
-
-# TODO: figure out how to apply compact diffs
-def get_diff(a: str, b: str) -> Tuple[str, str]:
- """
- Get a diff between two strings
- """
- return (a, b)
-
-
-def apply_diff(b: str, diff: Tuple[str, str]) -> str:
- """
- Apply a diff to a string
- """
- return diff[0]
-
-
-class Commit:
- """
- Tracks versions of the "shallow" content of a single component. For
- functions/methods, this is just the function source code, without any
- dependencies. For global variables, this is just (the content hash of) the
- value of the variable.
- """
-
- def __init__(
- self,
- parents: List[str],
- diffs: List[Any],
- content_hash: str,
- semantic_hash: str,
- content: Optional[str],
- ):
- # content hashes of parent commits
- self.parents = parents # currently there may be at most one parent
- # diffs between this commit and its parents
- self.diffs = diffs
- self.content_hash = content_hash
- # content hash of the semantic version this commit is associated with
- self.semantic_hash = semantic_hash
- # content of this commit, if it is a root commit
- self._content = content
- self.check_invariants()
-
- def check_invariants(self):
- assert len(self.parents) == len(self.diffs)
- assert len(self.parents) > 0 or self._content is not None
-
- def __repr__(self) -> str:
- return f"Commit(content_hash={self.content_hash}, semantic_hash={self.semantic_hash}, parents={self.parents})"
-
-
-T = TypeVar("T")
-from typing import Generic
-
-
-class ContentAdapter(Generic[T]):
- def get_diff(self, a: T, b: T) -> Any:
- """
- Get a diff between two objects
- """
- raise NotImplementedError()
-
- def apply_diff(self, b: T, diff: Any) -> T:
- """
- Apply a diff to an object
- """
- raise NotImplementedError()
-
- def get_presentable_content(self, content: T) -> str:
- """
- Get a presentable string representation of the content
- """
- raise NotImplementedError()
-
- def get_content_hash(self, content: T) -> str:
- raise NotImplementedError()
-
-
-class StringContentAdapter(ContentAdapter[str]):
- def get_diff(self, a: str, b: str) -> Tuple[str, str]:
- """
- Get a diff between two strings
- """
- return get_diff(a, b)
-
- def apply_diff(self, b: str, diff: Tuple[str, str]) -> str:
- """
- Apply a diff to a string
- """
- return apply_diff(b, diff)
-
- def get_presentable_content(self, content: str) -> str:
- """
- Get a presentable string representation of the content
- """
- return content
-
- def get_content_hash(self, content: str) -> str:
- return get_content_hash(content)
-
-
-GVContent = Tuple[str, str] # (content hash, repr)
-
-
-class GlobalVariableContentAdapter(ContentAdapter[GVContent]):
- def get_diff(self, a: GVContent, b: GVContent) -> Tuple[GVContent, GVContent]:
- """
- Get a diff between two global variable contents
- """
- return (a, b)
-
- def apply_diff(self, b: GVContent, diff: Tuple[GVContent, GVContent]) -> GVContent:
- """
- Apply a diff to a global variable content
- """
- return diff[0]
-
- def get_presentable_content(self, content: GVContent) -> str:
- return content[1]
-
- def get_content_hash(self, content: GVContent) -> str:
- return content[0]
-
-
-class DAG(Generic[T]):
- """
- Organizes the shallow versions of a single component in a `git`-like DAG.
- """
-
- def __init__(self, content_type: Literal["code", "global_variable"] = "code"):
- # content hash of the current head
- self.head: Optional[str] = None
- self.commits: Dict[str, Commit] = {}
- self._initial_commit: Optional[T] = None
- if content_type == "code":
- self.content_adapter = StringContentAdapter()
- elif content_type == "global_variable":
- self.content_adapter = GlobalVariableContentAdapter()
- else:
- raise ValueError(f"Invalid content_type: {content_type}")
- self.check_invariants()
-
- @property
- def initial_commit(self) -> str:
- assert self.head is not None
- return self._initial_commit
-
- def check_invariants(self):
- if self.head is not None:
- assert self.head in self.commits
- for commit, commit_obj in self.commits.items():
- commit_obj.check_invariants()
- assert all(p in self.commits for p in commit_obj.parents)
- assert commit_obj.content_hash == commit
- assert commit_obj.semantic_hash in self.commits
-
- def get_current_content(self) -> T:
- # return the full content of the current head
- assert self.head is not None
- return self.get_content(commit=self.head)
-
- def get_presentable_content(self, commit: str) -> str:
- return self.content_adapter.get_presentable_content(
- content=self.get_content(commit=commit)
- )
-
- def get_content(self, commit: str) -> T:
- # return the full content of a commit given its content hash
- if commit not in self.commits:
- raise ValueError(f"Commit {commit} not in DAG")
- commit_obj = self.commits[commit]
- if commit_obj._content is not None:
- return commit_obj._content
- else:
- parent_content = self.get_content(commit_obj.parents[0])
- return self.content_adapter.apply_diff(parent_content, commit_obj.diffs[0])
-
- def init(self, initial_content: T) -> str:
- """
- Initialize the DAG with the initial content, and set the head to the
- initial commit. Return the content hash of the initial commit.
- """
- # initialize the DAG with the initial content
- assert self.head is None
- content_hash = self.content_adapter.get_content_hash(content=initial_content)
- semantic_hash = content_hash
- commit = Commit(
- parents=[],
- diffs=[],
- content_hash=content_hash,
- semantic_hash=semantic_hash,
- content=initial_content,
- )
- self.head = content_hash
- self.commits[content_hash] = commit
- self._initial_commit = content_hash
- return content_hash
-
- def checkout(self, commit: str, implicit_merge: bool = False):
- """
- checkout an *existing* commit, i.e. set the head to the given commit.
- if implicit_merge is True, an edge will be added from the new head to
- the old head
- """
- if commit not in self.commits:
- raise ValueError(f"Commit {commit} not in DAG")
- if implicit_merge:
- raise NotImplementedError
- self.head = commit
-
- def commit(self, content: T, is_semantic_change: Optional[bool] = None) -> str:
- """
- Commit a *new* version of the content and return the content hash
- """
- assert self.head is not None
- content_hash = self.content_adapter.get_content_hash(content=content)
- assert content_hash not in self.commits
- head_commit = self.commits[self.head]
- if is_semantic_change is None:
- presentable_diff = get_diff(
- self.content_adapter.get_presentable_content(content),
- self.get_presentable_content(commit=self.head),
- )
- print(
- _get_colorized_diff(
- current=presentable_diff[1], new=presentable_diff[0]
- )
- )
- answer = ask_user(
- question="Does this change require recomputation of dependent calls?\nWARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\nAnswer: [y]es/[n]o/[a]bort",
- valid_options=["y", "n", "a"],
- )
- print(f'You answered: "{answer}"')
- if answer == "a":
- raise ValueError("Aborting commit")
- is_semantic_change = answer == "y"
- if is_semantic_change:
- semantic_hash = content_hash
- else:
- semantic_hash = head_commit.semantic_hash
- diff = self.content_adapter.get_diff(
- content, self.get_content(commit=self.head)
- )
- commit = Commit(
- parents=[self.head],
- diffs=[diff],
- content_hash=content_hash,
- semantic_hash=semantic_hash,
- content=None,
- )
- self.head = content_hash
- self.commits[content_hash] = commit
- return content_hash
-
- def sync(self, content: T, is_semantic_change: Optional[bool] = None) -> str:
- """
- if the content:
- - is the current head, do nothing
- - is in the DAG, checkout the content
- - is not in the DAG, commit the content
-
- return the content hash
- """
- assert self.head is not None
- content_hash = self.content_adapter.get_content_hash(content)
- if self.head == content_hash:
- result = self.head
- elif content_hash in self.commits:
- self.checkout(content_hash)
- result = self.head
- else:
- result = self.commit(content, is_semantic_change=is_semantic_change)
- assert self.head == content_hash
- return result
-
- ############################################################################
- ### visualization and printing
- ############################################################################
- def _get_tree_neighbors_representation(self) -> Dict[str, Set[str]]:
- """
- Get a {parent: {children}} representation of the tree underlying the DAG
- (obtained by following the first parent of each commit).
- """
- result = defaultdict(set)
- for commit in self.commits.values():
- if len(commit.parents) > 0:
- parent = commit.parents[0]
- result[parent].add(commit.content_hash)
- return dict(result)
-
- def get_commit_presentation(
- self, commit: str, diff_only: bool, include_metadata: bool = False
- ) -> Tuple[str, str]:
- if diff_only:
- if commit == self.initial_commit:
- content_to_show = self.get_presentable_content(commit)
- content_type = "code"
- else:
- parent_content = self.get_content(self.commits[commit].parents[0])
- child_content = self.content_adapter.apply_diff(
- b=parent_content, diff=self.commits[commit].diffs[0]
- )
- parent_presentable_content = (
- self.content_adapter.get_presentable_content(content=parent_content)
- )
- child_presentable_content = (
- self.content_adapter.get_presentable_content(content=child_content)
- )
- content_to_show = _get_diff(
- current=parent_presentable_content,
- new=child_presentable_content,
- )
- content_type = "diff"
- else:
- content_to_show = self.get_presentable_content(commit)
- content_type = "code"
-
- content_version = commit
- semantic_version = self.commits[commit].semantic_hash
-
- header_lines = []
- if commit == self.head:
- header_lines.append(f"### ===HEAD===")
- if include_metadata:
- header_lines.append(f"### content_commit={content_version}")
- header_lines.append(f"### semantic_commit={semantic_version}")
- if len(header_lines) > 0:
- header = "\n".join(header_lines)
- content_to_show = f"{header}\n{content_to_show}"
- return content_to_show, content_type
-
- if Config.has_rich:
- from rich.panel import Panel
- from rich.tree import Tree
-
- def get_commit_content_rich(
- self,
- commit: str,
- diff_only: bool = False,
- title: Optional[str] = None,
- include_metadata: bool = False,
- ) -> Panel:
- """
- Get a rich panel representing the content and metadata of a commit.
- """
- content_to_show, content_type = self.get_commit_presentation(
- commit=commit, diff_only=diff_only, include_metadata=include_metadata
- )
- if title is not None:
- title = Text(title, style="bold")
- lexer = "python" if content_type == "code" else "diff"
- content = Syntax(
- content_to_show,
- lexer=lexer,
- line_numbers=False,
- theme="solarized-light",
- )
- return Panel(renderable=content, title=title)
-
- def get_tree_rich(
- self, compact: bool = False, include_metadata: bool = False
- ) -> Tree:
- """
- Get a rich tree representing the tree underlying the DAG (obtained
- by following the first parent of each commit).
- """
- assert Config.has_rich
- if self.head is None:
- return Tree(label="DAG(head=None)")
- tree_neighbors = self._get_tree_neighbors_representation()
- tree_objs = {}
- initial_commit = self.initial_commit
-
- result = Tree(
- label=self.get_commit_content_rich(
- initial_commit, include_metadata=include_metadata
- )
- )
- tree_objs[initial_commit] = result
-
- def grow(commit: str):
- if commit in tree_neighbors:
- for child in tree_neighbors[commit]:
- current_tree = tree_objs[commit]
- new_tree = current_tree.add(
- self.get_commit_content_rich(
- child,
- diff_only=compact,
- include_metadata=include_metadata,
- )
- )
- tree_objs[child] = new_tree
- grow(child)
-
- grow(initial_commit)
- return result
-
- def __repr__(self) -> str:
- num_content = len(self.commits)
- num_semantic = len(set(c.semantic_hash for c in self.commits.values()))
- return f"DAG(head={self.head}) with {num_content} content version(s) and {num_semantic} semantic version(s)"
-
- def show(
- self, compact: bool = False, plain: bool = False, include_metadata: bool = False
- ):
- if Config.has_rich and not plain:
- rich.print(
- self.get_tree_rich(compact=compact, include_metadata=include_metadata)
- )
- return
- else:
- if self.head is None:
- return "DAG(head=None)"
- commits = list(self.commits.keys())
- commits = [self.head] + [k for k in commits if k != self.head]
- lines = []
- lines.append(
- self.get_commit_presentation(
- commit=self.initial_commit,
- diff_only=compact,
- include_metadata=include_metadata,
- )[0]
- )
- lines.append("--------")
- tree_neighbors = self._get_tree_neighbors_representation()
-
- def visit(commit: str, depth: int):
- if commit in tree_neighbors:
- for child in tree_neighbors[commit]:
- child_text = self.get_commit_presentation(
- commit=child,
- diff_only=compact,
- include_metadata=include_metadata,
- )[0]
- child_text = textwrap.indent(child_text, " " * (depth + 1))
- lines.append(child_text)
- lines.append("--------")
- visit(child, depth + 1)
-
- visit(self.initial_commit, 0)
- print("\n".join(lines))
-
- @property
- def size(self) -> int:
- return len(self.commits)
-
- @property
- def semantic_size(self) -> int:
- return len(set(c.semantic_hash for c in self.commits.values()))
diff --git a/mandala/_next/deps/tracers/__init__.py b/mandala/_next/deps/tracers/__init__.py
deleted file mode 100644
index 5211fa9..0000000
--- a/mandala/_next/deps/tracers/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .tracer_base import TracerABC
-from .dec_impl import DecTracer
-from .sys_impl import SysTracer
diff --git a/mandala/_next/deps/tracers/dec_impl.py b/mandala/_next/deps/tracers/dec_impl.py
deleted file mode 100644
index 4de9a3a..0000000
--- a/mandala/_next/deps/tracers/dec_impl.py
+++ /dev/null
@@ -1,271 +0,0 @@
-import types
-from functools import wraps, update_wrapper
-from ...common_imports import *
-from ...utils import unwrap_decorators
-from ...config import Config
-from ..model import (
- DependencyGraph,
- CallableNode,
- GlobalVarNode,
- TerminalData,
- TerminalNode,
-)
-from ..utils import (
- get_global_names_candidates,
- is_global_val,
- extract_code,
- extract_func_obj,
- DepKey,
- GlobalsStrictness,
-)
-from .tracer_base import TracerABC, get_module_flow, KEEP, get_closure_names
-
-
-class DecTracerConfig:
- allow_class_tracking = True
- restrict_global_accesses = True
- allow_owned_class_accesses = False
- allow_nonfunc_attributes = False
-
-
-class TracerState:
- tracer: Optional["DecTracer"] = None
- registry: Dict[DepKey, Any] = {}
-
- @staticmethod
- def is_tracked(f: Union[types.FunctionType, type]) -> bool:
- assert isinstance(f, (types.FunctionType, type))
- dep_key = (f.__module__, f.__qualname__)
- return dep_key in TracerState.registry and TracerState.registry[dep_key] is f
-
-
-class TrackedDict(dict):
- """
- A dictionary that tracks global variable accesses.
- """
- def __init__(self, original: dict):
- self.__original__ = original
-
- def __getitem__(self, __key: str) -> Any:
- result = self.__original__.__getitem__(__key)
- if TracerState.tracer is not None:
- tracer = TracerState.tracer
- unwrapped_result = unwrap_decorators(result, strict=False)
- if isinstance(unwrapped_result, (types.FunctionType, type)):
- is_owned = tracer.is_owned_obj(obj=unwrapped_result)
- is_cls_access = isinstance(result, type)
- if (
- is_owned
- and is_cls_access
- and not DecTracerConfig.allow_owned_class_accesses
- ):
- raise ValueError(
- f"Attempting to access class {result} from module {unwrapped_result.__module__}."
- )
- is_tracked = TracerState.is_tracked(unwrapped_result)
- if is_owned and not is_tracked:
- raise ValueError(
- f"Function/class {result} from module {unwrapped_result.__module__} is accessed but not tracked"
- )
- elif is_global_val(result):
- TracerState.tracer.register_global_access(key=__key, value=result)
- else:
- if (
- DecTracerConfig.restrict_global_accesses
- and not GlobalsStrictness.is_excluded(result)
- ):
- raise ValueError(
- f"Accessing global value {result} of type {type(result)} is not allowed"
- )
- return result
-
-
-def make_tracked_copy(f: types.FunctionType) -> types.FunctionType:
- result = types.FunctionType(
- code=f.__code__,
- globals=TrackedDict(f.__globals__),
- name=f.__name__,
- argdefs=f.__defaults__,
- closure=f.__closure__,
- )
- result = update_wrapper(result, f)
- result.__module__ = f.__module__
- result.__kwdefaults__ = copy.deepcopy(f.__kwdefaults__)
- result.__annotations__ = copy.deepcopy(f.__annotations__)
- return result
-
-
-def get_nonfunc_attributes(cls: type) -> Dict[str, Any]:
- result = {}
- for k, v in cls.__dict__.items():
- if not k.startswith("__") and not isinstance(
- unwrap_decorators(v, strict=False), (types.FunctionType, type)
- ):
- result[k] = v
- return result
-
-
-def track(obj: Union[types.FunctionType, type]) -> "obj":
- if isinstance(obj, type):
- if not DecTracerConfig.allow_class_tracking:
- raise ValueError("Class tracking is not allowed")
- if not DecTracerConfig.allow_nonfunc_attributes:
- nonfunc_attributes = get_nonfunc_attributes(obj)
- if len(nonfunc_attributes) > 0:
- raise ValueError(
- f"Class tracking for {obj} is not allowed: found non-function attributes {nonfunc_attributes}"
- )
- # decorate all methods/classes in the class
- for k, v in obj.__dict__.items():
- if isinstance(v, (types.FunctionType, type)):
- setattr(obj, k, track(v))
- TracerState.registry[(obj.__module__, obj.__qualname__)] = obj
- return obj
- elif isinstance(obj, types.FunctionType):
- obj = make_tracked_copy(unwrap_decorators(obj, strict=True))
-
- @wraps(obj)
- def wrapper(*args, **kwargs) -> Any:
- tracer = DecTracer.get_active_trace_obj()
- if tracer is not None:
- node = tracer.register_call(func=obj)
- outcome = obj(*args, **kwargs)
- tracer.register_return(node)
- return outcome
- else:
- return obj(*args, **kwargs)
-
- TracerState.registry[(obj.__module__, obj.__qualname__)] = unwrap_decorators(
- obj, strict=True
- )
- return wrapper
- elif type(obj).__name__ == Config.func_interface_cls_name:
- obj.func_op.func = make_tracked_copy(f=obj.func_op.func)
- TracerState.registry[
- (obj.func_op.func.__module__, obj.func_op.func.__qualname__)
- ] = obj.func_op.func
- return obj
- else:
- raise TypeError("Can only track callable objects")
-
-
-class DecTracer(TracerABC):
- """
- A decorator-based tracer that tracks function calls and global variable accesses.
- """
- def __init__(
- self,
- paths: List[Path],
- graph: Optional[DependencyGraph] = None,
- strict: bool = True,
- allow_methods: bool = False,
- ):
- self.call_stack: List[CallableNode] = []
- self.graph = DependencyGraph() if graph is None else graph
- self.paths = paths
- self.strict = strict
- self.allow_methods = allow_methods
-
- self._traced = {}
- self._traced_funcs = {}
-
- def is_owned_obj(self, obj: Union[types.FunctionType, type]) -> bool:
- module_name = obj.__module__
- return get_module_flow(module_name=module_name, paths=self.paths) == KEEP
-
- @staticmethod
- def get_active_trace_obj() -> Optional["DecTracer"]:
- return TracerState.tracer
-
- @staticmethod
- def set_active_trace_obj(trace_obj: Optional["DecTracer"]):
- if trace_obj is not None and TracerState.tracer is not None:
- raise ValueError("Tracer already active")
- TracerState.tracer = trace_obj
-
- def get_globals(self, func: Callable) -> List[GlobalVarNode]:
- result = []
- code_obj = extract_code(obj=func)
- global_scope = extract_func_obj(obj=func, strict=self.strict).__globals__
- for name in get_global_names_candidates(code=code_obj):
- # names used by the function; not all of them are global variables
- if name in global_scope.keys():
- global_val = global_scope[name]
- if not is_global_val(global_val):
- continue
- node = GlobalVarNode.from_obj(
- obj=global_val, dep_key=(func.__module__, name)
- )
- result.append(node)
- return result
-
- def register_call(self, func: Callable) -> CallableNode:
- module_name = func.__module__
- qualname = func.__qualname__
- # check for closure variables
- closure_names = get_closure_names(
- code_obj=func.__code__, func_qualname=qualname
- )
- if len(closure_names) > 0:
- msg = f"Found closure variables accessed by function {module_name}.{qualname}:\n{closure_names}"
- raise ValueError(msg)
- ### get call node
- node = CallableNode.from_runtime(
- module_name=module_name, obj_name=qualname, code_obj=extract_code(obj=func)
- )
- self.call_stack.append(node)
- self.graph.add_node(node)
- if len(self.call_stack) > 1:
- parent = self.call_stack[-2]
- assert parent is not None
- self.graph.add_edge(parent, node)
- ### get globals
- global_nodes = self.get_globals(func=func)
- for global_node in global_nodes:
- self.graph.add_edge(node, global_node)
- if len(self.call_stack) == 1:
- # this is the root of the graph
- self.graph.roots.add(node.key)
- return node
-
- def register_global_access(self, key: str, value: Any):
- assert len(self.call_stack) > 0
- calling_node = self.call_stack[-1]
- node = GlobalVarNode.from_obj(
- obj=value, dep_key=(calling_node.module_name, key)
- )
- self.graph.add_edge(calling_node, node)
-
- def register_return(self, node: CallableNode):
- assert self.call_stack[-1] == node
- self.call_stack.pop()
-
- @staticmethod
- def register_leaf_event(trace_obj: "DecTracer", data: TerminalData):
- unique_id = "_".join(
- [
- data.op_internal_name,
- str(data.op_version),
- data.call_content_version,
- data.call_semantic_version,
- ]
- )
- module_name = data.dep_key[0]
- node = TerminalNode(
- module_name=module_name, obj_name=unique_id, representation=data
- )
- if len(trace_obj.call_stack) > 0:
- trace_obj.graph.add_edge(trace_obj.call_stack[-1], node)
- return
-
- @staticmethod
- def leaf_signal(data):
- # a way to detect the end of a trace
- raise NotImplementedError
-
- def __enter__(self):
- DecTracer.set_active_trace_obj(self)
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- DecTracer.set_active_trace_obj(None)
diff --git a/mandala/_next/deps/tracers/sys_impl.py b/mandala/_next/deps/tracers/sys_impl.py
deleted file mode 100644
index c3b0d05..0000000
--- a/mandala/_next/deps/tracers/sys_impl.py
+++ /dev/null
@@ -1,243 +0,0 @@
-import types
-from ...common_imports import *
-from ..utils import (
- get_func_qualname,
- is_global_val,
- get_global_names_candidates,
-)
-from ..model import (
- DependencyGraph,
- CallableNode,
- TerminalData,
- TerminalNode,
- GlobalVarNode,
-)
-import sys
-import importlib
-from .tracer_base import TracerABC, get_closure_names
-
-################################################################################
-### tracer
-################################################################################
-# control flow constants
-from .tracer_base import BREAK, CONTINUE, KEEP, MAIN, get_module_flow
-
-LEAF_SIGNAL = "leaf_signal"
-
-# constants for special Python function names
-LAMBDA = ""
-COMPREHENSIONS = ("", "", "", "")
-SKIP_FRAMES = tuple(list(COMPREHENSIONS) + [LAMBDA])
-
-
-class SysTracer(TracerABC):
- def __init__(
- self,
- paths: List[Path],
- graph: Optional[DependencyGraph] = None,
- strict: bool = True,
- allow_methods: bool = False,
- ):
- self.call_stack: List[Optional[CallableNode]] = []
- self.graph = DependencyGraph() if graph is None else graph
- self.paths = paths
- self.path_strs = [str(path) for path in paths]
- self.strict = strict
- self.allow_methods = allow_methods
-
- @staticmethod
- def leaf_signal(data):
- # a way to detect the end of a trace
- pass
-
- @staticmethod
- def register_leaf_event(trace_obj: types.FunctionType, data: Any):
- SysTracer.leaf_signal(data)
-
- @staticmethod
- def get_active_trace_obj() -> Optional[Any]:
- return sys.gettrace()
-
- @staticmethod
- def set_active_trace_obj(trace_obj: Any):
- sys.settrace(trace_obj)
-
- def _process_failure(self, msg: str):
- if self.strict:
- raise RuntimeError(msg)
- else:
- logger.warning(msg)
-
- def find_most_recent_call(self) -> Optional[CallableNode]:
- if len(self.call_stack) == 0:
- return None
- else:
- # return the most recent non-None obj on the stack
- for i in range(len(self.call_stack) - 1, -1, -1):
- call = self.call_stack[i]
- if isinstance(call, CallableNode):
- return call
- return None
-
- def __enter__(self):
- if sys.gettrace() is not None:
- # pre-check this is used correctly
- raise RuntimeError("Another tracer is already active")
-
- def tracer(frame: types.FrameType, event: str, arg: Any):
- if event not in ("call", "return"):
- return
- module_name = frame.f_globals.get("__name__")
- # fast check to rule out non-user code
- if event == "call":
- try:
- module = importlib.import_module(module_name)
- if not any(
- [
- module.__file__.startswith(path_str)
- for path_str in self.path_strs
- ]
- ):
- return
- except:
- if module_name != MAIN:
- return
- code_obj = frame.f_code
- func_name = code_obj.co_name
- if event == "return":
- logging.debug(f"Returning from {func_name}")
- if len(self.call_stack) > 0:
- popped = self.call_stack.pop()
- logging.debug(f"Popped {popped} from call stack")
- # some sanity checks
- if func_name in SKIP_FRAMES:
- if popped != func_name:
- self._process_failure(
- f"Expected to pop {func_name} from call stack, but popped {popped}"
- )
- else:
- if popped.obj_name.split(".")[-1] != func_name:
- self._process_failure(
- f"Expected to pop {func_name} from call stack, but popped {popped.obj_name}"
- )
- else:
- # something went wrong
- raise RuntimeError("Call stack is empty")
- return
-
- if func_name == LEAF_SIGNAL:
- data: TerminalData = frame.f_locals["data"]
- unique_id = "_".join(
- [
- data.op_internal_name,
- str(data.op_version),
- data.call_content_version,
- data.call_semantic_version,
- ]
- )
- node = TerminalNode(
- module_name=module_name, obj_name=unique_id, representation=data
- )
- most_recent_option = self.find_most_recent_call()
- if most_recent_option is not None:
- self.graph.add_edge(source=most_recent_option, target=node)
- # self.call_stack.append(None)
- return
-
- module_control_flow = get_module_flow(
- paths=self.paths, module_name=module_name
- )
- if module_control_flow in (BREAK, CONTINUE):
- frame.f_trace = None
- return
-
- logger.debug(f"Tracing call to {module_name}.{func_name}")
-
- ### get the qualified name of the function/method
- func_qualname = get_func_qualname(
- func_name=func_name, code=code_obj, frame=frame
- )
- if "." in func_qualname:
- if not self.allow_methods:
- raise RuntimeError(
- f"Methods are currently not supported: {func_qualname} from {module_name}"
- )
-
- ### detect use of closure variables
- closure_names = get_closure_names(
- code_obj=code_obj, func_qualname=func_qualname
- )
- if len(closure_names) > 0 and func_name not in SKIP_FRAMES:
- closure_values = {
- var: frame.f_locals.get(var, frame.f_globals.get(var, None))
- for var in closure_names
- }
- msg = f"Found closure variables accessed by function {module_name}.{func_name}:\n{closure_values}"
- self._process_failure(msg=msg)
-
- ### get the global variables used by the function
- globals_nodes = []
- for name in get_global_names_candidates(code=code_obj):
- # names used by the function; not all of them are global variables
- if name in frame.f_globals:
- global_val = frame.f_globals[name]
- if not is_global_val(global_val):
- continue
- node = GlobalVarNode.from_obj(
- obj=global_val, dep_key=(module_name, name)
- )
- globals_nodes.append(node)
-
- ### if this is a comprehension call, add the globals to the most
- ### recent tracked call
- if func_name in SKIP_FRAMES:
- most_recent_tracked_call = self.find_most_recent_call()
- assert most_recent_tracked_call is not None
- for global_node in globals_nodes:
- self.graph.add_edge(
- source=most_recent_tracked_call, target=global_node
- )
- self.call_stack.append(func_name)
- return tracer
-
- ### manage the call stack
- call_node = CallableNode.from_runtime(
- module_name=module_name, obj_name=func_qualname, code_obj=code_obj
- )
- self.graph.add_node(node=call_node)
- ### global variable edges from this function always exist
- for global_node in globals_nodes:
- self.graph.add_edge(source=call_node, target=global_node)
- ### call edges exist only if there is a caller on the stack
- if len(self.call_stack) > 0:
- # find the most recent tracked call
- most_recent_tracked_call = self.find_most_recent_call()
- if most_recent_tracked_call is not None:
- self.graph.add_edge(
- source=most_recent_tracked_call, target=call_node
- )
- self.call_stack.append(call_node)
- if len(self.call_stack) == 1:
- self.graph.roots.add(call_node.key)
- return tracer
-
- sys.settrace(tracer)
-
- def __exit__(self, *exc_info):
- sys.settrace(None) # Stop tracing
-
-
-class SuspendSysTraceContext:
- def __init__(self):
- self.suspended_trace = None
-
- def __enter__(self) -> "SuspendSysTraceContext":
- if sys.gettrace() is not None:
- self.suspended_trace = sys.gettrace()
- sys.settrace(None)
- return self
-
- def __exit__(self, *exc_info):
- if self.suspended_trace is not None:
- sys.settrace(self.suspended_trace)
- self.suspended_trace = None
diff --git a/mandala/_next/deps/tracers/tracer_base.py b/mandala/_next/deps/tracers/tracer_base.py
deleted file mode 100644
index 0081ca4..0000000
--- a/mandala/_next/deps/tracers/tracer_base.py
+++ /dev/null
@@ -1,91 +0,0 @@
-from ...common_imports import *
-from ...config import Config
-import importlib
-from ..model import DependencyGraph, CallableNode
-from abc import ABC, abstractmethod
-
-
-class TracerABC(ABC):
- def __init__(
- self,
- paths: List[Path],
- strict: bool = True,
- allow_methods: bool = False,
- ):
- self.call_stack: List[Optional[CallableNode]] = []
- self.graph = DependencyGraph()
- self.paths = paths
- self.strict = strict
- self.allow_methods = allow_methods
-
- @abstractmethod
- def __enter__(self):
- raise NotImplementedError
-
- @abstractmethod
- def __exit__(self, exc_type, exc_val, exc_tb):
- raise NotImplementedError
-
- @staticmethod
- @abstractmethod
- def get_active_trace_obj() -> Optional[Any]:
- raise NotImplementedError
-
- @staticmethod
- @abstractmethod
- def set_active_trace_obj(trace_obj: Any):
- raise NotImplementedError
-
- @staticmethod
- @abstractmethod
- def register_leaf_event(trace_obj: Any, data: Any):
- raise NotImplementedError
-
-
-BREAK = "break" # stop tracing (currently doesn't really work b/c python)
-CONTINUE = "continue" # continue tracing, but don't add call to dependencies
-KEEP = "keep" # continue tracing and add call to dependencies
-MAIN = "__main__"
-
-
-def get_closure_names(code_obj: types.CodeType, func_qualname: str) -> Tuple[str]:
- closure_vars = code_obj.co_freevars
- if "." in func_qualname and "__class__" in closure_vars:
- closure_vars = tuple([var for var in closure_vars if var != "__class__"])
- return closure_vars
-
-
-def get_module_flow(module_name: Optional[str], paths: List[Path]) -> str:
- if module_name is None:
- return BREAK
- if module_name == MAIN:
- return KEEP
- try:
- module = importlib.import_module(module_name)
- is_importable = True
- except ModuleNotFoundError:
- is_importable = False
- if not is_importable:
- return BREAK
- try:
- module_path = Path(inspect.getfile(module))
- except TypeError:
- # this happens when the module is a built-in module
- return BREAK
- if (
- not any(root in module_path.parents for root in paths)
- and module_path not in paths
- ):
- # module is not in the paths we're inspecting; stop tracing
- logger.debug(f" Module {module_name} not in paths, BREAK")
- return BREAK
- elif module_name.startswith(Config.module_name) and not module_name.startswith(
- Config.tests_module_name
- ):
- # this function is part of `mandala` functionality. Continue tracing
- # but don't add it to the dependency state
- logger.debug(f" Module {module_name} is mandala, CONTINUE")
- return CONTINUE
- else:
- logger.debug(f" Module {module_name} is not mandala but in paths, KEEP")
- return KEEP
diff --git a/mandala/_next/deps/utils.py b/mandala/_next/deps/utils.py
deleted file mode 100644
index bad7405..0000000
--- a/mandala/_next/deps/utils.py
+++ /dev/null
@@ -1,228 +0,0 @@
-import types
-import dis
-import importlib
-import gc
-from typing import Literal
-
-from ..common_imports import *
-from ..utils import get_content_hash, unwrap_decorators
-from ..config import Config
-
-DepKey = Tuple[str, str] # (module name, object address in module)
-
-
-class GlobalsStrictness:
- SCALARS = "scalars"
- DATA = "data"
- ALL = "all"
-
- @staticmethod
- def is_excluded(obj: Any) -> bool:
- return (
- inspect.ismodule(obj) # exclude modules
- or isinstance(obj, type) # exclude classes
- or inspect.isfunction(obj) # exclude functions
- or callable(obj) # exclude callables
- or type(obj).__name__
- == Config.func_interface_cls_name #! a hack to exclude memoized functions
- )
-
- @staticmethod
- def is_scalar(obj: Any) -> bool:
- result = isinstance(obj, (int, float, str, bool, type(None)))
- return result
-
- @staticmethod
- def is_data(obj: Any) -> bool:
- if GlobalsStrictness.is_scalar(obj):
- result = True
- elif type(obj) in (tuple, list):
- result = all(GlobalsStrictness.is_data(x) for x in obj)
- elif type(obj) is dict:
- result = all(GlobalsStrictness.is_data((x, y)) for (x, y) in obj.items())
- elif type(obj) in (np.ndarray, pd.DataFrame, pd.Series, pd.Index):
- result = True
- else:
- result = False
- # if not result and not GlobalsStrictness.is_callable(obj):
- # logger.warning(f'Access to global variable "{obj}" is not tracked because it is not a scalar or a data structure')
- return result
-
-
-def is_global_val(obj: Any, strictness: str = "data") -> bool:
- if strictness == GlobalsStrictness.SCALARS:
- return GlobalsStrictness.is_scalar(obj=obj)
- elif strictness == GlobalsStrictness.DATA:
- return GlobalsStrictness.is_data(obj=obj)
- elif strictness == GlobalsStrictness.ALL:
- return not (
- inspect.ismodule(obj) # exclude modules
- or isinstance(obj, type) # exclude classes
- or inspect.isfunction(obj) # exclude functions
- or callable(obj) # exclude callables
- or type(obj).__name__
- == Config.func_interface_cls_name #! a hack to exclude memoized functions
- )
- else:
- raise ValueError(
- f"Unknown strictness level for tracking global variables: {strictness}"
- )
-
-
-def is_callable_obj(obj: Any, strict: bool) -> bool:
- if type(obj).__name__ == Config.func_interface_cls_name:
- return True
- if isinstance(obj, types.FunctionType):
- return True
- if not strict and callable(obj): # quite permissive
- return True
- return False
-
-
-def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType:
- if type(obj).__name__ == Config.func_interface_cls_name:
- return obj.f
- obj = unwrap_decorators(obj, strict=strict)
- if isinstance(obj, types.BuiltinFunctionType):
- raise ValueError(f"Expected a non-built-in function, but got {obj}")
- if not isinstance(obj, types.FunctionType):
- if not strict:
- if (
- isinstance(obj, type)
- and hasattr(obj, "__init__")
- and isinstance(obj.__init__, types.FunctionType)
- ):
- return obj.__init__
- else:
- return unknown_function
- else:
- raise ValueError(f"Expected a function, but got {obj} of type {type(obj)}")
- return obj
-
-
-def extract_code(obj: Callable) -> types.CodeType:
- if type(obj).__name__ == Config.func_interface_cls_name:
- obj = obj.f
- if isinstance(obj, property):
- obj = obj.fget
- obj = unwrap_decorators(obj, strict=True)
- if not isinstance(obj, (types.FunctionType, types.MethodType)):
- logger.debug(f"Expected a function or method, but got {type(obj)}")
- # raise ValueError(f"Expected a function or method, but got {obj}")
- return obj.__code__
-
-
-def get_runtime_description(code: types.CodeType) -> Any:
- assert isinstance(code, types.CodeType)
- return get_sanitized_bytecode_representation(code=code)
-
-
-def get_global_names_candidates(code: types.CodeType) -> Set[str]:
- result = set()
- instructions = list(dis.get_instructions(code))
- for instr in instructions:
- if instr.opname == "LOAD_GLOBAL":
- result.add(instr.argval)
- if isinstance(instr.argval, types.CodeType):
- result.update(get_global_names_candidates(instr.argval))
- return result
-
-
-def get_sanitized_bytecode_representation(
- code: types.CodeType,
-) -> List[dis.Instruction]:
- instructions = list(dis.get_instructions(code))
- result = []
- for instr in instructions:
- if isinstance(instr.argval, types.CodeType):
- result.append(
- dis.Instruction(
- instr.opname,
- instr.opcode,
- instr.arg,
- get_sanitized_bytecode_representation(instr.argval),
- "",
- instr.offset,
- instr.starts_line,
- is_jump_target=instr.is_jump_target,
- )
- )
- else:
- result.append(instr)
- return result
-
-
-def unknown_function():
- # this is a placeholder function that we use to get the source of
- # functions that we can't get the source of
- pass
-
-
-UNKNOWN_GLOBAL_VAR = "UNKNOWN_GLOBAL_VAR"
-
-
-def get_bytecode(f: Union[types.FunctionType, types.CodeType, str]) -> str:
- if isinstance(f, str):
- f = compile(f, "", "exec")
- instructions = dis.get_instructions(f)
- return "\n".join([str(i) for i in instructions])
-
-
-def hash_dict(d: dict) -> str:
- return get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())])
-
-
-def load_obj(module_name: str, obj_name: str) -> Tuple[Any, bool]:
- module = importlib.import_module(module_name)
- parts = obj_name.split(".")
- current = module
- found = True
- for part in parts:
- if not hasattr(current, part):
- found = False
- break
- else:
- current = getattr(current, part)
- return current, found
-
-
-def get_dep_key_from_func(func: types.FunctionType) -> DepKey:
- module_name = func.__module__
- qualname = func.__qualname__
- return module_name, qualname
-
-
-def get_func_qualname(
- func_name: str,
- code: types.CodeType,
- frame: types.FrameType,
-) -> str:
- # this is evil
- referrers = gc.get_referrers(code)
- func_referrers = [r for r in referrers if isinstance(r, types.FunctionType)]
- matching_name = [r for r in func_referrers if r.__name__ == func_name]
- if len(matching_name) != 1:
- return get_func_qualname_fallback(func_name=func_name, code=code, frame=frame)
- else:
- return matching_name[0].__qualname__
-
-
-def get_func_qualname_fallback(
- func_name: str, code: types.CodeType, frame: types.FrameType
-) -> str:
- # get the argument names to *try* to tell if the function is a method
- arg_names = code.co_varnames[: code.co_argcount]
- # a necessary but not sufficient condition for this to
- # be a method
- is_probably_method = (
- len(arg_names) > 0
- and arg_names[0] == "self"
- and hasattr(frame.f_locals["self"].__class__, func_name)
- )
- if is_probably_method:
- # handle nested classes via __qualname__
- cls_qualname = frame.f_locals["self"].__class__.__qualname__
- func_qualname = f"{cls_qualname}.{func_name}"
- else:
- func_qualname = func_name
- return func_qualname
diff --git a/mandala/_next/deps/versioner.py b/mandala/_next/deps/versioner.py
deleted file mode 100644
index 3552080..0000000
--- a/mandala/_next/deps/versioner.py
+++ /dev/null
@@ -1,596 +0,0 @@
-import typing
-from collections import OrderedDict
-import textwrap
-
-from ..common_imports import *
-from ..utils import is_subdict
-from ..config import Config
-from .utils import DepKey, hash_dict
-from .model import (
- Node,
- CallableNode,
- GlobalVarNode,
- TerminalNode,
- DependencyGraph,
-)
-from .crawler import crawl_static
-from .tracers import TracerABC
-from ..viz import _get_colorized_diff
-
-if Config.has_rich:
- from rich.panel import Panel
- from rich.text import Text
- from rich.syntax import Syntax
- from rich.console import Group
-
-from .shallow_versions import DAG
-from .deep_versions import Version
-
-
-class CodeState:
- def __init__(self, nodes: Dict[DepKey, Node]):
- self.nodes = nodes
-
- def __repr__(self) -> str:
- lines = []
- for dep_key, node in self.nodes.items():
- lines.append(f"{dep_key}:")
- lines.append(f"{node.content()}")
- return "\n".join(lines)
-
- def get_content_version(self, support: Iterable[DepKey]) -> str:
- return hash_dict({k: self.nodes[k].content_hash for k in support})
-
- def add_globals_from(self, graph: DependencyGraph):
- for node in graph.nodes.values():
- if isinstance(node, GlobalVarNode) and node.key not in self.nodes:
- self.nodes[node.key] = node
-
-
-class Versioner:
- def __init__(
- self,
- paths: List[Path],
- TracerCls: type,
- strict,
- track_methods,
- package_name: Optional[str] = None,
- ):
- assert len(paths) in [0, 1]
- self.paths = paths
- self.TracerCls = TracerCls
- self.strict = strict
- self.allow_methods = track_methods
- self.package_name = package_name
- self.global_topology: DependencyGraph = DependencyGraph()
- self.nodes: Dict[DepKey, Node] = {}
- self.component_dags: Dict[DepKey, DAG] = {}
- # all versions here must be synced with the DAGs already
- self.versions: Dict[DepKey, Dict[str, Version]] = {}
- columns = [
- "pre_call_uid",
- "semantic_version",
- "content_version",
- "outputs",
- ]
- self.df = pd.DataFrame(columns=columns)
-
- def get_version_ids(
- self,
- pre_call_uid: str,
- tracer_option: Optional[TracerABC],
- is_recompute: bool,
- ) -> Tuple[Optional[str], Optional[str]]:
- """
- Get the content and semantic IDs for the version corresponding to the
- given pre-call uid.
-
- Inputs:
- - `is_recompute`: this should be true only if this is a call with
- transient outputs that we already computed once.
- """
- assert tracer_option is not None
- version = self.process_trace(
- graph=tracer_option.graph,
- pre_call_uid=pre_call_uid,
- outputs=None,
- is_recompute=is_recompute,
- )
- content_version = version.content_version
- semantic_version = version.semantic_version
- return content_version, semantic_version
-
- def update_global_topology(self, graph: DependencyGraph):
- for node in graph.nodes.values():
- if isinstance(node, (CallableNode, GlobalVarNode)):
- self.global_topology.add_node(node)
- for edge in graph.edges:
- if (
- edge[0] in self.global_topology.nodes.keys()
- and edge[1] in self.global_topology.nodes.keys()
- ):
- self.global_topology.edges.add(edge)
-
- def make_tracer(self) -> TracerABC:
- return self.TracerCls(
- paths=[Config.mandala_path] + self.paths,
- strict=self.strict,
- allow_methods=self.allow_methods,
- )
-
- def guess_code_state(self) -> CodeState:
- result_graph = DependencyGraph()
- fallback_result = {}
- for dep_key in self.global_topology.nodes.keys():
- node = self.global_topology.nodes[dep_key]
- if (
- isinstance(node, (GlobalVarNode, CallableNode))
- and dep_key not in result_graph.nodes.keys()
- ):
- obj = node.load_obj(allow_fallback=not self.strict)
- fallback_result[dep_key] = node.from_obj(obj=obj, dep_key=dep_key)
- nodes = {**result_graph.nodes, **fallback_result}
- # fill in the gaps with a static crawl
- static_result, objs = crawl_static(
- root=None if len(self.paths) == 0 else self.paths[0],
- strict=self.strict,
- package_name=self.package_name,
- include_methods=self.allow_methods,
- )
- for dep_key, node in static_result.items():
- if dep_key not in nodes.keys():
- nodes[dep_key] = node
- result = CodeState(nodes=nodes)
- return result
-
- def get_codestate_semantic_hashes(
- self,
- code_state: CodeState,
- ) -> Optional[Dict[DepKey, str]]:
- """
- Given a code state, return the semantic hashes of the components found
- in this code state that *also* appear in the global component topology,
- or None if the code state is not fully compatible with the commits we
- have in the DAGs.
- """
- result = {}
- if not self.global_topology.nodes.keys() <= code_state.nodes.keys():
- extra_keys = self.global_topology.nodes.keys() - code_state.nodes.keys()
- raise ValueError(
- f"Found extra keys in global topology not in code state: {extra_keys}."
- )
- for component in code_state.nodes.keys():
- if component in self.global_topology.nodes.keys():
- component_content_hash = code_state.nodes[component].content_hash
- dag = self.component_dags[component]
- if component_content_hash not in dag.commits:
- print(
- f"Could not find commit for {component} with content hash {component_content_hash}"
- )
- return None
- result[component] = dag.commits[component_content_hash].semantic_hash
- return result
-
- def apply_state_hypothesis(
- self, hypothesis: CodeState, trace_result: Dict[DepKey, Node]
- ):
- keys_to_remove = []
- for trace_dep_key, trace_node in trace_result.items():
- if isinstance(trace_node, TerminalNode):
- continue
- if trace_dep_key not in hypothesis.nodes.keys() and isinstance(
- trace_node, GlobalVarNode
- ):
- continue
- if trace_dep_key in hypothesis.nodes.keys():
- if isinstance(trace_node, GlobalVarNode):
- if (
- not hypothesis.nodes[trace_dep_key].content()
- == trace_node.content()
- ):
- print(f"Content mismatch for {trace_dep_key}")
- print(f"Expected: {hypothesis.nodes[trace_dep_key].content()}")
- print(f"Actual: {trace_node.content()}")
- print(
- f"Diff:\n{_get_colorized_diff(hypothesis.nodes[trace_dep_key].content(), trace_node.content())}"
- )
- raise ValueError(f"Content mismatch for {trace_dep_key}")
- elif isinstance(trace_node, CallableNode):
- runtime_description_hypothesis = hypothesis.nodes[
- trace_dep_key
- ].runtime_description
- if runtime_description_hypothesis != trace_node.runtime_description:
- if self.strict:
- raise ValueError(
- f"Bytecode mismatch for {trace_dep_key} (you probably need to re-import the module)"
- )
- trace_node.representation = hypothesis.nodes[
- trace_dep_key
- ].representation
- else:
- continue
- else:
- if self.strict:
- raise ValueError(
- f"Unexpected dependency {trace_dep_key}; expected {textwrap.shorten(str(hypothesis.nodes.keys()), width=80)}"
- )
- else:
- keys_to_remove.append(trace_dep_key)
- for key in keys_to_remove:
- del trace_result[key]
-
- def get_semantic_version(
- self, semantic_hashes: Dict[DepKey, str], support: Iterable[DepKey]
- ) -> str:
- return hash_dict({k: semantic_hashes[k] for k in support})
-
- def init_component(self, component: DepKey, node: Node, initial_content: str):
- """
- Initialize a new component with an initial state.
- """
- if isinstance(node, CallableNode):
- content_type = "code"
- elif isinstance(node, GlobalVarNode):
- content_type = "global_variable"
- else:
- raise ValueError(f"Unexpected node type {type(node)}")
- dag = DAG(content_type=content_type)
- dag.init(initial_content=initial_content)
- self.nodes[component] = node
- self.component_dags[component] = dag
- self.versions[component] = {}
-
- def sync_codebase(self, code_state: CodeState):
- """
- Sync all the known components from the current state of the codebase.
- """
- dags = copy.deepcopy(self.component_dags)
- for component, dag in dags.items():
- content = code_state.nodes[component].content()
- content_hash = code_state.nodes[component].content_hash
- if content_hash not in dag.commits.keys() and dag.head is not None:
- dependent_versions = self.get_dependent_versions(
- dep_key=component, commit=dag.head
- )
- dependent_versions_presentation = textwrap.indent(
- text="\n".join([v.presentation for v in dependent_versions]),
- prefix=" ",
- )
- print(f"CHANGE DETECTED in {component[1]} from module {component[0]}")
- print(f"Dependent components:\n{dependent_versions_presentation}")
- print(f"===DIFF===:")
- dag.sync(content=content)
- # update the DAGs if all commits succeeded
- self.component_dags = dags
-
- def sync_component(
- self,
- component: DepKey,
- is_semantic_change: Optional[bool],
- code_state: CodeState,
- ) -> str:
- """
- Sync a single component from the current state of the codebase. Useful
- as a low-level API for testing.
- """
- commit = self.component_dags[component].sync(
- content=code_state.nodes[component].content(),
- is_semantic_change=is_semantic_change,
- )
- return commit
-
- def get_current_versions(
- self, component: DepKey, code_state: CodeState
- ) -> List[Version]:
- code_semantic_hashes = self.get_codestate_semantic_hashes(code_state=code_state)
- result = []
- if code_semantic_hashes is None:
- return result
- for _, version in self.versions[component].items():
- if is_subdict(version.semantic_expansion, code_semantic_hashes):
- result.append(version)
- return result
-
- def get_semantically_compatible_versions(
- self, component: DepKey, code_state: CodeState
- ) -> List[Version]:
- code_semantic_hashes = self.get_codestate_semantic_hashes(code_state=code_state)
- if code_semantic_hashes is None:
- return []
- result = []
- for version in self.versions[component].values():
- if all(
- [
- version.semantic_expansion[dep_key] == code_semantic_hashes[dep_key]
- for dep_key in version.semantic_expansion.keys()
- ]
- ):
- result.append(version)
- return result
-
- ############################################################################
- ### processing traces
- ############################################################################
- def create_new_components_from_nodes(self, nodes: Dict[DepKey, Node]):
- """
- Given the result of a trace, create any components necessary.
- """
- ### new components must be found among the nodes in the trace result
- for dep_key, node in nodes.items():
- if dep_key not in self.nodes and not isinstance(node, TerminalNode):
- content = node.content()
- self.init_component(
- component=dep_key, node=node, initial_content=content
- )
-
- def sync_version(self, version: Version, require_exists: bool = False) -> Version:
- # TODO - this is impure
- version.sync(component_dags=self.component_dags, all_versions=self.versions)
- if version.content_version not in self.versions[version.component]:
- if require_exists:
- raise ValueError(f"Version {version} does not exist in VersioningState")
- # logging.info(f'Adding new version for {version.component}')
- self.versions[version.component][version.content_version] = version
- return version
-
- def lookup_call(
- self, component: DepKey, pre_call_uid: str, code_state: CodeState
- ) -> Optional[Tuple[str, str]]:
- """
- Return a tuple of (content_version, semantic_version), or None if the
- call is not found.
-
- Inputs:
- - `pre_call_uid`: this is a hash of the content IDs of the inputs,
- together with the function's name.
-
- This works as follows:
- - we figure out the semantic hashes (i.e. shallow semantic versions) of
- the elements of the code state present in the global topology we have on
- record
- - we restrict to the records that match the given `pre_call_uid`
- - we search among these
- """
- codebase_semantic_hashes = self.get_codestate_semantic_hashes(
- code_state=code_state
- )
- if codebase_semantic_hashes is None:
- return None
- candidates = self.df[self.df["pre_call_uid"] == pre_call_uid]
- if len(candidates) == 0:
- return None
- else:
- content_versions = candidates["content_version"].values.tolist()
- semantic_versions = candidates["semantic_version"].values.tolist()
- for content_version, semantic_version in zip(
- content_versions, semantic_versions
- ):
- version = self.versions[component][content_version]
- codebase_semantic = self.get_semantic_version(
- semantic_hashes=codebase_semantic_hashes, support=version.support
- )
- if codebase_semantic == semantic_version:
- return content_version, semantic_version
- return None
-
- def process_trace(
- self,
- graph: DependencyGraph,
- pre_call_uid: str,
- outputs: Any,
- is_recompute: bool,
- ) -> Version:
- component, nodes = graph.get_trace_state()
- self.create_new_components_from_nodes(nodes=nodes)
- version = Version.from_trace(
- component=component, nodes=nodes, strict=self.strict
- )
- version = self.sync_version(version=version)
- row = {
- "pre_call_uid": pre_call_uid,
- "semantic_version": version.semantic_version,
- "content_version": version.content_version,
- "outputs": outputs,
- }
- if not is_recompute:
- self._check_semantic_distinguishability(
- component=component, pre_call_uid=pre_call_uid, call_version=version
- )
- # logging.info(f"Adding new call for {pre_call_uid} for {component}")
- self.df = pd.concat([self.df, pd.DataFrame([row])], ignore_index=True)
- return version
-
- def _check_semantic_distinguishability(
- self, component: DepKey, pre_call_uid: str, call_version: Version
- ):
- ### check semantic distinguishability between calls
- # TODO: make this more efficient
- candidates = self.df[self.df["pre_call_uid"] == pre_call_uid]
- existing_semantic_expansions = {}
- for content_version, semantic_version in zip(
- candidates["content_version"].values.tolist(),
- candidates["semantic_version"].values.tolist(),
- ):
- if semantic_version not in existing_semantic_expansions.keys():
- existing_call_version = self.versions[component][content_version]
- existing_semantic_expansions[
- semantic_version
- ] = existing_call_version.semantic_expansion
- new_semantic_dep_hashes = call_version.semantic_expansion
- for semantic_version, semantic_hashes in existing_semantic_expansions.items():
- overlap = set(semantic_hashes.keys()).intersection(
- set(new_semantic_dep_hashes.keys())
- )
- if all([semantic_hashes[k] == new_semantic_dep_hashes[k] for k in overlap]):
- raise ValueError(
- f"Call to {component} with pre_call_uid={pre_call_uid} is not semantically distinguishable from call for semantic version {semantic_version}"
- )
-
- ############################################################################
- ### inspecting the state
- ############################################################################
- def get_flat_versions(self) -> Dict[str, Version]:
- return {
- k: v
- for component, versions in self.versions.items()
- for k, v in versions.items()
- }
-
- def get_dependent_versions(self, dep_key: DepKey, commit: str) -> List[Version]:
- """
- Get a list of versions of components dependent on a given commit to a
- given component
- """
- dep_semantic = self.component_dags[dep_key].commits[commit].semantic_hash
- return [
- version
- for version in self.get_flat_versions().values()
- if version.semantic_expansion.get(dep_key) == dep_semantic
- ]
-
- def present_dependencies(
- self,
- commits: Dict[DepKey, str],
- include_metadata: bool = True,
- header: Optional[str] = None,
- ) -> str:
- """
- Get a code snippet for a given state of some dependencies
- """
- result_lines = []
- if header is not None:
- result_lines.extend(header.splitlines())
- module_groups = self.get_canonical_groups(components=commits.keys())
- for module_name, components_in_module in module_groups.items():
- result_lines.append(80 * "#")
- result_lines.append(f'### IN MODULE "{module_name}"')
- result_lines.append(80 * "#")
- commits_in_module = {k: commits[k] for k in components_in_module}
- nodes = {k: self.nodes[k] for k in commits_in_module.keys()}
- semantic_hashes = {
- k: self.component_dags[k].commits[v].semantic_hash
- for k, v in commits_in_module.items()
- }
- global_keys = {k for k, v in nodes.items() if isinstance(v, GlobalVarNode)}
- callable_keys = {k for k, v in nodes.items() if isinstance(v, CallableNode)}
- metadatas = {
- k: f"### {nodes[k].present_key()}\n### content_commit={commits_in_module[k]}\n### semantic_commit={semantic_hashes[k]}"
- for k in commits_in_module.keys()
- }
- for global_key in sorted(global_keys):
- if include_metadata:
- result_lines.append(metadatas[global_key])
- global_name = global_key[1]
- result_lines.append(
- f"{global_name} = {self.component_dags[global_key].get_presentable_content(commits_in_module[global_key])}"
- )
- result_lines.append("")
- callable_keys = list(sorted(callable_keys))
- is_class_start = []
- for i, callable_key in enumerate(callable_keys):
- this_is_class = "." in callable_key[1]
- if i == 0 and this_is_class:
- is_class_start.append(True)
- continue
- prev_is_class = "." in callable_keys[i - 1][1]
- if this_is_class and not prev_is_class:
- is_class_start.append(True)
- continue
- if (
- this_is_class
- and prev_is_class
- and callable_keys[i - 1][1].rsplit(".", maxsplit=1)[0]
- != callable_key[1].rsplit(".", maxsplit=1)[0]
- ):
- is_class_start.append(True)
- continue
- is_class_start.append(False)
- for i, callable_key in enumerate(callable_keys):
- if is_class_start[i]:
- result_lines.append(
- f"### in class {callable_key[1].rsplit('.', 1)[0]}:"
- )
- if include_metadata:
- result_lines.append(metadatas[callable_key])
- result_lines.append(
- self.component_dags[callable_key].get_presentable_content(
- commits_in_module[callable_key]
- )
- )
- result_lines.append("")
- return "\n".join(result_lines)
-
- def show_versions(
- self,
- component: DepKey,
- only_semantic: bool = False,
- include_metadata: bool = True,
- plain: bool = False,
- ):
- versions_dict = self.versions[component]
- if only_semantic:
- # use just 1 semantic representative per content version
- versions = list(
- {v.semantic_version: v for v in versions_dict.values()}.values()
- )
- else:
- versions = list(versions_dict.values())
- if Config.has_rich and not plain:
- version_panels: List[Panel] = []
- for version in versions:
- header_lines = [
- f"### Dependencies for version of {self.nodes[component].present_key()}"
- ]
- header_lines.append(f"### content_version_id={version.content_version}")
- header_lines.append(
- f"### semantic_version_id={version.semantic_version}\n\n"
- )
- version_panels.append(
- Panel(
- Syntax(
- self.present_dependencies(
- header="\n".join(header_lines),
- commits=version.semantic_expansion,
- include_metadata=include_metadata,
- ),
- lexer="python",
- theme="solarized-light",
- ),
- title=None,
- expand=True,
- )
- )
- rich.print(Group(*version_panels))
- else:
- for version in versions:
- print(version.presentation)
- print(
- textwrap.indent(
- self.present_dependencies(
- commits=version.semantic_expansion,
- include_metadata=include_metadata,
- ),
- prefix=" ",
- )
- )
-
- def get_canonical_groups(
- self, components: Iterable[DepKey]
- ) -> typing.OrderedDict[str, List[DepKey]]:
- """
- Order components by module name alphabetically, and within each module,
- put the global variables first, then the callables.
- """
- result = OrderedDict()
- for component in components:
- module_name = component[0]
- if module_name not in result:
- result[module_name] = []
- result[module_name].append(component)
- for module_name, module_components in result.items():
- result[module_name] = sorted(
- module_components,
- key=lambda x: (isinstance(self.nodes[x], CallableNode), x),
- )
- result = OrderedDict(sorted(result.items(), key=lambda x: x[0]))
- return result
diff --git a/mandala/_next/deps/viz.py b/mandala/_next/deps/viz.py
deleted file mode 100644
index 82907fb..0000000
--- a/mandala/_next/deps/viz.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import textwrap
-from ..common_imports import *
-from ..viz import (
- Node as DotNode,
- Edge as DotEdge,
- Group as DotGroup,
- to_dot_string,
- SOLARIZED_LIGHT,
-)
-
-from .utils import (
- DepKey,
-)
-
-
-def to_string(graph: "model.DependencyGraph") -> str:
- """
- Get a string for pretty-printing.
- """
- # group the nodes by module
- module_groups: Dict[str, List["model.Node"]] = {}
- for key, node in graph.nodes.items():
- module_name, _ = key
- module_groups.setdefault(module_name, []).append(node)
- lines = []
- for module_name, nodes in module_groups.items():
- global_nodes = [node for node in nodes if isinstance(node, model.GlobalVarNode)]
- callable_nodes = [
- node for node in nodes if isinstance(node, model.CallableNode)
- ]
- module_desc = f"MODULE: {module_name}"
- lines.append(module_desc)
- lines.append("-" * len(module_desc))
- lines.append("===Global Variables===")
- for node in global_nodes:
- desc = f"{node.obj_name} = {node.readable_content()}"
- lines.append(textwrap.indent(desc, 4 * " "))
- # lines.append(f" {node.diff_representation()}")
- lines.append("")
- lines.append("===Functions===")
- # group the methods by class
- method_nodes = [node for node in callable_nodes if node.is_method]
- func_nodes = [node for node in callable_nodes if not node.is_method]
- methods_by_class: Dict[str, List["model.CallableNode"]] = {}
- for method_node in method_nodes:
- methods_by_class.setdefault(method_node.class_name, []).append(method_node)
- for class_name, method_nodes in methods_by_class.items():
- lines.append(textwrap.indent(f"class {class_name}:", 4 * " "))
- for node in method_nodes:
- desc = node.readable_content()
- lines.append(textwrap.indent(textwrap.dedent(desc), 8 * " "))
- lines.append("")
- for node in func_nodes:
- desc = node.readable_content()
- lines.append(textwrap.indent(textwrap.dedent(desc), 4 * " "))
- lines.append("")
- return "\n".join(lines)
-
-
-def to_dot(graph: "model.DependencyGraph") -> str:
- nodes: Dict[DepKey, DotNode] = {}
- module_groups: Dict[str, DotGroup] = {} # module name -> Group
- class_groups: Dict[str, DotGroup] = {} # class name -> Group
- for key, node in graph.nodes.items():
- module_name, obj_addr = key
- if module_name not in module_groups:
- module_groups[module_name] = DotGroup(
- label=module_name, nodes=[], parent=None
- )
- if isinstance(node, model.GlobalVarNode):
- color = SOLARIZED_LIGHT["red"]
- elif isinstance(node, model.CallableNode):
- color = (
- SOLARIZED_LIGHT["blue"]
- if not node.is_method
- else SOLARIZED_LIGHT["violet"]
- )
- else:
- color = SOLARIZED_LIGHT["base03"]
- dot_node = DotNode(
- internal_name=".".join(key), label=node.obj_name, color=color
- )
- nodes[key] = dot_node
- module_groups[module_name].nodes.append(dot_node)
- if isinstance(node, model.CallableNode) and node.is_method:
- class_name = node.class_name
- class_groups.setdefault(
- class_name,
- DotGroup(
- label=class_name,
- nodes=[],
- parent=module_groups[module_name],
- ),
- ).nodes.append(dot_node)
- edges: Dict[Tuple[DotNode, DotNode], DotEdge] = {}
- for source, target in graph.edges:
- source_node = nodes[source]
- target_node = nodes[target]
- edge = DotEdge(source_node=source_node, target_node=target_node)
- edges[(source_node, target_node)] = edge
- dot_string = to_dot_string(
- nodes=list(nodes.values()),
- edges=list(edges.values()),
- groups=list(module_groups.values()) + list(class_groups.values()),
- rankdir="BT",
- )
- return dot_string
-
-
-from . import model
diff --git a/mandala/_next/imports.py b/mandala/_next/imports.py
deleted file mode 100644
index 66ff18c..0000000
--- a/mandala/_next/imports.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from .storage import Storage
-from .model import op, Ignore, NewArgDefault
-from .tps import MList, MDict
-from .deps.tracers.dec_impl import track
-
-from .common_imports import sess
-
-
-def pprint_dict(d) -> str:
- return '\n'.join([f" {k}: {v}" for k, v in d.items()])
\ No newline at end of file
diff --git a/mandala/_next/tests/test_cfs.py b/mandala/_next/tests/test_cfs.py
deleted file mode 100644
index b75ad16..0000000
--- a/mandala/_next/tests/test_cfs.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from mandala._next.imports import *
-
-
-def test_single_func():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage:
- for i in range(10):
- inc(i)
-
- cf = storage.cf(inc)
- df = cf.df()
- assert df.shape == (10, 3)
- assert (df['output_0'] == df['x'] + 1).all()
\ No newline at end of file
diff --git a/mandala/_next/tests/test_memoization.py b/mandala/_next/tests/test_memoization.py
deleted file mode 100644
index 2f47bda..0000000
--- a/mandala/_next/tests/test_memoization.py
+++ /dev/null
@@ -1,133 +0,0 @@
-from mandala._next.imports import *
-
-
-def test_storage():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage:
- x = 1
- y = inc(x)
- z = inc(2)
- w = inc(y)
-
- assert w.cid == z.cid
- assert w.hid != y.hid
- assert w.cid != y.cid
- assert storage.unwrap(y) == 2
- assert storage.unwrap(z) == 3
- assert storage.unwrap(w) == 3
- for ref in (y, z, w):
- assert storage.attach(ref).in_memory
- assert storage.attach(ref).obj == storage.unwrap(ref)
-
-
-def test_signatures():
- storage = Storage()
-
- @op # a function with a wild input/output signature
- def add(x, *args, y: int = 1, **kwargs):
- # just sum everything
- res = x + sum(args) + y + sum(kwargs.values())
- if kwargs:
- return res, kwargs
- elif args:
- return None
- else:
- return res
-
- with storage:
- # call the func in all the ways
- sum_1 = add(1)
- sum_2 = add(1, 2, 3, 4, )
- sum_3 = add(1, 2, 3, 4, y=5)
- sum_4 = add(1, 2, 3, 4, y=5, z=6)
- sum_5 = add(1, 2, 3, 4, z=5, w=7)
-
- assert storage.unwrap(sum_1) == 2
- assert storage.unwrap(sum_2) == None
- assert storage.unwrap(sum_3) == None
- assert storage.unwrap(sum_4) == (21, {'z': 6})
- assert storage.unwrap(sum_5) == (23, {'z': 5, 'w': 7})
-
-
-def test_retracing():
- storage = Storage()
-
- @op
- def inc(x):
- return x + 1
-
- ### iterating a function
- with storage:
- start = 1
- for i in range(10):
- start = inc(start)
-
- with storage:
- start = 1
- for i in range(10):
- start = inc(start)
-
- ### composing functions
- @op
- def add(x, y):
- return x + y
-
- with storage:
- inp = [1, 2, 3, 4, 5]
- stage_1 = [inc(x) for x in inp]
- stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)]
-
- with storage:
- inp = [1, 2, 3, 4, 5]
- stage_1 = [inc(x) for x in inp]
- stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)]
-
-
-def test_lists():
- storage = Storage()
-
- @op
- def get_sum(elts: MList[int]) -> int:
- return sum(elts)
-
- @op
- def primes_below(n: int) -> MList[int]:
- primes = []
- for i in range(2, n):
- for p in primes:
- if i % p == 0:
- break
- else:
- primes.append(i)
- return primes
-
- @op
- def chunked_square(elts: MList[int]) -> MList[int]:
- # a model for an op that does something on chunks of a big thing
- # to prevent OOM errors
- return [x*x for x in elts]
-
- with storage:
- n = 10
- primes = primes_below(n)
- sum_primes = get_sum(primes)
- assert len(primes) == 4
- # check indexing
- assert storage.unwrap(primes[0]) == 2
- assert storage.unwrap(primes[:2]) == [2, 3]
-
- ### lists w/ overlapping elements
- with storage:
- n = 100
- primes = primes_below(n)
- for i in range(0, len(primes), 2):
- sum_primes = get_sum(primes[:i+1])
-
- with storage:
- elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- squares = chunked_square(elts)
diff --git a/mandala/_next/utils.py b/mandala/_next/utils.py
deleted file mode 100644
index bbfb65a..0000000
--- a/mandala/_next/utils.py
+++ /dev/null
@@ -1,249 +0,0 @@
-from .common_imports import *
-import joblib
-import io
-import inspect
-import prettytable
-import sqlite3
-from .config import *
-from abc import ABC, abstractmethod
-
-def dataframe_to_prettytable(df: pd.DataFrame) -> str:
- # Initialize a PrettyTable object
- table = prettytable.PrettyTable()
-
- # Set the column names
- table.field_names = df.columns.tolist()
-
- # Add rows to the table
- for row in df.itertuples(index=False):
- table.add_row(row)
-
- # Return the pretty-printed table as a string
- return table.get_string()
-
-
-def serialize(obj: Any) -> bytes:
- """
- ! this may lead to different serializations for objects x, y such that x
- ! == y in Python. This is because of things like set ordering, which is not
- ! determined by the contents of the set. For example, {1, 2} and {2, 1} would
- ! `serialize()` to different things, but they would be equal in Python.
- """
- buffer = io.BytesIO()
- joblib.dump(obj, buffer)
- return buffer.getvalue()
-
-
-def deserialize(value: bytes) -> Any:
- buffer = io.BytesIO(value)
- return joblib.load(buffer)
-
-
-def get_content_hash(obj: Any) -> str:
- if hasattr(obj, "__get_mandala_dict__"):
- obj = obj.__get_mandala_dict__()
- if Config.has_torch:
- # TODO: ideally, should add a label to distinguish this from a numpy
- # array with the same contents!
- obj = tensor_to_numpy(obj)
- if isinstance(obj, pd.DataFrame):
- # DataFrames cause collisions for joblib hashing for some reason
- # TODO: the below may be incomplete
- obj = {
- "columns": obj.columns,
- "values": obj.values,
- "index": obj.index,
- }
- result = joblib.hash(obj) # this hash is canonical wrt python collections
- if result is None:
- raise RuntimeError("joblib.hash returned None")
- return result
-
-
-def dump_output_name(index: int, output_names: Optional[List[str]] = None) -> str:
- if output_names is not None and index < len(output_names):
- return output_names[index]
- else:
- return f"output_{index}"
-
-
-def parse_output_name(name: str) -> int:
- return int(name.split("_")[-1])
-
-
-def get_setdict_union(
- a: Dict[str, Set[str]], b: Dict[str, Set[str]]
-) -> Dict[str, Set[str]]:
- return {k: a.get(k, set()) | b.get(k, set()) for k in a.keys() | b.keys()}
-
-
-def get_setdict_intersection(
- a: Dict[str, Set[str]], b: Dict[str, Set[str]]
-) -> Dict[str, Set[str]]:
- return {k: a[k] & b[k] for k in a.keys() & b.keys()}
-
-
-def get_dict_union_over_keys(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
- return {k: a[k] if k in a else b[k] for k in a.keys() | b.keys()}
-
-
-def get_dict_intersection_over_keys(
- a: Dict[str, Any], b: Dict[str, Any]
-) -> Dict[str, Any]:
- return {k: a[k] for k in a.keys() & b.keys()}
-
-
-def get_adjacency_union(
- a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]]
-) -> Dict[str, Dict[str, Set[str]]]:
- return {
- k: get_setdict_union(a.get(k, {}), b.get(k, {})) for k in a.keys() | b.keys()
- }
-
-
-def get_adjacency_intersection(
- a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]]
-) -> Dict[str, Dict[str, Set[str]]]:
- return {k: get_setdict_intersection(a[k], b[k]) for k in a.keys() & b.keys()}
-
-
-def get_nullable_union(*sets: Set[str]) -> Set[str]:
- return set.union(*sets) if len(sets) > 0 else set()
-
-
-def get_nullable_intersection(*sets: Set[str]) -> Set[str]:
- return set.intersection(*sets) if len(sets) > 0 else set()
-
-
-def get_adj_from_edges(
- edges: Set[Tuple[str, str, str]], node_support: Optional[Set[str]] = None
-) -> Tuple[Dict[str, Dict[str, Set[str]]], Dict[str, Dict[str, Set[str]]]]:
- """
- Given edges, convert them into the adjacency representation used by the
- `ComputationFrame` class.
- """
- out = {}
- inp = {}
- for src, dst, label in edges:
- if src not in out:
- out[src] = {}
- if label not in out[src]:
- out[src][label] = set()
- out[src][label].add(dst)
- if dst not in inp:
- inp[dst] = {}
- if label not in inp[dst]:
- inp[dst][label] = set()
- inp[dst][label].add(src)
- if node_support is not None:
- for node in node_support:
- if node not in out:
- out[node] = {}
- if node not in inp:
- inp[node] = {}
- return out, inp
-
-
-def parse_returns(
- sig: inspect.Signature,
- returns: Any,
- nout: Union[Literal["auto", "var"], int],
- output_names: Optional[List[str]] = None,
-) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- """
- Return two dicts based on the returns:
- - {output name: output value}
- - {output name: output type annotation}, where things like `Tuple[T, ...]` are expanded.
- """
- ### figure out the number of outputs, and convert them to a tuple
- if nout == "auto": # infer from the returns
- if isinstance(returns, tuple):
- nout = len(returns)
- returns_tuple = returns
- else:
- nout = 1
- returns_tuple = (returns,)
- elif nout == "var":
- assert isinstance(returns, tuple)
- nout = len(returns)
- returns_tuple = returns
- else: # nout is an integer
- assert isinstance(nout, int)
- assert isinstance(returns, tuple)
- assert len(returns) == nout
- returns_tuple = returns
- ### get the dict of outputs
- outputs_dict = {
- dump_output_name(i, output_names): returns_tuple[i] for i in range(nout)
- }
- ### figure out the annotations
- annotations_dict = {}
- output_annotation = sig.return_annotation
- if output_annotation is inspect._empty: # no annotation
- annotations_dict = {k: Any for k in outputs_dict.keys()}
- else:
- if (
- hasattr(output_annotation, "__origin__")
- and output_annotation.__origin__ is tuple
- ):
- if (
- len(output_annotation.__args__) == 2
- and output_annotation.__args__[1] == Ellipsis
- ):
- annotations_dict = {
- k: output_annotation.__args__[0] for k in outputs_dict.keys()
- }
- else:
- annotations_dict = {
- k: output_annotation.__args__[i]
- for i, k in enumerate(outputs_dict.keys())
- }
- else:
- assert nout == 1
- annotations_dict = {k: output_annotation for k in outputs_dict.keys()}
- return outputs_dict, annotations_dict
-
-
-def unwrap_decorators(
- obj: Callable, strict: bool = True
-) -> Union[types.FunctionType, types.MethodType]:
- while hasattr(obj, "__wrapped__"):
- obj = obj.__wrapped__
- if not isinstance(obj, (types.FunctionType, types.MethodType)):
- msg = f"Expected a function or method, but got {type(obj)}"
- if strict:
- raise RuntimeError(msg)
- else:
- logger.debug(msg)
- return obj
-
-def is_subdict(a: Dict, b: Dict) -> bool:
- """
- Check that all keys in `a` are in `b` with the same value.
- """
- return all((k in b and a[k] == b[k]) for k in a)
-
-_KT, _VT = TypeVar("_KT"), TypeVar("_VT")
-def invert_dict(d: Dict[_KT, _VT]) -> Dict[_VT, List[_KT]]:
- """
- Invert a dictionary
- """
- out = {}
- for k, v in d.items():
- if v not in out:
- out[v] = []
- out[v].append(k)
- return out
-
-def ask_user(question: str, valid_options: List[str]) -> str:
- """
- Ask the user a question and return their response.
- """
- prompt = f"{question} "
- while True:
- print(prompt)
- response = input().strip().lower()
- if response in valid_options:
- return response
- else:
- print(f"Invalid response: {response}")
diff --git a/mandala/all.py b/mandala/all.py
deleted file mode 100644
index badf8c7..0000000
--- a/mandala/all.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from .common_imports import *
-from .core.model import ValueRef, Call, FuncOp, wrap_atom, TransientObj
-from .core.wrapping import unwrap
-from .core.builtins_ import DictRef, ListRef, SetRef
-from .queries.weaver import BuiltinQueries
-from .queries.main import Querier
-from .queries.viz import show
-from .core.sig import Signature
-from .core.config import Config
-from .ui.storage import Storage, MODES, FuncInterface
-from .ui.contexts import Context, GlobalContext
-from .ui.funcs import op, Q, superop, Transient
-from .ui.utils import wrap_ui as wrap
-from .ui.cfs import ComputationFrame
-
-# from .storages.rel_impls.duckdb_impl import DuckDBRelStorage
-from .storages.rel_impls.sqlite_impl import SQLiteRelStorage
-from .storages.rels import serialize, deserialize
-from .deps.tracers.dec_impl import track, TracerState
-from .deps.tracers import TracerABC, DecTracer, SysTracer
-
-from .queries import ListQ, SetQ, DictQ
diff --git a/mandala/_next/cf.py b/mandala/cf.py
similarity index 100%
rename from mandala/_next/cf.py
rename to mandala/cf.py
diff --git a/mandala/_next/cf_examples.ipynb b/mandala/cf_examples.ipynb
similarity index 100%
rename from mandala/_next/cf_examples.ipynb
rename to mandala/cf_examples.ipynb
diff --git a/mandala/common_imports.py b/mandala/common_imports.py
index 68c8962..c65c395 100644
--- a/mandala/common_imports.py
+++ b/mandala/common_imports.py
@@ -28,6 +28,7 @@
Set,
Union,
TypeVar,
+ Literal,
)
from pathlib import Path
@@ -35,26 +36,6 @@
import pyarrow as pa
import numpy as np
-
-class Session:
- # for debugging
-
- def __init__(self):
- self.items = []
- self._scope = None
-
- def d(self):
- scope = inspect.currentframe().f_back.f_locals
- self._scope = scope
-
- def dump(self):
- # put the scope into the current locals
- assert self._scope is not None
- scope = inspect.currentframe().f_back.f_locals
- print(f"Dumping {self._scope.keys()} into local scope")
- scope.update(self._scope)
-
-
try:
import rich
@@ -78,10 +59,5 @@ def dump(self):
logging.basicConfig(format=FORMAT)
logger.setLevel(logging.INFO)
-sess = Session()
-
-TableType = TypeVar("TableType", pa.Table, pd.DataFrame)
-
-class InternalError(Exception):
- pass
+from mandala.common_imports import sess
diff --git a/mandala/_next/config.py b/mandala/config.py
similarity index 100%
rename from mandala/_next/config.py
rename to mandala/config.py
diff --git a/mandala/core/builtins_.py b/mandala/core/builtins_.py
deleted file mode 100644
index 9056d2e..0000000
--- a/mandala/core/builtins_.py
+++ /dev/null
@@ -1,431 +0,0 @@
-from collections.abc import Sequence, Mapping, Set as SetABC
-from ..common_imports import *
-
-from .config import MODES
-from .model import Ref, FuncOp, Call, wrap_atom, ValueRef
-from .utils import Hashing
-from .tps import AnyType
-
-HASHERS = {
- "__list__": Hashing.hash_list,
- "__set__": Hashing.hash_multiset,
- "__dict__": Hashing.hash_dict,
-}
-
-
-class StructRef(Ref):
- builtin_id = None
-
- def __init__(
- self,
- uid: Optional[str],
- obj: Optional[Any],
- in_memory: bool,
- transient: bool = False,
- ):
- if uid is None:
- builtin_id = self.builtin_id
- hasher = HASHERS[builtin_id]
- uids = Builtins.map(func=lambda elt: elt.uid, obj=obj, struct_id=builtin_id)
- uid = Builtins._make_builtin_uid(uid=hasher(uids), builtin_id=builtin_id)
- super().__init__(uid=uid, obj=obj, in_memory=in_memory, transient=transient)
-
- def as_inputs_list(self) -> List[Dict[str, Ref]]:
- raise NotImplementedError
-
- def causify_up(self):
- raise NotImplementedError
-
- def get_call(self, wrapped_inputs: Dict[str, Ref]) -> Call:
- call_uid = Builtins.OPS[self.builtin_id].get_pre_call_uid(
- input_uids={k: v.uid for k, v in wrapped_inputs.items()}
- )
- return Call(
- uid=call_uid,
- inputs=wrapped_inputs,
- outputs=[],
- func_op=Builtins.OPS[self.builtin_id],
- transient=False,
- )
-
- def get_calls(self) -> List[Call]:
- inputs_list = self.as_inputs_list()
- res = []
- for wrapped_inputs in inputs_list:
- res.append(self.get_call(wrapped_inputs=wrapped_inputs))
- return res
-
- @staticmethod
- def map(obj: Iterable, func: Callable) -> Iterable:
- raise NotImplementedError
-
- @staticmethod
- def elts(obj: Iterable) -> Iterable:
- raise NotImplementedError
-
- def children(self) -> Iterable[Ref]:
- return self.elts(self.obj)
-
- def unlinked(self, keep_causal: bool) -> "StructRef":
- if not self.in_memory:
- res = type(self)(uid=self.uid, obj=None, in_memory=False)
- if keep_causal:
- res._causal_uid = self._causal_uid
- return res
- else:
- unlinked_elts = type(self).map(
- obj=self.obj, func=lambda elt: elt.unlinked(keep_causal=keep_causal)
- )
- res = type(self)(uid=self.uid, obj=unlinked_elts, in_memory=True)
- if keep_causal:
- res._causal_uid = self._causal_uid
- return res
-
-
-class ListRef(StructRef, Sequence):
- """
- Immutable list of Refs.
- """
-
- builtin_id = "__list__"
-
- def as_inputs_list(self) -> List[Dict[str, Ref]]:
- idxs = [wrap_atom(idx) for idx in range(len(self))]
- for idx in idxs:
- causify_atom(idx)
- wrapped_inputs_list = [
- {"lst": self, "elt": elt, "idx": idx} for elt, idx in zip(self.obj, idxs)
- ]
- return wrapped_inputs_list
-
- def causify_up(self):
- assert all(elt.causal_uid is not None for elt in self.obj)
- self.causal_uid = Hashing.hash_list([elt.causal_uid for elt in self.obj])
-
- @staticmethod
- def map(obj: list, func: Callable) -> list:
- return [func(elt) for elt in obj]
-
- @staticmethod
- def elts(obj: list) -> list:
- return obj
-
- def dump(self) -> "ListRef":
- return ListRef(
- uid=self.uid, obj=[vref.detached() for vref in self.obj], in_memory=True
- )
-
- ############################################################################
- ### list interface
- ############################################################################
- def __getitem__(
- self, idx: Union[int, "ValNode", Ref, slice]
- ) -> Union[Ref, "ValNode"]:
- self._auto_attach()
- if isinstance(idx, Ref):
- prepare_query(ref=idx, tp=AnyType())
- res_query = BuiltinQueries.GetListItemQuery(lst=self.query, idx=idx.query)
- res = self.obj[idx.obj].unlinked(keep_causal=True)
- res._query = res_query
- return res
- elif isinstance(idx, ValNode):
- res = BuiltinQueries.GetListItemQuery(lst=self.query, idx=idx)
- return res
- elif isinstance(idx, int):
- if (
- GlobalContext.current is not None
- and GlobalContext.current.mode != MODES.run
- ):
- raise ValueError
- res: Ref = self.obj[idx]
- res = res.unlinked(keep_causal=True)
- wrapped_idx = wrap_atom(obj=idx)
- causify_atom(ref=wrapped_idx)
- call = self.get_call(
- wrapped_inputs={"lst": self, "idx": wrapped_idx, "elt": res}
- )
- call.link(orientation=StructOrientations.destruct)
- return res
- elif isinstance(idx, slice):
- if (
- GlobalContext.current is not None
- and GlobalContext.current.mode != MODES.run
- ):
- raise ValueError
- res = self.obj[idx]
- res = [elt.unlinked(keep_causal=True) for elt in res]
- wrapped_idxs = [wrap_atom(obj=i) for i in range(*idx.indices(len(self)))]
- for wrapped_idx in wrapped_idxs:
- causify_atom(ref=wrapped_idx)
- for wrapped_idx, elt in zip(wrapped_idxs, res):
- call = self.get_call(
- wrapped_inputs={"lst": self, "idx": wrapped_idx, "elt": elt}
- )
- call.link(orientation=StructOrientations.destruct)
- return res
-
- def __iter__(self):
- self._auto_attach()
- return iter(self.obj)
-
- def __len__(self) -> int:
- self._auto_attach()
- return len(self.obj)
-
-
-class DictRef(StructRef): # don't inherit from Mapping because it's not hashable
- """
- Immutable string-keyed dict of Refs.
- """
-
- builtin_id = "__dict__"
-
- def as_inputs_list(self) -> List[Dict[str, Ref]]:
- keys = {k: wrap_atom(k) for k in self.obj.keys()}
- for k in keys.values():
- causify_atom(k)
- wrapped_inputs_list = [
- {"dct": self, "key": keys[k], "val": v} for k, v in self.obj.items()
- ]
- return wrapped_inputs_list
-
- def causify_up(self):
- assert all(elt.causal_uid is not None for elt in self.obj.values())
- self.causal_uid = Hashing.hash_dict(
- {k: v.causal_uid for k, v in self.obj.items()}
- )
-
- @staticmethod
- def map(obj: dict, func: Callable) -> dict:
- return {k: func(v) for k, v in obj.items()}
-
- @staticmethod
- def elts(obj: dict) -> Iterable:
- return iter(obj.values())
-
- def dump(self) -> "DictRef":
- assert self.in_memory
- return DictRef(
- uid=self.uid,
- obj={k: vref.detached() for k, vref in self.obj.items()},
- in_memory=True,
- )
-
- ############################################################################
- ### dict interface
- ############################################################################
- def __getitem__(self, key: Union[str, "ValNode", Ref]) -> Union[Ref, "ValNode"]:
- self._auto_attach()
- if isinstance(key, str):
- if (
- GlobalContext.current is not None
- and GlobalContext.current.mode != MODES.run
- ):
- raise ValueError
- return self.obj[key]
- if isinstance(key, Ref):
- prepare_query(ref=key, tp=AnyType())
- res_query = BuiltinQueries.GetDictItemQuery(dct=self.query, key=key.query)
- res = self.obj[key.obj].unlinked(keep_causal=True)
- res._query = res_query
- return res
- elif isinstance(key, ValNode):
- res = BuiltinQueries.GetDictItemQuery(dct=self.query, key=key)
- return res
- else:
- raise ValueError
-
- def __iter__(self):
- self._auto_attach()
- return iter(self.obj)
-
- def __len__(self) -> int:
- self._auto_attach()
- return len(self.obj)
-
-
-class SetRef(StructRef): # don't subclass from set, because it's not hashable
- """
- Immutable set of Refs.
- """
-
- builtin_id = "__set__"
-
- def as_inputs_list(self) -> List[Dict[str, Ref]]:
- wrapped_inputs_list = [{"st": self, "elt": elt} for elt in self.obj]
- return wrapped_inputs_list
-
- def causify_up(self):
- assert all(elt.causal_uid is not None for elt in self.obj)
- self.causal_uid = Hashing.hash_multiset([elt.causal_uid for elt in self.obj])
-
- @staticmethod
- def map(obj: set, func: Callable) -> list:
- return [func(elt) for elt in obj]
-
- @staticmethod
- def elts(obj: set) -> Iterable:
- return iter(obj)
-
- def dump(self) -> "SetRef":
- assert self.in_memory
- return SetRef(
- uid=self.uid, obj={vref.detached() for vref in self.obj}, in_memory=True
- )
-
- ############################################################################
- ### set interface
- ############################################################################
- def __contains__(self, item: Ref) -> bool:
- from .wrapping import unwrap
-
- if not self.in_memory:
- logging.warning(
- "Checking membership in a lazy SetRef requires loading the entire set into memory."
- )
- self._auto_attach(shallow=False)
- return item in unwrap(self.obj)
- else:
- return item in unwrap(self.obj)
-
- def __iter__(self):
- self._auto_attach()
- return iter(self.obj)
-
- def __len__(self) -> int:
- self._auto_attach()
- return len(self.obj)
-
-
-class Builtins:
- IDS = ("__list__", "__dict__", "__set__")
-
- @staticmethod
- def list_func(lst: List[Any], elt: Any, idx: Any):
- assert lst[idx] is elt
-
- @staticmethod
- def dict_func(dct: Dict[str, Any], key: str, val: Any):
- assert dct[key] is val
-
- @staticmethod
- def set_func(st: Set[Any], elt: Any):
- assert elt in st
-
- list_op = FuncOp(
- func=list_func.__func__, _is_builtin=True, version=0, ui_name="__list__"
- )
- dict_op = FuncOp(
- func=dict_func.__func__, _is_builtin=True, version=0, ui_name="__dict__"
- )
- set_op = FuncOp(
- func=set_func.__func__, _is_builtin=True, version=0, ui_name="__set__"
- )
-
- OPS = {
- "__list__": list_op,
- "__dict__": dict_op,
- "__set__": set_op,
- }
-
- REF_CLASSES = {
- "__list__": ListRef,
- "__dict__": DictRef,
- "__set__": SetRef,
- }
-
- PY_TYPES = {
- "__list__": list,
- "__dict__": dict,
- "__set__": set,
- }
-
- IO = {
- "construct": {
- "__list__": {"in": {"elt", "idx"}, "out": {"lst"}},
- "__dict__": {"in": {"key", "val"}, "out": {"dct"}},
- "__set__": {"in": {"elt"}, "out": {"st"}},
- },
- "destruct": {
- "__list__": {"in": {"lst", "idx"}, "out": {"elt"}},
- "__dict__": {"in": {"dct", "key"}, "out": {"val"}},
- "__set__": {"in": {"st"}, "out": {"elt"}},
- },
- }
-
- @staticmethod
- def reassign_io_using_orientation(
- in_dict: Dict[str, Any],
- out_dict: Dict[str, Any],
- orientation: str,
- builtin_id: str,
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- io_assignment = Builtins.IO[orientation][builtin_id]
- in_names = io_assignment["in"]
- out_names = io_assignment["out"]
- res_in = {}
- res_out = {}
- for k, v in itertools.chain(in_dict.items(), out_dict.items()):
- if k in in_names:
- res_in[k] = v
- if k in out_names:
- res_out[k] = v
- return res_in, res_out
-
- @staticmethod
- def _make_builtin_uid(uid: str, builtin_id: str) -> str:
- return f"{builtin_id}.{uid}"
-
- @staticmethod
- def is_builtin_uid(uid: str) -> bool:
- return ("." in uid) and (uid.split(".")[0] in Builtins.IDS)
-
- @staticmethod
- def parse_builtin_uid(uid: str) -> Tuple[str, str]:
- assert Builtins.is_builtin_uid(uid)
- builtin_id, uid = uid.split(".", 1)
- return builtin_id, uid
-
- @staticmethod
- def spawn_builtin(
- builtin_id: str, uid: str, causal_uid: Optional[str] = None
- ) -> Ref:
- assert builtin_id in Builtins.IDS
- uid = Builtins._make_builtin_uid(uid=uid, builtin_id=builtin_id)
- res = Builtins.REF_CLASSES[builtin_id](uid=uid, obj=None, in_memory=False)
- if causal_uid is not None:
- res._causal_uid = causal_uid
- return res
-
- @staticmethod
- def map(
- func: Callable, obj: Union[List, Dict, Set], struct_id: str
- ) -> Union[List, Dict, Set]:
- if struct_id == "__list__":
- return [func(elt) for elt in obj]
- elif struct_id == "__dict__":
- return {key: func(val) for key, val in obj.items()}
- elif struct_id == "__set__":
- return {func(elt) for elt in obj}
- else:
- raise ValueError(f"Invalid struct_id: {struct_id}")
-
- @staticmethod
- def collect_all_calls(ref: Ref) -> List[Call]:
- if isinstance(ref, ValueRef):
- return []
- elif isinstance(ref, StructRef):
- if not ref.in_memory:
- return []
- else:
- calls = ref.get_calls()
- for elt in ref.elts(ref.obj):
- calls.extend(Builtins.collect_all_calls(elt))
- return calls
- else:
- raise ValueError(f"Unexpected ref type: {type(ref)}")
-
-
-from ..queries.weaver import ValNode, BuiltinQueries, prepare_query, StructOrientations
-from ..ui.contexts import GlobalContext
-from .wrapping import causify_atom
diff --git a/mandala/core/config.py b/mandala/core/config.py
deleted file mode 100644
index 9d6b7a5..0000000
--- a/mandala/core/config.py
+++ /dev/null
@@ -1,164 +0,0 @@
-from ..common_imports import *
-from typing import Literal
-
-
-def get_mandala_path() -> Path:
- import mandala
-
- return Path(os.path.dirname(mandala.__file__))
-
-
-class Config:
- ### options modifying library behavior
- # whether ops automatically wrap their inputs as value references, or
- # require to be explicitly passed value references
- ### settings
- # whether to automatically wrap inputs when a call is made to an op
- autowrap_inputs = True
- # whether to automatically unwrap inputs when an op is actually executed
- autounwrap_inputs = True
- # how to assign UIDs to outputs
- output_wrap_method = "content"
- # whether to empty the call and vref caches upon committing to the RDBMS
- evict_on_commit = True
- # whether to commit on context exit
- autocommit = True
- # whether signatures are verified against the database each time a function
- # is called
- check_signature_on_each_call = False
- # always create storage with a persistent database
- _persistent_storage_testing = False
- enable_ref_magics = False
- warnings = True
- spillover_threshold_mb = 50
- db_backend = "sqlite"
- query_engine: Literal["sql", "naive", "_test"] = "sql"
- verbose_queries: bool = True
- func_interface_cls_name = "FuncInterface"
-
- ### constants
- # used for columns containing UIDs of value references or calls
- uid_col = "__uid__"
- causal_uid_col = "__causal_uid__"
- full_uid_col = "__full_uid__"
- content_version_col = "__content_version__"
- semantic_version_col = "__semantic_version__"
- transient_col = "__transient__"
- # columns that are not inputs or outputs in a memoization table
- special_call_cols = [
- uid_col,
- causal_uid_col,
- content_version_col,
- semantic_version_col,
- transient_col,
- ]
- # name for the table that holds the value reference UIDs
- vref_table = "__vrefs__"
- causal_vref_table = "__causal_vrefs__"
- vref_value_col = "value"
- temp_arrow_table = "__arrow__"
- # name for the event log table
- event_log_table = "__event_log__"
- # todo: currently unused?
- # schema_event_log_table = "__schema_event_log__"
- schema_table = "__schema__"
- # table for keeping track of function dependencies
- deps_table = "__deps__"
- provenance_table = "__provenance__"
- # all output names begin with this string
- # todo: prevent creating inputs with this name
- output_name_prefix = "output_"
-
- ### checking for optional dependencies
- try:
- import dask
-
- has_dask = True
- except ImportError:
- has_dask = False
-
- try:
- import torch
-
- has_torch = True
- except ImportError:
- has_torch = False
-
- try:
- import cityhash
-
- has_cityhash = True
- except ImportError:
- has_cityhash = False
-
- try:
- import PIL
-
- has_pil = True
- except ImportError:
- has_pil = False
-
- try:
- import duckdb
-
- has_duckdb = True
- except ImportError:
- has_duckdb = False
-
- try:
- import rich
-
- has_rich = True
- except ImportError:
- has_rich = False
-
- if has_rich:
- from rich import pretty
-
- pretty.install()
-
- ### some module names needed by internals
- mandala_path = get_mandala_path()
- module_name = "mandala"
- tests_module_name = "mandala.tests"
-
- # hashing method
- content_hasher: Literal["cityhash", "blake2b", "joblib"] = "joblib"
-
-
-def dump_output_name(index: int) -> str:
- return f"{Config.output_name_prefix}{index}"
-
-
-def parse_output_idx(output_name: str) -> int:
- return int(output_name[len(Config.output_name_prefix) :])
-
-
-def is_output_name(name: str) -> bool:
- return (
- name.startswith(Config.output_name_prefix)
- and name[len(Config.output_name_prefix) :].isdigit()
- )
-
-
-if Config.has_torch:
- import torch
-
-
-class MODES:
- run = "run"
- query = "query"
- batch = "batch"
- noop = "noop"
- define = "define"
- delete = "delete"
-
- all_ = (run, query, batch, noop, define, delete)
-
-
-class Provenance:
- causal_uid = "causal"
- direction = "direction"
- call_causal_uid = "call_causal"
- name = "name"
- op_id = "op_id"
diff --git a/mandala/core/integrations.py b/mandala/core/integrations.py
deleted file mode 100644
index 43c8c9c..0000000
--- a/mandala/core/integrations.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from .sig import Signature
-from ..common_imports import *
-from .config import Config
-
-if Config.has_torch:
- import torch
-
- def sig_from_jit_script(
- f: torch.jit.ScriptFunction, version: int
- ) -> Tuple[Signature, inspect.Signature]:
- """
- Parse a torch.jit.ScriptFunction into a FuncOp.
-
- @torch.jit.script-decorated functions must follow a special format:
- - there can't be a variable number of arguments
- - there can't be keyword arguments with defaults
- """
- # get the names of the inputs
- input_names = [arg.name for arg in f.schema.arguments]
- # get the number of outputs
- n_outputs = len(f.schema.returns)
- # get the default values for the inputs of the ScriptFunction
- defaults = {}
- for argument in f.schema.arguments:
- if argument.default_value is not None:
- # defaults are assigned non-null values by the JIT based on
- # inferred type and default value in the signature.
- defaults[argument.name] = argument.default_value
- parameters = OrderedDict()
- for arg in f.schema.arguments:
- kind = (
- inspect.Parameter.KEYWORD_ONLY
- if arg.kwarg_only
- else inspect.Parameter.POSITIONAL_OR_KEYWORD
- )
- if defaults.get(arg.name) is not None:
- param = inspect.Parameter(
- name=arg.name, kind=kind, default=defaults[arg.name]
- )
- else:
- param = inspect.Parameter(name=arg.name, kind=kind)
- parameters[arg.name] = param
- if n_outputs == 0:
- return_annotation = inspect._empty
- elif n_outputs == 1:
- return_annotation = f.schema.returns[0].type
- else:
- return_annotation = tuple([r.type for r in f.schema.returns])
- py_sig = inspect.Signature(
- parameters=list(parameters.values()),
- return_annotation=return_annotation,
- __validate_parameters__=True,
- )
- return Signature.from_py(name=str(f.name), version=version, sig=py_sig), py_sig
diff --git a/mandala/core/model.py b/mandala/core/model.py
deleted file mode 100644
index 070158f..0000000
--- a/mandala/core/model.py
+++ /dev/null
@@ -1,743 +0,0 @@
-from typing import Type
-import textwrap
-from .config import Config, parse_output_idx, dump_output_name, is_output_name, MODES
-from ..common_imports import *
-from .utils import Hashing, get_uid, get_full_uid, parse_full_uid
-from .sig import (
- Signature,
- _postprocess_outputs,
-)
-from .tps import Type, AnyType, ListType, DictType, SetType
-
-from ..deps.tracers import TracerABC, DecTracer
-
-if Config.has_torch:
- import torch
- from .integrations import sig_from_jit_script
-
-
-class Delayed:
- pass
-
-
-class TransientObj:
- def __init__(self, obj: Any, unhashable: bool = False):
- self.obj = obj
- self.unhashable = unhashable
-
-
-def get_transient_uid(content_hash: str) -> str:
- return f"__transient__.{content_hash}"
-
-
-def is_transient_uid(uid: str) -> bool:
- return uid.startswith("__transient__")
-
-
-################################################################################
-### refs
-################################################################################
-class Ref:
- def __init__(self, uid: str, obj: Any, in_memory: bool, transient: bool):
- self.uid = uid
- self._causal_uid = None
- self._obj = obj
- self.in_memory = in_memory
- self.transient = transient
- # runtime only
- self._query: Optional[ValNode] = None
-
- @property
- def causal_uid(self) -> str:
- return self._causal_uid
-
- @causal_uid.setter
- def causal_uid(self, causal_uid: str):
- if self.causal_uid is not None and self.causal_uid != causal_uid:
- raise ValueError(
- f"causal_uid already set to {self.causal_uid}, cannot set to {causal_uid}"
- )
- self._causal_uid = causal_uid
-
- @property
- def full_uid(self) -> str:
- return get_full_uid(uid=self.uid, causal_uid=self.causal_uid)
-
- @staticmethod
- def parse_full_uid(full_uid: str) -> Tuple[str, str]:
- return parse_full_uid(full_uid=full_uid)
-
- @property
- def obj(self) -> Any:
- return self._obj
-
- @staticmethod
- def from_full_uid(full_uid: str) -> "Ref":
- uid, causal_uid = Ref.parse_full_uid(full_uid=full_uid)
- res = Ref.from_uid(uid=uid)
- res.causal_uid = causal_uid
- return res
-
- @staticmethod
- def from_uid(uid: str, causal_uid: Optional[str] = None) -> "Ref":
- from .builtins_ import Builtins
-
- if Builtins.is_builtin_uid(uid=uid):
- builtin_id, uid = Builtins.parse_builtin_uid(uid=uid)
- return Builtins.spawn_builtin(
- builtin_id=builtin_id, uid=uid, causal_uid=causal_uid
- )
- elif is_transient_uid(uid=uid):
- res = ValueRef(uid=uid, obj=None, in_memory=False, transient=True)
- else:
- res = ValueRef(uid=uid, obj=None, in_memory=False)
- if causal_uid is not None:
- res._causal_uid = causal_uid
- return res
-
- def is_delayed(self) -> bool:
- return isinstance(self.obj, Delayed)
-
- @staticmethod
- def make_delayed(RefCls) -> "Ref":
- return RefCls(uid="", obj=Delayed(), in_memory=False)
-
- def attach(self, reference: "Ref"):
- assert self.uid == reference.uid and reference.in_memory
- self._obj = reference.obj
- self.in_memory = True
-
- def _auto_attach(self, shallow: bool = True):
- if not self.in_memory:
-
- context = GlobalContext.current
- if context is None:
- ## emit a warning
- # logger.warning(
- # "No context found, cannot attach ref."
- # )
- # fail silently
- return
- if context.mode != MODES.run:
- return
- assert context.storage is not None
- storage = context.storage
- storage.cache.mattach(vrefs=[self], shallow=shallow)
- # storage.rel_adapter.mattach(vrefs=[self], shallow=shallow)
- causify_down(ref=self, start=self.causal_uid, stop_at_causal=False)
-
- def detached(self, keep_causal: bool = True) -> "Ref":
- """
- Produce a metadata-only copy of this `Ref` that is unlinked from the
- computational graph.
- """
- res = self.__class__(uid=self.uid, obj=None, in_memory=False)
- if keep_causal:
- res.causal_uid = self.causal_uid
- return res
-
- def unlinked(self, keep_causal: bool) -> "Ref":
- """
- Produce a copy unlinked from the computational graph that preserves the
- `in_memory` status and all other properties.
-
- If `keep_causal` is True, the causal UID is preserved. Otherwise, it is
- removed recursively.
-
- ! Actually, this is quite a lot like a deepcopy for refs, which suggests
- that we could refactor
- """
- raise NotImplementedError
-
- def clone(self) -> "Ref":
- """
- Clone a `Ref` by creating a new object which has the same data
- """
- return self.unlinked(keep_causal=True)
-
- @property
- def _uid_suffix(self) -> str:
- return self.uid.split(".")[-1]
-
- @property
- def _short_uid(self) -> str:
- return self._uid_suffix[:3] + "..."
-
- def __repr__(self, shorten: bool = False) -> str:
- if self.in_memory:
- obj_repr = repr(self.obj)
- if shorten:
- obj_repr = textwrap.shorten(obj_repr, width=50, placeholder="...")
- if "\n" in obj_repr:
- obj_repr = f'\n{textwrap.indent(obj_repr, " ")}'
- return f"{self.__class__.__name__}({obj_repr}, uid={self._short_uid})"
- else:
- return f"{self.__class__.__name__}(in_memory=False, uid={self._short_uid})"
-
- @property
- def query(self) -> "ValNode":
- if self._query is None:
- raise ValueError("Ref has no query")
- return self._query
-
- def pin(self, *values):
- assert self._query is not None
- if len(values) == 0:
- constraint = [self.full_uid]
- else:
- constraint = self.query.get_constraint(*values)
- self._query.constraint = constraint
-
- def unpin(self):
- assert self._query is not None
- self._query.constraint = None
-
-
-class ValueRef(Ref):
- def __init__(self, uid: str, obj: Any, in_memory: bool, transient: bool = False):
- super().__init__(uid=uid, obj=obj, in_memory=in_memory, transient=transient)
-
- def dump(self) -> "ValueRef":
- if not self.transient:
- return ValueRef(uid=self.uid, obj=self.obj, in_memory=True)
- return ValueRef(
- uid=self.uid, obj=TransientObj(obj=None), in_memory=False, transient=True
- )
-
- def unlinked(self, keep_causal: bool) -> "Ref":
- res = ValueRef(
- uid=self.uid,
- obj=self.obj,
- in_memory=self.in_memory,
- transient=self.transient,
- )
- if keep_causal:
- res.causal_uid = self.causal_uid
- return res
-
- ############################################################################
- ### magic methods forwarding
- ############################################################################
- def _init_magic(self) -> Callable:
- if not Config.enable_ref_magics:
- raise RuntimeError(
- "Ref magic methods (typecasting/comparison operators/binary operators) are disabled; enable with Config.enable_ref_magics = True"
- )
- if Config.warnings:
- logging.warning(
- f"Automatically unwrapping `Ref` to run magic method (typecasting/comparison operators/binary operators)."
- )
- from .wrapping import unwrap
-
- self._auto_attach(shallow=True)
- return unwrap
-
- ### typecasting
- def __bool__(self) -> bool:
- self._init_magic()
- return self.obj.__bool__()
-
- def __int__(self) -> int:
- self._init_magic()
- return self.obj.__int__()
-
- def __index__(self) -> int:
- self._init_magic()
- return self.obj.__index__()
-
- ### comparison
- def __lt__(self, other: Any) -> bool:
- unwrap = self._init_magic()
- return self.obj.__lt__(unwrap(other))
-
- def __le__(self, other: Any) -> bool:
- unwrap = self._init_magic()
- return self.obj.__le__(unwrap(other))
-
- def __eq__(self, other: Any) -> bool:
- unwrap = self._init_magic()
- return self.obj.__eq__(unwrap(other))
-
- def __hash__(self) -> int:
- return id(self)
-
- def __ne__(self, other: Any) -> bool:
- unwrap = self._init_magic()
- return self.obj.__ne__(unwrap(other))
-
- def __gt__(self, other: Any) -> bool:
- unwrap = self._init_magic()
- return self.obj.__gt__(unwrap(other))
-
- def __ge__(self, other: Any) -> bool:
- unwrap = self._init_magic()
- return self.obj.__ge__(unwrap(other))
-
- ### binary operations
- def __add__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__add__(unwrap(other))
-
- def __sub__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__sub__(unwrap(other))
-
- def __mul__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__mul__(unwrap(other))
-
- def __floordiv__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__floordiv__(unwrap(other))
-
- def __truediv__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__truediv__(unwrap(other))
-
- def __mod__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__mod__(unwrap(other))
-
- def __or__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__or__(unwrap(other))
-
- def __and__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__and__(unwrap(other))
-
- def __xor__(self, other: Any) -> Any:
- unwrap = self._init_magic()
- return self.obj.__xor__(unwrap(other))
-
-
-################################################################################
-### calls
-################################################################################
-
-
-class Call:
- def __init__(
- self,
- uid: str,
- inputs: Dict[str, Ref],
- outputs: List[Ref],
- func_op: "FuncOp",
- transient: bool,
- causal_uid: Optional[str] = None,
- semantic_version: Optional[str] = None,
- content_version: Optional[str] = None,
- ):
- self.func_op = func_op.detached()
- self.uid = uid
- self.semantic_version = semantic_version
- self.content_version = content_version
- self.inputs = inputs
- self.outputs = outputs
- self.transient = transient # if outputs contain transient objects
- if causal_uid is None:
- input_uids = {k: v.uid for k, v in self.inputs.items()}
- input_causal_uids = {k: v.causal_uid for k, v in self.inputs.items()}
- assert all([v is not None for v in input_causal_uids.values()])
- causal_uid = self.func_op.get_call_causal_uid(
- input_uids=input_uids,
- input_causal_uids=input_causal_uids,
- semantic_version=semantic_version,
- )
- self.causal_uid = causal_uid
- self._func_query = None
-
- @property
- def full_uid(self) -> str:
- return f"{self.uid}.{self.causal_uid}"
-
- def link(
- self,
- orientation: Optional[str] = None,
- ):
- if self._func_query is not None:
- return
- input_types, output_types = self.func_op.input_types, self.func_op.output_types
- for k, v in self.inputs.items():
- prepare_query(ref=v, tp=input_types[k])
- for i, v in enumerate(self.outputs):
- prepare_query(ref=v, tp=output_types[i])
- outputs = {dump_output_name(i): v.query for i, v in enumerate(self.outputs)}
- self._func_query = CallNode.link(
- inputs={k: v.query for k, v in self.inputs.items()},
- func_op=self.func_op.detached(),
- outputs=outputs,
- orientation=orientation,
- constraint=None,
- )
-
- def unlink(self):
- assert self._func_query is not None
- self._func_query.unlink()
- self._func_query = None
-
- @property
- def func_query(self) -> "CallNode":
- assert self._func_query is not None
- return self._func_query
-
- def __repr__(self) -> str:
- tuples: List[Tuple[str, str]] = [
- ("uid", self.uid),
- ]
- if self.semantic_version is not None:
- tuples.append(("semantic_version", self.semantic_version))
- if self.content_version is not None:
- tuples.append(("content_version", self.content_version))
- tuples.extend(
- [
- ("inputs", textwrap.shorten(str(self.inputs), width=80)),
- ("outputs", textwrap.shorten(str(self.outputs), width=80)),
- ("func_op", self.func_op),
- ]
- )
- data_str = ",\n".join([f" {k}={v}" for k, v in tuples])
- return f"Call(\n{data_str}\n)"
-
- @staticmethod
- def from_row(row: Union[pa.Table, dict], func_op: "FuncOp") -> "Call":
- """
- Generate a `Call` from a single-row table encoding the UID, input and
- output UIDs.
-
- NOTE: this does not include the objects for the inputs and outputs to the call!
- """
- columns = row.column_names if isinstance(row, pa.Table) else row.keys()
- output_columns = [column for column in columns if is_output_name(column)]
- input_columns = [
- column
- for column in columns
- if column not in output_columns and column not in Config.special_call_cols
- ]
- process_boolean = lambda x: True if x == "1" else False
- if isinstance(row, pa.Table):
- return Call(
- uid=row.column(Config.uid_col)[0].as_py(),
- causal_uid=row.column(Config.causal_uid_col)[0].as_py(),
- semantic_version=row.column(Config.semantic_version_col)[0].as_py(),
- content_version=row.column(Config.content_version_col)[0].as_py(),
- transient=process_boolean(row.column(Config.transient_col)[0].as_py()),
- inputs={
- k: Ref.from_full_uid(full_uid=row.column(k)[0].as_py())
- for k in input_columns
- },
- outputs=[
- Ref.from_full_uid(full_uid=row.column(k)[0].as_py())
- for k in sorted(output_columns, key=parse_output_idx)
- ],
- func_op=func_op,
- )
- else:
- return Call(
- uid=row[Config.uid_col],
- causal_uid=row[Config.causal_uid_col],
- semantic_version=row[Config.semantic_version_col],
- content_version=row[Config.content_version_col],
- transient=process_boolean(row[Config.transient_col]),
- inputs={k: Ref.from_full_uid(full_uid=row[k]) for k in input_columns},
- outputs=[
- Ref.from_full_uid(full_uid=row[k])
- for k in sorted(output_columns, key=parse_output_idx)
- ],
- func_op=func_op,
- )
-
- def set_input_values(self, inputs: Dict[str, Ref]) -> "Call":
- res = copy.deepcopy(self)
- assert set(inputs.keys()) == set(res.inputs.keys())
- for k, v in inputs.items():
- current = res.inputs[k]
- current._obj = v.obj
- current.in_memory = True
- return res
-
- def set_output_values(self, outputs: List[Ref]) -> "Call":
- res = copy.deepcopy(self)
- assert len(outputs) == len(res.outputs)
- for i, v in enumerate(outputs):
- current = res.outputs[i]
- current._obj = v.obj
- current.in_memory = True
- return res
-
- def detached(self) -> "Call":
- """
- Produce a copy of this call that is unlinked from the computational
- graph, and has all its inputs and outputs detached.
-
- See `Ref.detached` for more details.
- """
- return Call(
- uid=self.uid,
- causal_uid=self.causal_uid,
- inputs={k: v.detached() for k, v in self.inputs.items()},
- outputs=[v.detached() for v in self.outputs],
- func_op=self.func_op,
- semantic_version=self.semantic_version,
- content_version=self.content_version,
- transient=self.transient,
- )
-
-
-################################################################################
-### ops
-################################################################################
-class FuncOp:
- def __init__(
- self,
- func: Optional[Callable] = None,
- sig: Optional[Signature] = None,
- version: Optional[int] = None,
- ui_name: Optional[str] = None,
- is_super: bool = False,
- n_outputs_override: Optional[int] = None,
- _is_builtin: bool = False,
- ):
- self.is_super = is_super
- self._is_builtin = _is_builtin
- if func is None:
- self.sig = sig
- self.py_sig = None
- self._func = None
- self._module = None
- self._qualname = None
- else:
- self._func = func
- self._module = func.__module__
- self._qualname = func.__qualname__
- if Config.has_torch and isinstance(func, torch.jit.ScriptFunction):
- sig, py_sig = sig_from_jit_script(self._func, version=version)
- ui_name = sig.ui_name
- else:
- py_sig = inspect.signature(self._func)
- ui_name = self._func.__name__ if ui_name is None else ui_name
- self.py_sig = py_sig
- self.sig = Signature.from_py(
- sig=self.py_sig, name=ui_name, version=version, _is_builtin=_is_builtin
- )
- self.n_outputs_override = n_outputs_override
- if n_outputs_override is not None:
- self.sig.n_outputs = n_outputs_override
- self.sig.output_annotations = [Any] * n_outputs_override
-
- def __repr__(self) -> str:
- return f"FuncOp({self.sig.ui_name}, version={self.sig.version})"
-
- @property
- def func(self) -> Callable:
- assert self._func is not None
- return self._func
-
- @property
- def is_builtin(self) -> bool:
- return self._is_builtin
-
- @func.setter
- def func(self, func: Optional[Callable]):
- self._func = func
- if func is not None:
- self.py_sig = inspect.signature(self._func)
-
- @property
- def input_annotations(self) -> Dict[str, Any]:
- assert self.sig is not None
- return self.sig.input_annotations
-
- @property
- def input_types(self) -> Dict[str, Type]:
- return {k: Type.from_annotation(v) for k, v in self.input_annotations.items()}
-
- @property
- def output_annotations(self) -> List[Any]:
- assert self.sig is not None
- return self.sig.output_annotations
-
- @property
- def output_types(self) -> List[Type]:
- return [Type.from_annotation(a) for a in self.output_annotations]
-
- def compute(
- self,
- inputs: Dict[str, Any],
- tracer: Optional[TracerABC] = None,
- ) -> List[Any]:
- if tracer is not None:
- with tracer:
- if isinstance(tracer, DecTracer):
- node = tracer.register_call(func=self.func)
- result = self.func(**inputs)
- if isinstance(tracer, DecTracer):
- tracer.register_return(node=node)
- else:
- result = self.func(**inputs)
- return _postprocess_outputs(sig=self.sig, result=result)
-
- @staticmethod
- def _from_data(
- sig: Signature,
- func: Callable,
- ) -> "FuncOp":
- """
- Create a `FuncOp` object based on a signature and maybe a function. For
- internal use only.
- """
- res = FuncOp(
- func=func,
- version=sig.version,
- ui_name=sig.ui_name,
- _is_builtin=sig.is_builtin,
- )
- res.sig = sig
- return res
-
- @staticmethod
- def _from_sig(sig: Signature) -> "FuncOp":
- return FuncOp(func=None, sig=sig, _is_builtin=sig.is_builtin)
-
- def detached(self) -> "FuncOp":
- if self._func is None:
- return copy.deepcopy(self)
- result = FuncOp(
- func=None,
- sig=copy.deepcopy(self.sig),
- version=self.sig.version,
- ui_name=self.sig.ui_name,
- is_super=self.is_super,
- n_outputs_override=self.n_outputs_override,
- _is_builtin=self._is_builtin,
- )
- result._module = self._module
- result._qualname = self._qualname
- result.py_sig = self.py_sig # signature objects are immutable
- return result
-
- def get_active_inputs(self, input_uids: Dict[str, str]) -> Dict[str, str]:
- """
- Return a dict of external -> internal input names for inputs that are
- not set to their default values.
- """
- res = {}
- for k, v in input_uids.items():
- internal_k = self.sig.ui_to_internal_input_map[k]
- if internal_k in self.sig._new_input_defaults_uids:
- internal_uid = Ref.parse_full_uid(
- full_uid=self.sig._new_input_defaults_uids[internal_k]
- )[0]
- if internal_uid == v:
- continue
- res[k] = internal_k
- return res
-
- def get_call_causal_uid(
- self,
- input_uids: Dict[str, str],
- input_causal_uids: Dict[str, str],
- semantic_version: Optional[str],
- ) -> str:
- """
- Combine the causal UIDs of the inputs, the semantic version of the call,
- and the versioned internal name of the function to generate a unique
- causal UID for the call.
- """
- active_inputs = self.get_active_inputs(input_uids=input_uids)
- return Hashing.get_content_hash(
- obj=[
- {
- active_inputs[k]: v
- for k, v in input_causal_uids.items()
- if k in active_inputs.keys()
- },
- semantic_version,
- self.sig.versioned_internal_name,
- ]
- )
-
- def get_pre_call_uid(self, input_uids: Dict[str, str]) -> str:
- # get call UID using *internal names* to guarantee the same UID will be
- # assigned regardless of renamings
- active_inputs = self.get_active_inputs(input_uids=input_uids)
- hashable_input_uids = {
- active_inputs[k]: v
- for k, v in input_uids.items()
- if k in active_inputs.keys()
- }
- call_uid = Hashing.get_content_hash(
- obj=[
- hashable_input_uids,
- self.sig.versioned_internal_name,
- ]
- )
- return call_uid
-
- def get_call_uid(self, pre_call_uid: str, semantic_version: Optional[str]) -> str:
- return Hashing.get_content_hash((pre_call_uid, semantic_version))
-
-
-################################################################################
-### wrapping
-################################################################################
-def wrap_atom(obj: Any, uid: Optional[str] = None) -> ValueRef:
- """
- Wraps a value as a `ValueRef`, if it isn't one already.
-
- The uid is either explicitly set, or a content hash is generated. Note that
- content hashing may take non-trivial time for large objects. When `obj` is
- already a `ValueRef` and `uid` is provided, an error is raised.
- """
- if isinstance(obj, Ref) and not isinstance(obj, ValueRef):
- raise ValueError(f"Cannot wrap {obj} as a ValueRef")
- if isinstance(obj, ValueRef):
- if uid is not None:
- # protect against accidental misuse
- raise ValueError(f"Cannot change uid of ValueRef: {obj}")
- return obj
- elif not isinstance(obj, TransientObj):
- uid = Hashing.get_content_hash(obj) if uid is None else uid
- return ValueRef(uid=uid, obj=obj, in_memory=True)
- else:
- if obj.unhashable:
- uid = get_uid()
- else:
- uid = Hashing.get_content_hash(obj.obj) if uid is None else uid
- uid = get_transient_uid(content_hash=uid)
- return ValueRef(uid=uid, obj=obj.obj, in_memory=True, transient=True)
-
-
-def collect_detached(refs: Iterable[Ref], include_transient: bool) -> List[Ref]:
- """
- Recursively get all detached `Ref`s present in `refs`
- """
- detached_vrefs = []
- for ref in refs:
- if (
- isinstance(ref, Ref)
- and not ref.in_memory
- and (include_transient or not ref.transient)
- ):
- detached_vrefs.append(ref)
- elif isinstance(ref, StructRef) and ref.in_memory:
- detached_vrefs.extend(
- collect_detached(ref.children(), include_transient=include_transient)
- )
- else:
- continue
- return detached_vrefs
-
-
-def clone(ref: Ref) -> Ref:
- """
- Clone a `Ref` by creating a new object which has the same data
- """
- if isinstance(ref, ValueRef):
- return ref.dump()
-
-
-from .builtins_ import ListRef, DictRef, SetRef, Builtins, StructRef
-from ..queries.weaver import ValNode, StructOrientations, CallNode, prepare_query
-from ..ui.contexts import GlobalContext
-from .wrapping import causify_down, causify_atom
diff --git a/mandala/core/prov.py b/mandala/core/prov.py
deleted file mode 100644
index a2d2292..0000000
--- a/mandala/core/prov.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from ..common_imports import *
-from .builtins_ import ListRef, DictRef, SetRef
-
-BUILTIN_IDS = [ListRef.builtin_id, DictRef.builtin_id, SetRef.builtin_id]
-BUILTIN_OP_IDS = [f"{x}_0" for x in BUILTIN_IDS]
-
-
-def propagate_struct_provenance(prov_df: pd.DataFrame) -> pd.DataFrame:
- """
- Compute directions of structural calls in a new column `direction_new` by
- inferring from the data. Currently for backward compatibility.
-
- The algorithm is as follows:
- - find all the refs that are the direct result if a non-struct op call
- - find all the struct calls where these refs appear as the struct
- - find all the struct calls expressing the items of these structs
- - mark these items as outputs
- - repeat the process for these items, until no new structs are found among them
-
- Note that this assigns `direction_new` for all calls involved in the structs
- found by this process.
-
- For every struct call that hasn't been assigned `direction_new` yet, we mark
- the struct as output, and the items (and indices) as inputs.
- """
- prov_df = prov_df.copy()
- prov_df["direction_new"] = [None for _ in range(len(prov_df))]
- nonstruct_outputs_causal_uids = prov_df.query(
- 'direction == "output" and op_id not in @BUILTIN_OP_IDS'
- ).causal.values
- structs_df = get_structs_df(prov_df, nonstruct_outputs_causal_uids)
- items_df = get_items_df(prov_df, structs_df.call_causal.values)
- while len(structs_df) > 0:
- # mark only the items (not structs or indices) as outputs
- prov_df["direction_new"][
- prov_df.call_causal.isin(items_df.call_causal)
- & ~(prov_df.name.isin(["lst", "dct", "st", "idx", "key"]))
- ] = "output"
- items_causal_uids = items_df.causal.values
- structs_df = get_structs_df(prov_df, items_causal_uids)
- items_df = get_items_df(prov_df, structs_df.call_causal.values)
- remaining_struct_mask = (
- (prov_df["direction_new"] != prov_df["direction_new"])
- & (prov_df["op_id"].isin(BUILTIN_OP_IDS))
- & (prov_df["name"].isin(["lst", "dct", "st"]))
- )
- prov_df.loc[remaining_struct_mask, "direction_new"] = "output"
-
- remaining_things = prov_df.query("direction_new != direction_new").index
- prov_df.loc[remaining_things, "direction_new"] = prov_df.query(
- "direction_new != direction_new"
- ).direction
- return prov_df
-
-
-def get_structs_df(prov_df: pd.DataFrame, causal_uids: Iterable[str]) -> pd.DataFrame:
- """
- Given some causal UIDs and a provenance dataframe, return the sub-dataframe
- where these causal UIDs appear in the role of the struct in a structural
- call
- """
- return prov_df.query(
- 'causal in @causal_uids and op_id in @BUILTIN_OP_IDS and name in ["lst", "dct", "st"]'
- )
-
-
-def get_items_df(
- prov_df: pd.DataFrame, struct_call_uids: Iterable[str]
-) -> pd.DataFrame:
- """
- Given some structural causal call UIDs and a provenance dataframe, return
- the sub-dataframe where these structural calls are associated with items
- (elements/values) of the structs
- """
- # get the sub-dataframe for these structural calls containing the items (elts/values) of the structs
- return prov_df.query('call_causal in @struct_call_uids and name in ["elt", "val"]')
-
-
-def get_idx_df(prov_df: pd.DataFrame, struct_call_uids: Iterable[str]) -> pd.DataFrame:
- return prov_df.query('call_causal in @struct_call_uids and name in ["idx", "key"]')
diff --git a/mandala/core/sig.py b/mandala/core/sig.py
deleted file mode 100644
index 298aef9..0000000
--- a/mandala/core/sig.py
+++ /dev/null
@@ -1,511 +0,0 @@
-from ..common_imports import *
-from .config import Config, is_output_name
-from .utils import get_uid, Hashing, is_subdict, get_full_uid
-from ..utils import serialize
-from pickle import PicklingError
-
-if Config.has_torch:
- import torch
-
-
-def sanitize_annotation(annotation: Any) -> Any:
- try:
- serialize(annotation)
- return annotation
- except PicklingError:
- return Any
- except Exception as e:
- raise ValueError(f"Invalid annotation: {annotation}") from e
-
-
-class Signature:
- """
- Holds and manipulates the relevant metadata for a memoized function, which
- includes
- - the function's user-interface (human-facing) and internal (used by storage)
- name,
- - the user-interface and internal input names (and the mapping between them),
- - the version,
- - and the default values.
- - (optional) superop status
-
- Responsible for manipulations to this state, and keeping it consistent, so
- e.g. logic for checking if a refactoring makes sense should be hidden here.
-
- The internal name of the function is an immutable UID that is used to
- identify the function throughout its entire lifetime for the storage it is
- connected to. The UI name is what the function is named in the source
- code, and can be changed. Same for the internal/UI input names.
-
- What goes through most of the system at runtime are the UI names, to make it
- easier to debug and inspect things. The internal names are used only in very
- specific and isolated parts of the architecture.
- """
-
- def __init__(
- self,
- ui_name: str,
- input_names: Set[str],
- n_outputs: int,
- defaults: Dict[str, Any], # ui name -> default value
- version: int,
- input_annotations: Dict[str, Any],
- output_annotations: List[Any],
- _is_builtin: bool = False,
- ):
- self.ui_name = ui_name
- self.input_names = input_names
- self.defaults = defaults
- self.n_outputs = n_outputs
- self.version = version
- self._internal_name = None
- # ui name -> internal name for inputs
- self._ui_to_internal_input_map = None
- # internal input name -> UID of default value
- # this stores the UIDs of default values for inputs that have been
- # added to the function since its creation
- self._new_input_defaults_uids = {}
-
- self.input_annotations = {
- k: sanitize_annotation(v) for k, v in input_annotations.items()
- }
- self.output_annotations = [sanitize_annotation(v) for v in output_annotations]
-
- self._is_builtin = _is_builtin
- if self.is_builtin:
- self._internal_name = ui_name
- self._ui_to_internal_input_map = {name: name for name in input_names}
-
- @property
- def is_builtin(self) -> bool:
- return self._is_builtin
-
- def check_invariants(self):
- assert set(self.defaults.keys()) <= self.input_names
- assert set(self.input_annotations.keys()) == self.input_names
- assert len(self.output_annotations) == self.n_outputs
- if self.has_internal_data:
- assert set(self._ui_to_internal_input_map.keys()) == self.input_names
- assert set(self._new_input_defaults_uids.keys()) <= set(
- self._ui_to_internal_input_map.values()
- )
-
- def __repr__(self) -> str:
- return (
- f"Signature(ui_name={self.ui_name}, input_names={self.input_names}, "
- f"n_outputs={self.n_outputs}, defaults={self.defaults}, "
- f"version={self.version}, internal_name={self._internal_name}, "
- f"ui_to_internal_input_map={self._ui_to_internal_input_map}, "
- f"new_input_defaults_uids={self._new_input_defaults_uids}, "
- f"is_builtin={self.is_builtin})"
- )
-
- @property
- def versioned_ui_name(self) -> str:
- """
- Return the version-qualified human-readable name of this signature, used to
- disambiguate between different versions of the same function.
- """
- return f"{self.ui_name}_{self.version}"
-
- @property
- def versioned_internal_name(self) -> str:
- """
- Return the version-qualified internal name of this signature
- """
- return f"{self.internal_name}_{self.version}"
-
- @property
- def internal_name(self) -> str:
- if self._internal_name is None:
- raise ValueError("Internal name not set")
- return self._internal_name
-
- @staticmethod
- def parse_versioned_name(versioned_name: str) -> Tuple[str, int]:
- """
- Recover the name and version from a version-qualified name
- """
- name, version_string = versioned_name.rsplit("_", 1)
- return name, int(version_string)
-
- @staticmethod
- def dump_versioned_name(name: str, version: int) -> str:
- """
- Return a version-qualified name from a name and version
- """
- return f"{name}_{version}"
-
- @property
- def ui_to_internal_input_map(self) -> Dict[str, str]:
- if self._ui_to_internal_input_map is None:
- raise ValueError("Internal input names not set")
- return self._ui_to_internal_input_map
-
- @property
- def internal_to_ui_input_map(self) -> Dict[str, str]:
- """
- Mapping from internal input names to their UI names
- """
- if not self.has_internal_data:
- raise ValueError()
- return {v: k for k, v in self.ui_to_internal_input_map.items()}
-
- @property
- def has_internal_data(self) -> bool:
- """
- Whether this signature has had its internal data (internal signature
- name and internal input names) set.
- """
- return (
- self._internal_name is not None
- and self._ui_to_internal_input_map is not None
- and self._ui_to_internal_input_map.keys() == self.input_names
- )
-
- @property
- def new_ui_input_default_uids(self) -> Dict[str, str]:
- return {
- self.internal_to_ui_input_map[k]: v
- for k, v in self._new_input_defaults_uids.items()
- }
-
- def __eq__(self, other: Any) -> bool:
- return (
- isinstance(other, Signature)
- and self.ui_name == other.ui_name
- and self.input_names == other.input_names
- and self.n_outputs == other.n_outputs
- and self.defaults == other.defaults
- and self.version == other.version
- and self._internal_name == other._internal_name
- and self._ui_to_internal_input_map == other._ui_to_internal_input_map
- and self._new_input_defaults_uids == other._new_input_defaults_uids
- )
-
- ############################################################################
- ### PURE methods for manipulating the signature
- ### to avoid broken state
- ############################################################################
- def _generate_internal(self, internal_name: Optional[str] = None) -> "Signature":
- """
- Assign internal names to random UIDs.
-
- Providing `internal_name` explicitly can be used to set the same
- internal name for different versions of the same function.
- """
- res = copy.deepcopy(self)
- if not self.is_builtin:
- if internal_name is None:
- internal_name = get_uid()
- ui_to_internal_map = {k: get_uid() for k in self.input_names}
- res._internal_name, res._ui_to_internal_input_map = (
- internal_name,
- ui_to_internal_map,
- )
- if len(self._new_input_defaults_uids) > 0:
- assert self.has_internal_data
- res._new_input_defaults_uids = {
- ui_to_internal_map[self.internal_to_ui_input_map[k]]: v
- for k, v in self._new_input_defaults_uids.items()
- }
- res.check_invariants()
- return res
-
- def is_compatible(self, new: "Signature") -> Tuple[bool, Optional[str]]:
- """
- Check if a new signature (possibly without internal data) is compatible
- with this signature.
-
- Currently, the only way to be compatible is to be either the same object
- or an extension with new arguments.
-
- Returns:
- Tuple[bool, str]: (outcome, (reason if `False`, None if True))
- """
- if new.version != self.version:
- return False, "Versions do not match"
- if new.ui_name != self.ui_name:
- return False, "UI names do not match"
- if new.has_internal_data and self.has_internal_data:
- if new.internal_name != self.internal_name:
- return False, "Internal names do not match"
- if not is_subdict(
- self.ui_to_internal_input_map, new.ui_to_internal_input_map
- ):
- return False, "UI -> internal input mapping is inconsistent"
- new_internal_names = set(new._ui_to_internal_input_map.values())
- current_internal_names = set(self._ui_to_internal_input_map.values())
- if not set.issubset(current_internal_names, new_internal_names):
- return False, "Internal input names must be a superset of current"
- if not set.issubset(set(self.input_names), set(new.input_names)):
- return False, "Removing inputs is not supported"
- # if not self.n_outputs == new.n_outputs:
- # return False, "Changing the number of outputs is not supported"
- if not is_subdict(self.defaults, new.defaults):
- return False, "New defaults are inconsistent with current defaults"
- if new.has_internal_data and not is_subdict(
- self._new_input_defaults_uids, new._new_input_defaults_uids
- ):
- return False, "New default UIDs are inconsistent with current default UIDs"
- for k in new.input_names:
- if k not in self.input_names:
- if k not in new.defaults.keys():
- return False, f"All new arguments must be created with defaults!"
- return True, None
-
- def update(self, new: "Signature") -> Tuple["Signature", dict]:
- """
- Return an updated version of this signature based on a new signature
- (possibly without internal data), plus a description of the updates.
-
- If the new signature has internal data, it is copied over.
-
- NOTE: the new signature need not have internal data. The goal of this
- method is to be able to update from a function provided by the user that
- has not been synchronized yet.
-
- This takes care of
- - checking that the new signature is compatible with the old one
- - generating names for new inputs, if any.
-
- Returns:
- - new `Signature` object
- - a dictionary of {new ui input name: default value} for any new inputs
- that were created
- """
- is_compatible, reason = self.is_compatible(new)
- if not is_compatible:
- raise ValueError(reason)
- new_defaults = new.defaults
- res = copy.deepcopy(self)
- updates = {}
- for k in new.input_names:
- if k not in res.input_names:
- # this means a new input is being created
- if new.has_internal_data:
- internal_name = new.ui_to_internal_input_map[k]
- else:
- internal_name = None
- res = res.create_input(
- name=k,
- default=new_defaults[k],
- internal_name=internal_name,
- annotation=new.input_annotations[k],
- )
- updates[k] = new_defaults[k]
- if new.n_outputs != self.n_outputs:
- res.n_outputs = new.n_outputs
- res.output_annotations = new.output_annotations
- res.check_invariants()
- return res, updates
-
- def create_input(
- self,
- name: str,
- default: Any,
- annotation: Any,
- internal_name: Optional[str] = None,
- ) -> "Signature":
- """
- Add an input with a default value to this signature. This takes care of
- all the internal bookkeeping, including figuring out the UID for the
- default value.
- """
- if name in self.input_names:
- raise ValueError(f'Input "{name}" already exists')
- if not self.has_internal_data:
- raise InternalError(
- "Cannot add inputs to a signature without internal data"
- )
- res = copy.deepcopy(self)
- res.input_names.add(name)
- internal_name = get_uid() if internal_name is None else internal_name
- res.ui_to_internal_input_map[name] = internal_name
- res.defaults[name] = default
- uid = Hashing.get_content_hash(obj=default)
- full_uid = get_full_uid(uid=uid, causal_uid=uid)
- res._new_input_defaults_uids[internal_name] = full_uid
- res.input_annotations[name] = annotation
- res.check_invariants()
- return res
-
- def rename(self, new_name: str) -> "Signature":
- """
- Change the ui name
- """
- res = copy.deepcopy(self)
- res.ui_name = new_name
- res.check_invariants()
- return res
-
- def rename_inputs(self, mapping: Dict[str, str]) -> "Signature":
- """
- Change UI names according to the given mapping.
-
- Supporting only a method that changes multiple names at once is more
- convenient, since we must support applying updates in bulk anyway.
- """
- assert all(k in self.input_names for k in mapping.keys())
- current_names = list(self.input_names)
- new_names = [mapping.get(k, k) for k in current_names]
- if len(set(new_names)) != len(new_names):
- raise ValueError("Input name collision")
- res = copy.deepcopy(self)
- # migrate input names
- for current_name in mapping.keys():
- res.input_names.remove(current_name)
- for new_name in mapping.values():
- res.input_names.add(new_name)
- # migrate defaults
- res.defaults = {mapping.get(k, k): v for k, v in res.defaults.items()}
- # migrate annotations
- res.input_annotations = {
- mapping.get(k, k): v for k, v in res.input_annotations.items()
- }
- # migrate internal data
- for current_name, new_name in mapping.items():
- res.ui_to_internal_input_map[new_name] = res.ui_to_internal_input_map.pop(
- current_name
- )
- res.check_invariants()
- return res
-
- def bump_version(self) -> "Signature":
- res = copy.deepcopy(self)
- res.version += 1
- res.check_invariants()
- return res
-
- @staticmethod
- def from_py(
- name: str,
- version: int,
- sig: inspect.Signature,
- _is_builtin: bool = False,
- ) -> "Signature":
- """
- Create a `Signature` from a Python function's signature and the other
- necessary metadata, and check it satisfies mandala-specific constraints.
- """
- input_names = set(
- [
- param.name
- for param in sig.parameters.values()
- if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
- ]
- )
- # ensure that there will be no collisions with input and output names
- if any(is_output_name(name) for name in input_names):
- raise ValueError(
- f"Input names cannot be of the form {Config.output_name_prefix}_[number]"
- )
- return_annotation = sig.return_annotation
- if (
- hasattr(return_annotation, "__origin__")
- and return_annotation.__origin__ is tuple
- ):
- n_outputs = len(return_annotation.__args__)
- elif return_annotation is inspect._empty:
- n_outputs = 0
- else:
- n_outputs = 1
- defaults = {
- param.name: param.default
- for param in sig.parameters.values()
- if param.default is not inspect.Parameter.empty
- }
- res = Signature(
- ui_name=name,
- input_names=input_names,
- n_outputs=n_outputs,
- defaults=defaults,
- version=version,
- input_annotations=_get_arg_annotations_from_sig(sig, support=input_names),
- output_annotations=_get_return_annotations_from_sig(
- sig=sig, support_size=n_outputs
- ),
- _is_builtin=_is_builtin,
- )
- res.check_invariants()
- return res
-
-
-def _postprocess_outputs(sig: Signature, result) -> List[Any]:
- if sig.n_outputs == 0:
- assert (
- result is None
- ), f"Operation with signature {sig} has zero outputs, but its function returned {result}"
- return []
- elif sig.n_outputs == 1:
- return [result]
- else:
- assert isinstance(
- result, tuple
- ), f"Operation {sig.ui_name} has multiple outputs, but its function returned a non-tuple: {result}"
- assert (
- len(result) == sig.n_outputs
- ), f"Operation {sig.ui_name} has {sig.n_outputs} outputs, but its function returned a tuple of length {len(result)}"
- return list(result)
-
-
-def _get_arg_annotations_from_sig(
- sig: inspect.Signature, support: Set[str]
-) -> Dict[str, Any]:
- # Create an empty dictionary to store the argument annotations
- arg_annotations = {}
- # Iterate over the function's parameters
- for param in sig.parameters.values():
- # Add the parameter name and its type annotation to the dictionary
- arg_annotations[param.name] = param.annotation
- for k in support:
- if k not in arg_annotations:
- arg_annotations[k] = Any
- # Return the dictionary of argument annotations
- return arg_annotations
-
-
-def _get_arg_annotations(func: Callable, support: Set[str]) -> Dict[str, Any]:
- if Config.has_torch and isinstance(func, torch.jit.ScriptFunction):
- return {a.name: a.type for a in func.schema.arguments}
- return _get_arg_annotations_from_sig(sig=inspect.signature(func), support=support)
-
-
-def is_typing_tuple(obj: Any) -> bool:
- """
- Check if an object is a typing.Tuple[...] thing
- """
- try:
- return obj.__origin__ is tuple
- except AttributeError:
- return False
-
-
-def unpack_return_annotation(return_annotation: Any, support_size: int) -> List[Any]:
- if is_typing_tuple(return_annotation):
- # we must unpack the typing tuple in this weird way
- res: List[Any] = list(return_annotation.__args__)
- elif return_annotation is inspect._empty:
- if support_size > 0:
- res = [Any for _ in range(support_size)]
- else:
- res = []
- else:
- res = [return_annotation]
- return res
-
-
-def _get_return_annotations_from_sig(
- sig: inspect.Signature, support_size: int
-) -> List[Any]:
- return_annotation = sig.return_annotation
- return unpack_return_annotation(return_annotation, support_size)
-
-
-def _get_return_annotations(func: Callable, support_size: int) -> List[Any]:
- if Config.has_torch and isinstance(func, torch.jit.ScriptFunction):
- return_annotation = func.schema.returns[0].type
- else:
- sig = inspect.signature(func)
- return_annotation = sig.return_annotation
- return unpack_return_annotation(return_annotation, support_size)
diff --git a/mandala/core/tps.py b/mandala/core/tps.py
deleted file mode 100644
index 7bb54bf..0000000
--- a/mandala/core/tps.py
+++ /dev/null
@@ -1,90 +0,0 @@
-from ..common_imports import *
-import typing
-
-################################################################################
-### types
-################################################################################
-class Type:
- @staticmethod
- def from_annotation(annotation: Any) -> "Type":
- if (annotation is None) or (annotation is inspect._empty):
- return AnyType()
- if annotation is typing.Any:
- return AnyType()
- elif isinstance(annotation, Type):
- return annotation
- elif isinstance(annotation, type):
- if annotation is list:
- return ListType()
- elif annotation is dict:
- return DictType()
- elif annotation is set:
- return SetType()
- else:
- return AnyType()
- elif hasattr(annotation, "__origin__"):
- if annotation.__origin__ is list:
- elt_annotation = annotation.__args__[0]
- return ListType(
- elt_type=Type.from_annotation(annotation=elt_annotation)
- )
- elif annotation.__origin__ is dict:
- value_annotation = annotation.__args__[1]
- return DictType(
- elt_type=Type.from_annotation(annotation=value_annotation)
- )
- elif annotation.__origin__ is set:
- elt_annotation = annotation.__args__[0]
- return SetType(elt_type=Type.from_annotation(annotation=elt_annotation))
- elif annotation.__origin__ is tuple:
- return AnyType()
- else:
- return AnyType()
- else:
- return AnyType()
-
- def __eq__(self, other: Any) -> bool:
- if type(self) != type(other):
- return False
- if isinstance(self, StructType):
- return self.struct_id == other.struct_id and self.elt_type == other.elt_type
- elif isinstance(self, AnyType):
- return True
- else:
- raise NotImplementedError
-
-
-class AnyType(Type):
- def __repr__(self):
- return "AnyType()"
-
-
-class StructType(Type):
- struct_id = None
-
- def __init__(self, elt_type: Optional[Type] = None):
- self.elt_type = AnyType() if elt_type is None else elt_type
-
-
-class ListType(StructType):
- struct_id = "__list__"
- model = list
-
- def __repr__(self):
- return f"ListType(elt_type={self.elt_type})"
-
-
-class DictType(StructType):
- struct_id = "__dict__"
- model = dict
-
- def __repr__(self):
- return f"DictType(elt_type={self.elt_type})"
-
-
-class SetType(StructType):
- struct_id = "__set__"
- model = set
-
- def __repr__(self):
- return f"SetType(elt_type={self.elt_type})"
diff --git a/mandala/core/utils.py b/mandala/core/utils.py
deleted file mode 100644
index 6cb089b..0000000
--- a/mandala/core/utils.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import typing
-import hashlib
-import textwrap
-from weakref import WeakKeyDictionary
-from .config import *
-from ..common_imports import *
-
-if Config.has_cityhash:
- import cityhash
-
-OpKey = Tuple[str, int]
-
-
-def get_uid() -> str:
- """
- Generate a sequence of 32 hexadecimal characters using the operating
- system's "randomness".
- """
- return "{}".format(binascii.hexlify(os.urandom(16)).decode("utf-8"))
-
-
-def get_full_uid(uid: str, causal_uid: str) -> str:
- return f"{uid}.{causal_uid}"
-
-
-def parse_full_uid(full_uid: str) -> Tuple[str, str]:
- uid, causal_uid = full_uid.rsplit(".", 1)
- return uid, causal_uid
-
-
-_KT = TypeVar("_KT")
-_VT = TypeVar("_VT")
-
-
-def get_fibers_as_lists(mapping: Dict[_KT, _VT]) -> Dict[_VT, List[_KT]]:
- fibers = defaultdict(list)
- for vq, name in mapping.items():
- fibers[name].append(vq)
- return fibers
-
-
-def is_subdict(a: Dict, b: Dict) -> bool:
- """
- Check that all keys in `a` are in `b` with the same value.
- """
- return all((k in b and a[k] == b[k]) for k in a)
-
-
-def invert_dict(d: Dict) -> Dict:
- """
- Invert a dictionary, assuming that all values are unique.
- """
- return {v: k for k, v in d.items()}
-
-
-def concat_lists(lists: List[list]) -> list:
- return [x for lst in lists for x in lst]
-
-
-class Hashing:
- """
- Helpers for hashing e.g. function inputs and call metadata.
- """
-
- @staticmethod
- def get_content_hash_blake2b(obj: Any) -> str:
- # if Config.has_torch and isinstance(obj, torch.Tensor):
- # #! torch tensors do not have a deterministic hash under the below
- # # method
- # obj = obj.cpu().numpy()
- stream = io.BytesIO()
- joblib.dump(value=obj, filename=stream)
- stream.seek(0)
- m = hashlib.blake2b()
- m.update(str((stream.read())).encode())
- return m.hexdigest()
-
- @staticmethod
- def get_cityhash(obj: Any) -> str:
- stream = io.BytesIO()
- joblib.dump(value=obj, filename=stream)
- stream.seek(0)
- h = cityhash.CityHash128(stream.read())
- digest = h.to_bytes(16, "little")
- s = binascii.b2a_hex(digest)
- res = s.decode()
- return res
-
- @staticmethod
- def get_joblib_hash(obj: Any) -> str:
- if hasattr(obj, "__get_mandala_dict__"):
- obj = obj.__get_mandala_dict__()
- if Config.has_torch:
- obj = tensor_to_numpy(obj)
- if isinstance(obj, pd.DataFrame):
- # DataFrames cause collisions with joblib hashing
- obj = {
- "columns": obj.columns,
- "values": obj.values,
- "index": obj.index,
- }
- result = joblib.hash(obj)
- if result is None:
- raise RuntimeError("joblib.hash returned None")
- return result
-
- if Config.content_hasher == "blake2b":
- get_content_hash = get_content_hash_blake2b
- elif Config.content_hasher == "cityhash":
- get_content_hash = get_cityhash
- elif Config.content_hasher == "joblib":
- get_content_hash = get_joblib_hash
- else:
- raise ValueError("Unknown content hasher: {}".format(Config.content_hasher))
-
- ### deterministic hashing of common collections
- @staticmethod
- def hash_list(elts: List[str]) -> str:
- return Hashing.get_content_hash(elts)
-
- @staticmethod
- def hash_dict(elts: Dict[str, str]) -> str:
- key_order = sorted(elts.keys())
- return Hashing.get_content_hash([(k, elts[k]) for k in key_order])
-
- @staticmethod
- def hash_set(elts: Set[str]) -> str:
- return Hashing.get_content_hash(sorted(elts))
-
- @staticmethod
- def hash_multiset(elts: List[str]) -> str:
- return Hashing.get_content_hash(sorted(elts))
-
-
-def unwrap_decorators(
- obj: Callable, strict: bool = True
-) -> Union[types.FunctionType, types.MethodType]:
- while hasattr(obj, "__wrapped__"):
- obj = obj.__wrapped__
- if not isinstance(obj, (types.FunctionType, types.MethodType)):
- msg = f"Expected a function or method, but got {type(obj)}"
- if strict:
- raise RuntimeError(msg)
- else:
- logger.debug(msg)
- return obj
-
-
-if Config.has_torch:
-
- def tensor_to_numpy(obj: Union[torch.Tensor, dict, list, tuple, Any]) -> Any:
- """
- Recursively convert PyTorch tensors in a data structure to numpy arrays.
-
- Parameters
- ----------
- obj : any
- The input data structure.
-
- Returns
- -------
- any
- The data structure with tensors converted to numpy arrays.
- """
- if isinstance(obj, torch.Tensor):
- return obj.detach().cpu().numpy()
- elif isinstance(obj, dict):
- return {k: tensor_to_numpy(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [tensor_to_numpy(v) for v in obj]
- elif isinstance(obj, tuple):
- return tuple(tensor_to_numpy(v) for v in obj)
- else:
- return obj
diff --git a/mandala/core/wrapping.py b/mandala/core/wrapping.py
deleted file mode 100644
index cf84fbf..0000000
--- a/mandala/core/wrapping.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import typing
-from ..common_imports import *
-from .tps import Type, ListType, DictType, SetType, AnyType, StructType
-from .utils import Hashing
-from .model import Ref, Call, wrap_atom, ValueRef
-from .builtins_ import ListRef, DictRef, SetRef, Builtins, StructRef
-
-
-def typecheck(obj: Any, tp: Type):
- if isinstance(tp, ListType) and not (
- isinstance(obj, list) or isinstance(obj, ListRef)
- ):
- raise ValueError(f"Expecting a list, got {type(obj)}")
- elif isinstance(tp, DictType) and not (
- isinstance(obj, dict) or isinstance(obj, DictRef)
- ):
- raise ValueError(f"Expecting a dict, got {type(obj)}")
- elif isinstance(tp, SetType) and not (
- isinstance(obj, set) or isinstance(obj, SetRef)
- ):
- raise ValueError(f"Expecting a set, got {type(obj)}")
-
-
-def causify_atom(ref: ValueRef):
- if ref.causal_uid is not None:
- return
- ref.causal_uid = Hashing.get_content_hash(ref.uid)
-
-
-def causify_down(ref: Ref, start: str, stop_at_causal: bool = True):
- """
- In-place top-down assignment of causal hashes to a Ref and its children.
- Requires causal hashes to not be present initially.
- """
- assert start is not None
- if ref.causal_uid is not None and stop_at_causal:
- return
- if isinstance(ref, ValueRef):
- ref.causal_uid = start
- elif isinstance(ref, ListRef):
- if ref.in_memory:
- for i, elt in enumerate(ref.obj):
- causify_down(elt, start=Hashing.hash_list([start, i]))
- ref.causal_uid = start
- elif isinstance(ref, DictRef):
- if ref.in_memory:
- for k, elt in ref.obj.items():
- causify_down(elt, start=Hashing.hash_list([start, k]))
- ref.causal_uid = start
- elif isinstance(ref, SetRef):
- # sort by uid to ensure deterministic ordering
- if ref.in_memory:
- elts_by_uid = {elt.uid: elt for elt in ref.obj}
- sorted_uids = sorted({elt.uid for elt in ref.obj})
- for i, uid in enumerate(sorted_uids):
- causify_down(elts_by_uid[uid], start=Hashing.hash_list([start, uid]))
- ref.causal_uid = start
- else:
- raise ValueError(f"Unknown ref type {type(ref)}")
-
-
-def decausify(ref: Ref, stop_at_first_missing: bool = False):
- """
- In-place recursive removal of causal hashes from a Ref
- """
- if ref._causal_uid is None and stop_at_first_missing:
- return
- ref._causal_uid = None
- if isinstance(ref, StructRef):
- for elt in ref.children():
- decausify(elt, stop_at_first_missing=stop_at_first_missing)
-
-
-def wrap_constructive(obj: Any, annotation: Any) -> Tuple[Ref, List[Call]]:
- tp = Type.from_annotation(annotation=annotation)
- typecheck(obj=obj, tp=tp)
- calls = []
- if isinstance(obj, Ref):
- res = obj, calls
- elif isinstance(tp, AnyType):
- res = wrap_atom(obj=obj), calls
- causify_atom(ref=res[0])
- elif isinstance(tp, StructType):
- RefCls: Union[
- typing.Type[ListRef], typing.Type[DictRef], typing.Type[SetRef]
- ] = Builtins.REF_CLASSES[tp.struct_id]
- assert type(obj) == tp.model
- recursive_result = RefCls.map(
- obj=obj, func=lambda elt: wrap_constructive(elt, annotation=tp.elt_type)
- )
- wrapped_elts = RefCls.map(obj=recursive_result, func=lambda elt: elt[0])
- recursive_calls = RefCls.elts(
- RefCls.map(obj=recursive_result, func=lambda elt: elt[1])
- )
- calls.extend([c for cs in recursive_calls for c in cs])
- obj: StructRef = RefCls(obj=wrapped_elts, uid=None, in_memory=True)
- obj.causify_up()
- obj_calls = obj.get_calls()
- calls.extend(obj_calls)
- res = obj, calls
- else:
- raise ValueError(f"Unknown type {tp}")
- return res
-
-
-def wrap_inputs(
- objs: Dict[str, Any],
- annotations: Dict[str, Any],
-) -> Tuple[Dict[str, Ref], List[Call]]:
- calls = []
- wrapped_objs = {}
- for k, v in objs.items():
- wrapped_obj, wrapping_calls = wrap_constructive(
- obj=v,
- annotation=annotations[k],
- )
- wrapped_objs[k] = wrapped_obj
- calls.extend(wrapping_calls)
- return wrapped_objs, calls
-
-
-def causify_outputs(refs: List[Ref], call_causal_uid: str):
- assert isinstance(call_causal_uid, str)
- for i, ref in enumerate(refs):
- causify_down(ref=ref, start=Hashing.hash_list([call_causal_uid, str(i)]))
-
-
-def wrap_outputs(
- objs: List[Any],
- annotations: List[Any],
-) -> Tuple[List[Ref], List[Call]]:
- calls = []
- wrapped_objs = []
- for i, v in enumerate(objs):
- wrapped_obj, wrapping_calls = wrap_constructive(
- obj=v,
- annotation=annotations[i],
- )
- wrapped_objs.append(wrapped_obj)
- calls.extend(wrapping_calls)
- return wrapped_objs, calls
-
-
-################################################################################
-### unwrapping
-################################################################################
-T = TypeVar("T")
-
-
-def unwrap(obj: Union[T, Ref], through_collections: bool = True) -> T:
- """
- If an object is a `ValueRef`, returns the wrapped object; otherwise, return
- the object itself.
-
- If `through_collections` is True, recursively unwraps objects in lists,
- tuples, sets, and dict values.
- """
- if isinstance(obj, ValueRef) and obj.transient:
- return obj.obj
- if isinstance(obj, Ref) and not obj.in_memory:
- from ..ui.contexts import GlobalContext
-
- if GlobalContext.current is None:
- raise ValueError(
- "Cannot unwrap a Ref with `in_memory=False` outside a context"
- )
- storage = GlobalContext.current.storage
- storage.cache.mattach(vrefs=[obj])
- if isinstance(obj, ValueRef):
- return obj.obj
- elif isinstance(obj, StructRef):
- return type(obj).map(
- obj=obj.obj,
- func=lambda elt: unwrap(elt, through_collections=through_collections),
- )
- elif type(obj) is tuple and through_collections:
- return tuple(unwrap(v, through_collections=through_collections) for v in obj)
- elif type(obj) is set and through_collections:
- return {unwrap(v, through_collections=through_collections) for v in obj}
- elif type(obj) is list and through_collections:
- return [unwrap(v, through_collections=through_collections) for v in obj]
- elif type(obj) is dict and through_collections:
- return {
- k: unwrap(v, through_collections=through_collections)
- for k, v in obj.items()
- }
- else:
- return obj
-
-
-def contains_transient(ref: Ref) -> bool:
- if isinstance(ref, ValueRef):
- return ref.transient
- elif isinstance(ref, StructRef):
- return any(contains_transient(elt) for elt in ref.children())
- else:
- raise ValueError(f"Unexpected ref type {type(ref)}")
-
-
-def contains_not_in_memory(ref: Ref) -> bool:
- if isinstance(ref, ValueRef):
- return not ref.in_memory
- elif isinstance(ref, StructRef):
- return any(contains_not_in_memory(elt) for elt in ref.children())
- else:
- raise ValueError(f"Unexpected ref type {type(ref)}")
-
-
-def _sanitize_value(value: Any) -> Any:
- if isinstance(value, Ref):
- return (_sanitize_value(value.obj), value.in_memory, value.uid)
- try:
- hash(value)
- return value
- except TypeError:
- if isinstance(value, bytearray):
- return value.hex()
- elif isinstance(value, list):
- return tuple([_sanitize_value(v) for v in value])
- else:
- raise NotImplementedError(f"Got value of type {type(value)}")
-
-
-def compare_dfs_as_relations(
- df_1: pd.DataFrame, df_2: pd.DataFrame, return_reason: bool = False
-) -> Union[bool, Tuple[bool, str]]:
- if df_1.shape != df_2.shape:
- result, reason = False, f"Shapes differ: {df_1.shape} vs {df_2.shape}"
- if set(df_1.columns) != set(df_2.columns):
- result, reason = False, f"Columns differ: {df_1.columns} vs {df_2.columns}"
- # reorder columns of df_2 to match df_1
- df_2 = df_2[df_1.columns]
- # sanitize values to make them hashable
- df_1 = df_1.applymap(_sanitize_value)
- df_2 = df_2.applymap(_sanitize_value)
- # compare as sets of tuples
- result = set(map(tuple, df_1.itertuples(index=False))) == set(
- map(tuple, df_2.itertuples(index=False))
- )
- if result:
- reason = ""
- else:
- reason = f"Dataframe rows differ: {df_1} vs {df_2}"
- if return_reason:
- return result, reason
- else:
- return result
diff --git a/mandala/deps/crawler.py b/mandala/deps/crawler.py
index 75bee40..25f109c 100644
--- a/mandala/deps/crawler.py
+++ b/mandala/deps/crawler.py
@@ -1,6 +1,6 @@
import types
from ..common_imports import *
-from ..core.utils import unwrap_decorators
+from ..utils import unwrap_decorators
import importlib
from .model import (
DepKey,
diff --git a/mandala/deps/model.py b/mandala/deps/model.py
index b0599d2..731bb0e 100644
--- a/mandala/deps/model.py
+++ b/mandala/deps/model.py
@@ -3,8 +3,8 @@
import types
from ..common_imports import *
-from ..core.utils import Hashing
-from ..ui.viz import (
+from ..utils import get_content_hash
+from ..viz import (
write_output,
)
@@ -115,7 +115,7 @@ def representation(self) -> str:
def _set_representation(self, value: str):
assert isinstance(value, str)
self._representation = value
- self._content_hash = Hashing.get_content_hash(value)
+ self._content_hash = get_content_hash(value)
@representation.setter
def representation(self, value: str):
@@ -138,8 +138,8 @@ def represent(
obj: Union[types.FunctionType, types.CodeType, Callable],
allow_fallback: bool = False,
) -> str:
- if type(obj).__name__ == "FuncInterface":
- obj = obj.func_op.func
+ if type(obj).__name__ == "Op":
+ obj = obj.f
if not isinstance(obj, (types.FunctionType, types.MethodType, types.CodeType)):
logger.warning(f"Found {obj} of type {type(obj)}")
try:
@@ -201,7 +201,7 @@ def representation(self) -> Tuple[str, str]:
def represent(obj: Any, allow_fallback: bool = False) -> Tuple[str, str]:
truncated_repr = textwrap.shorten(text=repr(obj), width=80)
try:
- content_hash = Hashing.get_content_hash(obj=obj)
+ content_hash = get_content_hash(obj=obj)
except Exception as e:
shortened_exception = textwrap.shorten(text=str(e), width=80)
msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}"
diff --git a/mandala/deps/shallow_versions.py b/mandala/deps/shallow_versions.py
index c01dc14..98282f8 100644
--- a/mandala/deps/shallow_versions.py
+++ b/mandala/deps/shallow_versions.py
@@ -1,10 +1,10 @@
from typing import Literal
import textwrap
from ..common_imports import *
-from ..core.utils import Hashing
-from ..core.config import Config
+from ..utils import get_content_hash
+from ..config import Config
from ..utils import ask_user
-from ..ui.viz import _get_colorized_diff, _get_diff
+from ..viz import _get_colorized_diff, _get_diff
if Config.has_rich:
from rich.tree import Tree
@@ -110,7 +110,7 @@ def get_presentable_content(self, content: str) -> str:
return content
def get_content_hash(self, content: str) -> str:
- return Hashing.get_content_hash(content)
+ return get_content_hash(content)
GVContent = Tuple[str, str] # (content hash, repr)
@@ -241,7 +241,7 @@ def commit(self, content: T, is_semantic_change: Optional[bool] = None) -> str:
)
)
answer = ask_user(
- question="Does this change require recomputation of dependent calls? [y]es/[n]o/[a]bort",
+ question="Does this change require recomputation of dependent calls?\nWARNING: if the change created new dependencies and you choose 'no', you should add them by hand or risk missing changes in them.\nAnswer: [y]es/[n]o/[a]bort",
valid_options=["y", "n", "a"],
)
print(f'You answered: "{answer}"')
diff --git a/mandala/deps/tracers/dec_impl.py b/mandala/deps/tracers/dec_impl.py
index 66b01bc..4de9a3a 100644
--- a/mandala/deps/tracers/dec_impl.py
+++ b/mandala/deps/tracers/dec_impl.py
@@ -1,8 +1,8 @@
import types
from functools import wraps, update_wrapper
from ...common_imports import *
-from ...core.utils import unwrap_decorators
-from ...core.config import Config
+from ...utils import unwrap_decorators
+from ...config import Config
from ..model import (
DependencyGraph,
CallableNode,
@@ -40,6 +40,9 @@ def is_tracked(f: Union[types.FunctionType, type]) -> bool:
class TrackedDict(dict):
+ """
+ A dictionary that tracks global variable accesses.
+ """
def __init__(self, original: dict):
self.__original__ = original
@@ -147,6 +150,9 @@ def wrapper(*args, **kwargs) -> Any:
class DecTracer(TracerABC):
+ """
+ A decorator-based tracer that tracks function calls and global variable accesses.
+ """
def __init__(
self,
paths: List[Path],
diff --git a/mandala/deps/tracers/tracer_base.py b/mandala/deps/tracers/tracer_base.py
index 9542d2e..0081ca4 100644
--- a/mandala/deps/tracers/tracer_base.py
+++ b/mandala/deps/tracers/tracer_base.py
@@ -1,5 +1,5 @@
from ...common_imports import *
-from ...core.config import Config
+from ...config import Config
import importlib
from ..model import DependencyGraph, CallableNode
from abc import ABC, abstractmethod
diff --git a/mandala/deps/utils.py b/mandala/deps/utils.py
index f2dbda4..bad7405 100644
--- a/mandala/deps/utils.py
+++ b/mandala/deps/utils.py
@@ -5,8 +5,8 @@
from typing import Literal
from ..common_imports import *
-from ..core.utils import Hashing, unwrap_decorators
-from ..core.config import Config
+from ..utils import get_content_hash, unwrap_decorators
+from ..config import Config
DepKey = Tuple[str, str] # (module name, object address in module)
@@ -81,7 +81,7 @@ def is_callable_obj(obj: Any, strict: bool) -> bool:
def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType:
if type(obj).__name__ == Config.func_interface_cls_name:
- return obj.func_op.func
+ return obj.f
obj = unwrap_decorators(obj, strict=strict)
if isinstance(obj, types.BuiltinFunctionType):
raise ValueError(f"Expected a non-built-in function, but got {obj}")
@@ -102,7 +102,7 @@ def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType:
def extract_code(obj: Callable) -> types.CodeType:
if type(obj).__name__ == Config.func_interface_cls_name:
- obj = obj.func_op.func
+ obj = obj.f
if isinstance(obj, property):
obj = obj.fget
obj = unwrap_decorators(obj, strict=True)
@@ -169,7 +169,7 @@ def get_bytecode(f: Union[types.FunctionType, types.CodeType, str]) -> str:
def hash_dict(d: dict) -> str:
- return Hashing.get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())])
+ return get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())])
def load_obj(module_name: str, obj_name: str) -> Tuple[Any, bool]:
diff --git a/mandala/deps/versioner.py b/mandala/deps/versioner.py
index 3fd9a67..3552080 100644
--- a/mandala/deps/versioner.py
+++ b/mandala/deps/versioner.py
@@ -3,8 +3,8 @@
import textwrap
from ..common_imports import *
-from ..core.utils import is_subdict
-from ..core.config import Config
+from ..utils import is_subdict
+from ..config import Config
from .utils import DepKey, hash_dict
from .model import (
Node,
@@ -15,7 +15,7 @@
)
from .crawler import crawl_static
from .tracers import TracerABC
-from ..ui.viz import _get_colorized_diff
+from ..viz import _get_colorized_diff
if Config.has_rich:
from rich.panel import Panel
@@ -84,6 +84,10 @@ def get_version_ids(
"""
Get the content and semantic IDs for the version corresponding to the
given pre-call uid.
+
+ Inputs:
+ - `is_recompute`: this should be true only if this is a call with
+ transient outputs that we already computed once.
"""
assert tracer_option is not None
version = self.process_trace(
@@ -335,6 +339,17 @@ def lookup_call(
"""
Return a tuple of (content_version, semantic_version), or None if the
call is not found.
+
+ Inputs:
+ - `pre_call_uid`: this is a hash of the content IDs of the inputs,
+ together with the function's name.
+
+ This works as follows:
+ - we figure out the semantic hashes (i.e. shallow semantic versions) of
+ the elements of the code state present in the global topology we have on
+ record
+ - we restrict to the records that match the given `pre_call_uid`
+ - we search among these
"""
codebase_semantic_hashes = self.get_codestate_semantic_hashes(
code_state=code_state
@@ -408,7 +423,7 @@ def _check_semantic_distinguishability(
)
if all([semantic_hashes[k] == new_semantic_dep_hashes[k] for k in overlap]):
raise ValueError(
- f"Call {pre_call_uid} is not semantically distinguishable from call for semantic version {semantic_version}"
+ f"Call to {component} with pre_call_uid={pre_call_uid} is not semantically distinguishable from call for semantic version {semantic_version}"
)
############################################################################
diff --git a/mandala/deps/viz.py b/mandala/deps/viz.py
index c5baa9b..82907fb 100644
--- a/mandala/deps/viz.py
+++ b/mandala/deps/viz.py
@@ -1,6 +1,6 @@
import textwrap
from ..common_imports import *
-from ..ui.viz import (
+from ..viz import (
Node as DotNode,
Edge as DotEdge,
Group as DotGroup,
diff --git a/mandala/_next/docs/01_storage_and_ops.ipynb b/mandala/docs/01_storage_and_ops.ipynb
similarity index 100%
rename from mandala/_next/docs/01_storage_and_ops.ipynb
rename to mandala/docs/01_storage_and_ops.ipynb
diff --git a/mandala/_next/docs/02_retracing.ipynb b/mandala/docs/02_retracing.ipynb
similarity index 100%
rename from mandala/_next/docs/02_retracing.ipynb
rename to mandala/docs/02_retracing.ipynb
diff --git a/mandala/_next/docs/03_cf.ipynb b/mandala/docs/03_cf.ipynb
similarity index 100%
rename from mandala/_next/docs/03_cf.ipynb
rename to mandala/docs/03_cf.ipynb
diff --git a/mandala/_next/docs/04_versions.ipynb b/mandala/docs/04_versions.ipynb
similarity index 100%
rename from mandala/_next/docs/04_versions.ipynb
rename to mandala/docs/04_versions.ipynb
diff --git a/mandala/_next/docs/05_collections.ipynb b/mandala/docs/05_collections.ipynb
similarity index 100%
rename from mandala/_next/docs/05_collections.ipynb
rename to mandala/docs/05_collections.ipynb
diff --git a/mandala/_next/docs/06_advanced_cf.ipynb b/mandala/docs/06_advanced_cf.ipynb
similarity index 100%
rename from mandala/_next/docs/06_advanced_cf.ipynb
rename to mandala/docs/06_advanced_cf.ipynb
diff --git a/mandala/_next/docs/make_docs.py b/mandala/docs/make_docs.py
similarity index 100%
rename from mandala/_next/docs/make_docs.py
rename to mandala/docs/make_docs.py
diff --git a/mandala/_next/docs/readme.md b/mandala/docs/readme.md
similarity index 100%
rename from mandala/_next/docs/readme.md
rename to mandala/docs/readme.md
diff --git a/mandala/imports.py b/mandala/imports.py
index 75de7f7..66ff18c 100644
--- a/mandala/imports.py
+++ b/mandala/imports.py
@@ -1,11 +1,10 @@
-"""
-Intended way for users to import mandala
-"""
-from .core.model import wrap_atom
-from .core.wrapping import unwrap
+from .storage import Storage
+from .model import op, Ignore, NewArgDefault
+from .tps import MList, MDict
from .deps.tracers.dec_impl import track
-from .queries import ListQ, SetQ, DictQ
-from .core.config import Config
-from .ui.storage import Storage
-from .ui.funcs import op, superop, Q, Transient
-from .ui.utils import wrap_ui as wrap
+
+from .common_imports import sess
+
+
+def pprint_dict(d) -> str:
+ return '\n'.join([f" {k}: {v}" for k, v in d.items()])
\ No newline at end of file
diff --git a/mandala/_next/model.py b/mandala/model.py
similarity index 100%
rename from mandala/_next/model.py
rename to mandala/model.py
diff --git a/mandala/queries/__init__.py b/mandala/queries/__init__.py
deleted file mode 100644
index e7c481c..0000000
--- a/mandala/queries/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .weaver import BuiltinQueries
-
-ListQ = BuiltinQueries.ListQ
-DictQ = BuiltinQueries.DictQ
-SetQ = BuiltinQueries.SetQ
diff --git a/mandala/queries/compiler.py b/mandala/queries/compiler.py
deleted file mode 100644
index 902d43b..0000000
--- a/mandala/queries/compiler.py
+++ /dev/null
@@ -1,129 +0,0 @@
-from ..common_imports import *
-from ..core.model import FuncOp
-from ..core.sig import Signature
-from ..core.config import Config, dump_output_name
-from .weaver import ValNode, CallNode, traverse_all
-from .viz import visualize_graph, get_names
-from ..core.utils import invert_dict, OpKey
-from pypika import Query, Table, Criterion
-
-
-class Compiler:
- def __init__(self, vqs: List[ValNode], fqs: List[CallNode]):
- if not len(vqs) == len({id(x) for x in vqs}):
- raise InternalError
- if not len(fqs) == len({id(x) for x in fqs}):
- raise InternalError
- self.vqs = vqs
- self.fqs = fqs
- self.val_aliases, self.func_aliases = self._generate_aliases()
-
- def _generate_aliases(self) -> Tuple[Dict[ValNode, Table], Dict[CallNode, Table]]:
- func_aliases = {}
- func_counter = 0
- for i, func_query in enumerate(self.fqs):
- op_table = Table(func_query.func_op.sig.versioned_ui_name)
- func_aliases[func_query] = op_table.as_(f"_func_{func_counter}")
- func_counter += 1
- val_aliases = {}
- val_counter = 0
- for val_query in self.vqs:
- val_table = Table(Config.causal_vref_table)
- val_aliases[val_query] = val_table.as_(f"_var_{val_counter}")
- val_counter += 1
- return val_aliases, func_aliases
-
- def compile_func(
- self, fq: CallNode, semantic_versions: Optional[Set[str]] = None
- ) -> Tuple[list, list]:
- """
- Compile the query corresponding to an op, including built-in ops
- """
- constraints = []
- select_fields = []
- func_alias = self.func_aliases[fq]
- for input_name, val_query in fq.inputs.items():
- val_alias = self.val_aliases[val_query]
- constraints.append(val_alias[Config.full_uid_col] == func_alias[input_name])
- for output_name, val_query in fq.outputs.items():
- val_alias = self.val_aliases[val_query]
- constraints.append(
- val_alias[Config.full_uid_col] == func_alias[output_name]
- )
- if semantic_versions is not None:
- constraints.append(
- func_alias[Config.semantic_version_col].isin(semantic_versions)
- )
- if fq.constraint is not None:
- constraints.append(func_alias[Config.causal_uid_col].isin(fq.constraint))
- return constraints, select_fields
-
- def compile_val(self, val_query: ValNode) -> Tuple[list, list]:
- """
- Compile the query corresponding to a variable
- """
- constraints = []
- select_fields = []
- val_alias = self.val_aliases[val_query]
- if val_query.constraint is not None:
- constraints.append(
- val_alias[Config.full_uid_col].isin(val_query.constraint)
- )
- select_fields.append(val_alias[Config.full_uid_col])
- return constraints, select_fields
-
- def compile(
- self,
- select_queries: List[ValNode],
- semantic_version_constraints: Optional[Dict[OpKey, Optional[Set[str]]]] = None,
- filter_duplicates: bool = False,
- ):
- """
- Compile the query induced by the data of this compiler instance to
- an SQL select query. If `semantic_version_constraints` is not provided,
- no constraints are placed.
-
- NOTE:
- - for each value query, we select both columns of the variable
- table: the index and the partition. This is to be able to convert
- the query result directly into locations.
- - The list of columns, partitioned into sublists per value query, is
- also returned.
- """
- if not len(select_queries) == len({id(x) for x in select_queries}):
- raise InternalError
- assert all([vq in self.vqs for vq in select_queries])
- from_tables = []
- all_constraints = []
- select_cols = [
- self.val_aliases[vq][Config.full_uid_col] for vq in select_queries
- ]
- if semantic_version_constraints is None:
- semantic_version_constraints = {
- (op_query.func_op.sig.internal_name, op_query.func_op.sig.version): None
- for op_query in self.fqs
- }
- for func_query in self.fqs:
- op_key = (
- func_query.func_op.sig.internal_name,
- func_query.func_op.sig.version,
- )
- constraints, select_fields = self.compile_func(
- func_query, semantic_versions=semantic_version_constraints[op_key]
- )
- func_alias = self.func_aliases[func_query]
- from_tables.append(func_alias)
- all_constraints.extend(constraints)
- for val_query in self.vqs:
- val_alias = self.val_aliases[val_query]
- constraints, select_fields = self.compile_val(val_query)
- from_tables.append(val_alias)
- all_constraints.extend(constraints)
- query = Query
- for table in from_tables:
- query = query.from_(table)
- query = query.select(*select_cols)
- if filter_duplicates:
- query = query.distinct()
- query = query.where(Criterion.all(all_constraints))
- return query
diff --git a/mandala/queries/graphs.py b/mandala/queries/graphs.py
deleted file mode 100644
index aefc120..0000000
--- a/mandala/queries/graphs.py
+++ /dev/null
@@ -1,340 +0,0 @@
-from typing import Literal
-from collections import deque
-from ..core.utils import Hashing, invert_dict
-from ..core.tps import StructType
-from ..core.config import Provenance
-from ..common_imports import *
-from .weaver import ValNode, CallNode, Node
-
-
-def get_deps(nodes: Set[Node]) -> Set[Node]:
- res = set()
-
- def visit(node: Node):
- if node not in res:
- res.add(node)
- for n in node.neighbors("backward"):
- visit(n)
-
- for n in nodes:
- visit(n)
- return res
-
-
-def prune(
- vqs: Set[ValNode], fqs: Set[CallNode], selection: List[ValNode]
-) -> Tuple[Set[ValNode], Set[CallNode]]:
- # remove vqs of degree 1 that are not in selection
- pass
-
-
-def is_connected(val_queries: Set[ValNode], func_queries: Set[CallNode]) -> bool:
- """
- Check if the graph defined by these objects is a connected graph
- """
- if len(val_queries) == 0:
- raise NotImplementedError
- visited: Dict[Node, bool] = {}
-
- def visit(node: Node):
- if node not in visited.keys():
- visited[node] = True
- for n in node.neighbors("both"):
- visit(n)
-
- start = list(val_queries)[0]
- visit(start)
- return len(visited) == len(val_queries) + len(func_queries)
-
-
-def is_component(nodes: Set[Node]) -> bool:
- """
- Check if the given nodes are a full component of the graph
- """
- for node in nodes:
- if any(n not in nodes for n in node.neighbors()):
- return False
- return True
-
-
-def get_fibers(v_map: Dict[Node, str]) -> Dict[str, Set[Node]]:
- fibers = defaultdict(set)
- for vq, name in v_map.items():
- fibers[name].add(vq)
- return fibers
-
-
-def hash_groups(groups: Dict[str, Set[Node]]) -> str:
- return Hashing.hash_set(
- {Hashing.hash_set({str(id(v)) for v in g}) for g in groups.values()}
- )
-
-
-class InducedSubgraph:
- def __init__(self, vqs: Set[ValNode], fqs: Set[CallNode]):
- self.vqs = vqs
- self.fqs = fqs
- self.nodes: Set[Node] = vqs.union(fqs)
-
- def neighbors(
- self, node: Node, direction: Literal["forward", "backward", "both"]
- ) -> Set[Node]:
- return {n for n in node.neighbors(direction=direction) if n in self.nodes}
-
- def fq_inputs(self, fq: CallNode) -> Dict[str, ValNode]:
- return {k: v for k, v in fq.inputs.items() if v in self.vqs}
-
- def fq_outputs(self, fq: CallNode) -> Dict[str, ValNode]:
- return {k: v for k, v in fq.outputs.items() if v in self.vqs}
-
- def consumers(self, vq: ValNode) -> List[Tuple[str, CallNode]]:
- return [(k, v) for k, v in zip(vq.consumed_as, vq.consumers) if v in self.fqs]
-
- def creators(self, vq: ValNode) -> List[Tuple[str, CallNode]]:
- return [(k, v) for k, v in zip(vq.created_as, vq.creators) if v in self.fqs]
-
- def topsort(
- self, canonical_labels: Optional[Dict[Node, str]]
- ) -> Tuple[List[Node], Set[Node], Set[Node]]:
- """
- Get the topsort, the sources, and the sinks. Optionally, provide a
- canonical label for each node to get a canonical topsort.
- """
- # return a topological sort of the graph, and the sources and sinks
- in_degrees = {
- n: len(self.neighbors(node=n, direction="backward")) for n in self.nodes
- }
- sources = {v for v in self.nodes if in_degrees[v] == 0}
- sinks = set()
- res = []
- S = [v for v in self.nodes if in_degrees[v] == 0]
- if canonical_labels is not None:
- S = sorted(S, key=lambda x: canonical_labels[x])
- while len(S) > 0:
- v = S.pop(0)
- res.append(v)
- forward_neighbors = self.neighbors(node=v, direction="forward")
- if len(forward_neighbors) == 0:
- sinks.add(v)
- for w in self.neighbors(node=v, direction="forward"):
- in_degrees[w] -= 1
- if in_degrees[w] == 0:
- S.append(w)
- if canonical_labels is not None:
- # TODO: inefficient, needs a heap instead
- S = sorted(S, key=lambda x: canonical_labels[x])
- return res, sources, sinks
-
- def digest_neighbors_fq(self, fq: CallNode, colors: Dict[Node, str]) -> str:
- neighs = {**self.fq_inputs(fq=fq), **self.fq_outputs(fq=fq)}
- return Hashing.hash_dict({k: colors[v] for k, v in neighs.items()})
-
- def digest_neighbors_vq(self, vq: ValNode, colors: Dict[Node, str]) -> str:
- creator_labels = [
- Hashing.hash_list([k, colors[v]]) for k, v in self.creators(vq=vq)
- ]
- consumer_labels = [
- Hashing.hash_list([k, colors[v]]) for k, v in self.consumers(vq=vq)
- ]
- if isinstance(vq.tp, StructType):
- lst = [
- Hashing.hash_set(set(creator_labels)),
- Hashing.hash_set(set(consumer_labels)),
- ]
- else:
- lst = [
- Hashing.hash_multiset(creator_labels),
- Hashing.hash_multiset(consumer_labels),
- ]
- return Hashing.hash_list(lst)
-
- def run_color_refinement(
- self, initialization: Dict[Node, str], verbose: bool = False
- ) -> Dict[Node, str]:
- """
- Run the color refinement algorithm on this graph. Return a dict mapping
- each node to its color.
- """
- colors = initialization
- groups = get_fibers(initialization)
- iteration = 0
- while True:
- if verbose:
- logger.info(f"Color refinement: iteration {iteration}")
- new_colors = {}
- for node in self.nodes:
- if isinstance(node, CallNode):
- neighbors_digest = self.digest_neighbors_fq(fq=node, colors=colors)
- elif isinstance(node, ValNode):
- neighbors_digest = self.digest_neighbors_vq(vq=node, colors=colors)
- else:
- raise Exception(f"Unknown node type {type(node)}")
- new_colors[node] = Hashing.hash_list([colors[node], neighbors_digest])
- new_groups = get_fibers(new_colors)
- if hash_groups(new_groups) == hash_groups(groups):
- break
- else:
- groups = new_groups
- colors = new_colors
- iteration += 1
- return colors
-
- @staticmethod
- def is_homomorphism(
- s: "InducedSubgraph",
- t: "InducedSubgraph",
- v_map: Dict[ValNode, ValNode],
- f_map: Dict[CallNode, CallNode],
- ) -> bool:
- for svq, tvq in v_map.items():
- if not isinstance(svq, ValNode):
- return False
- if not isinstance(tvq, ValNode):
- return False
- if not svq.tp == tvq.tp:
- return False
- for sfq, tfq in f_map.items():
- if not isinstance(sfq, CallNode):
- return False
- if not isinstance(tfq, CallNode):
- return False
- if not sfq.func_op.sig == tfq.func_op.sig:
- return False
- sfq_inps, sfq_outps = s.fq_inputs(fq=sfq), s.fq_outputs(fq=sfq)
- tfq_inps, tfq_outps = t.fq_inputs(fq=tfq), t.fq_outputs(fq=tfq)
- if not set(sfq_inps.keys()) == set(tfq_inps.keys()):
- return False
- if not set(sfq_outps.keys()) == set(tfq_outps.keys()):
- return False
- if not all(v_map[v] == tfq_inps[k] for k, v in sfq_inps.items()):
- return False
- if not all(v_map[v] == tfq_outps[k] for k, v in sfq_outps.items()):
- return False
- return True
-
- @staticmethod
- def are_canonically_isomorphic(s: "InducedSubgraph", t: "InducedSubgraph") -> bool:
- try:
- s_vlabels, s_flabels, s_topsort = s.canonicalize(strict=True)
- t_vlabels, t_flabels, t_topsort = t.canonicalize(strict=True)
- except AssertionError:
- raise NotImplementedError
- if len(s_vlabels) != len(t_vlabels) or len(s_flabels) != len(t_flabels):
- return False
- t_vinverse, t_finverse = invert_dict(t_vlabels), invert_dict(t_flabels)
- v_map = {vq: t_vinverse[v] for vq, v in s_vlabels.items()}
- f_map = {fq: t_finverse[f] for fq, f in s_flabels.items()}
- if not InducedSubgraph.is_homomorphism(s, t, v_map, f_map):
- return False
- if not InducedSubgraph.is_homomorphism(
- t, s, invert_dict(v_map), invert_dict(f_map)
- ):
- return False
- return True
-
- def get_projections(
- self, colors: Dict[Node, str]
- ) -> Tuple[Dict[ValNode, ValNode], Dict[CallNode, CallNode]]:
- v_groups = defaultdict(list)
- f_groups = defaultdict(list)
- for node, color in colors.items():
- if isinstance(node, ValNode):
- v_groups[color].append(node)
- elif isinstance(node, CallNode):
- f_groups[color].append(node)
- v_map: Dict[ValNode, ValNode] = {}
- f_map: Dict[CallNode, CallNode] = {}
- for color, gp in v_groups.items():
- representative = gp[0]
- prototype = ValNode(
- tp=representative.tp, constraint=representative.constraint
- )
- with_constraint = [vq for vq in gp if vq.constraint is not None]
- if len(with_constraint) == 0:
- representative_constraint = None
- elif len(with_constraint) == 1:
- representative_constraint = with_constraint[0].constraint
- else:
- raise ValueError("Multiple constraints for a single value")
- prototype = ValNode(
- tp=representative.tp, constraint=representative_constraint
- )
- for fq in gp:
- v_map[fq] = prototype
- for color, gp in f_groups.items():
- representative = gp[0]
- rep_inps = self.fq_inputs(fq=representative)
- rep_outps = self.fq_outputs(fq=representative)
- prototype = CallNode.link(
- inputs={k: v_map[vq] for k, vq in rep_inps.items()},
- outputs={k: v_map[vq] for k, vq in rep_outps.items()},
- func_op=representative.func_op,
- orientation=representative.orientation,
- constraint=None,
- )
- for fq in gp:
- f_map[fq] = prototype
- t = InducedSubgraph(vqs=set(v_map.values()), fqs=set(f_map.values()))
- assert self.is_homomorphism(s=self, t=t, v_map=v_map, f_map=f_map)
- return v_map, f_map
-
- def canonicalize(
- self, strict: bool = False, method: str = "wl1"
- ) -> Tuple[Dict[ValNode, str], Dict[CallNode, str], List[Node]]:
- """
- Return canonical labels for each node, as well as a canonical topological sort.
- """
- initialization = {}
- for vq in self.vqs:
- initialization[vq] = "null"
- for fq in self.fqs:
- initialization[fq] = fq.func_op.sig.versioned_internal_name
- colors = self.run_color_refinement(initialization=initialization)
- v_colors = {vq: colors[vq] for vq in self.vqs}
- f_colors = {fq: colors[fq] for fq in self.fqs}
- if strict:
- assert len(set(v_colors.values())) == len(v_colors)
- assert len(set(f_colors.values())) == len(f_colors)
- canonical_topsort, sources, sinks = self.topsort(
- canonical_labels={**v_colors, **f_colors}
- )
- return v_colors, f_colors, canonical_topsort
-
- def project(
- self,
- ) -> Tuple[Dict[ValNode, ValNode], Dict[CallNode, CallNode], List[Node]]:
- v_colors, f_colors, topsort = self.canonicalize()
- colors: Dict[Node, str] = {**v_colors, **f_colors}
- return *self.get_projections(colors=colors), topsort
-
-
-def get_canonical_order(vqs: Set[ValNode], fqs: Set[CallNode]) -> List[ValNode]:
- g = InducedSubgraph(vqs=vqs, fqs=fqs)
- _, _, canonical_topsort = g.canonicalize()
- return [vq for vq in canonical_topsort if isinstance(vq, ValNode)]
-
-
-def copy_subgraph(
- vqs: Set[ValNode], fqs: Set[CallNode]
-) -> Tuple[Dict[ValNode, ValNode], Dict[CallNode, CallNode]]:
- """
- Copy the subgraph supported on the given nodes. Return maps from the
- original nodes to their copies.
- """
- v_map = {
- v: ValNode(tp=v.tp, constraint=v.constraint, name=v.name, refs=v.refs)
- for v in vqs
- }
- f_map = {}
- for fq in fqs:
- inputs = {k: v_map[v] for k, v in fq.inputs.items() if v in vqs}
- outputs = {k: v_map[v] for k, v in fq.outputs.items() if v in vqs}
- f_map[fq] = CallNode.link(
- inputs=inputs,
- outputs=outputs,
- func_op=fq.func_op,
- orientation=fq.orientation,
- constraint=fq.constraint,
- calls=fq.calls,
- )
- return v_map, f_map
diff --git a/mandala/queries/main.py b/mandala/queries/main.py
deleted file mode 100644
index 6e4ed6f..0000000
--- a/mandala/queries/main.py
+++ /dev/null
@@ -1,182 +0,0 @@
-from collections import Counter
-from ..common_imports import *
-from ..core.utils import OpKey
-from ..core.config import parse_output_idx
-from .weaver import ValNode, CallNode
-from .compiler import Compiler
-from .solver import NaiveQueryEngine
-from .viz import get_names
-from .graphs import (
- InducedSubgraph,
- is_connected,
- copy_subgraph,
-)
-
-
-class Querier:
- @staticmethod
- def check_df(
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- df: pd.DataFrame,
- funcs: Dict[str, Callable],
- ):
- """
- Check validity of a query result projected from the given graph against
- executables.
- """
- cols = set(df.columns)
- assert cols <= set(vq.name for vq in vqs)
- for fq in fqs:
- func = funcs[fq.func_op.sig.ui_name]
- input_cols: Dict[str, str] = {
- k: vq.name for k, vq in fq.inputs.items() if vq.name in cols
- }
- ouptut_cols: Dict[int, str] = {
- parse_output_idx(k): vq.name
- for k, vq in fq.outputs.items()
- if vq.name in cols
- }
- for i, row in df.iterrows():
- inputs = {k: row[v] for k, v in input_cols.items()}
- outputs = func(**inputs)
- if not (isinstance(outputs, tuple)):
- outputs = (outputs,)
- for j, v in enumerate(outputs):
- assert row[ouptut_cols[j]] == v
-
- @staticmethod
- def execute_naive(
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- selection: List[ValNode],
- memoization_tables: Dict[str, pd.DataFrame],
- filter_duplicates: bool,
- table_evaluator: Callable,
- visualize_steps_at: Optional[Path] = None,
- ) -> pd.DataFrame:
- v_copymap, f_copymap = copy_subgraph(vqs=vqs, fqs=fqs)
- vqs = set(v_copymap.values())
- fqs = set(f_copymap.values())
- select_copies = [v_copymap[vq] for vq in selection]
- tables = {
- f: memoization_tables[f.func_op.sig.versioned_internal_name] for f in fqs
- }
- query_graph = NaiveQueryEngine(
- vqs=vqs,
- fqs=fqs,
- selection=select_copies,
- tables=tables,
- _table_evaluator=table_evaluator,
- _visualize_steps_at=visualize_steps_at,
- )
- logger.debug("Solving query...")
- df = query_graph.solve()
- if filter_duplicates:
- df = df.drop_duplicates(keep="first")
- return df
-
- @staticmethod
- def compile(
- selection: List[ValNode],
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- version_constraints: Optional[Dict[OpKey, Optional[Set[str]]]],
- filter_duplicates: bool = True,
- call_uids: Optional[Dict[Tuple[str, int], List[str]]] = None,
- ) -> str:
- """
- Execute the given queries and return the result as a pandas DataFrame.
- """
- Querier.add_fq_constraints(fqs=fqs, call_uids=call_uids)
- compiler = Compiler(vqs=vqs, fqs=fqs)
- query = compiler.compile(
- select_queries=selection,
- filter_duplicates=filter_duplicates,
- semantic_version_constraints=version_constraints,
- )
- return query
-
- @staticmethod
- def add_fq_constraints(
- fqs: Set[CallNode], call_uids: Optional[Dict[Tuple[str, int], List[str]]]
- ):
- if call_uids is None:
- return
- for fq in fqs:
- if not fq.func_op.is_builtin:
- sig = fq.func_op.sig
- op_id = (sig.internal_name, sig.version)
- if op_id in call_uids:
- fq.constraint = call_uids[op_id]
-
- @staticmethod
- def validate_query(
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- selection: List[ValNode],
- names: Dict[ValNode, str],
- ):
- if not selection: # empty selection
- raise ValueError("Empty selection")
- if len(vqs) == 0: # empty query
- raise ValueError("Query is empty")
- if not is_connected(val_queries=vqs, func_queries=fqs): # disconnected query
- msg = f"Query is not connected! This could lead to a very large table.\n"
- logger.warning(msg)
- if not len(set(names.values())) == len(names): # duplicate names
- duplicates = [k for k, v in Counter(names.values()).items() if v > 1]
- raise ValueError("Duplicate names in value queries: " + str(duplicates))
-
- @staticmethod
- def prepare_projection_query(
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- selection: List[ValNode],
- name_hints: Dict[ValNode, str],
- ):
- graph = InducedSubgraph(vqs=vqs, fqs=fqs)
- v_map, f_map, _ = graph.project()
- validate_projection(
- source_selection=selection,
- v_map=v_map,
- source_selection_names={
- k: v for k, v in name_hints.items() if k in selection
- },
- )
- target_selection = [v_map[vq] for vq in selection]
- ### get the names in the projected graph
- g = InducedSubgraph(vqs=set(v_map.values()), fqs=set(f_map.values()))
- _, _, canonical_topsort = g.canonicalize()
- target_name_hints = {
- v_map[vq]: name for vq, name in name_hints.items() if vq in v_map.keys()
- }
- target_names = get_names(
- hints=target_name_hints,
- canonical_order=[vq for vq in canonical_topsort if isinstance(vq, ValNode)],
- )
- assert set(target_names.keys()) == set(v_map.values())
- return v_map, f_map, target_selection, target_names
-
-
-def validate_projection(
- source_selection: List[ValNode],
- v_map: Dict[ValNode, ValNode],
- source_selection_names: Dict[ValNode, str],
-):
- """
- Check that the selected nodes in the source project to distinct nodes in the
- target. Print out an error message if this is not the case.
-
- Here `names` is a (partial) dict from source nodes to names
- """
- fibers = defaultdict(list)
- for vq in source_selection:
- fibers[v_map[vq]].append(vq)
- if any(len(fiber) > 1 for fiber in fibers.values()):
- # find the first fiber with more than one query
- fiber = next(fiber for fiber in fibers.values() if len(fiber) > 1)
- raise ValueError(
- f"Ambiguous query: nodes {[source_selection_names.get(x, '?') for x in fiber]} have the "
- f"same role in the computational graph."
- )
diff --git a/mandala/queries/solver.py b/mandala/queries/solver.py
deleted file mode 100644
index c83d028..0000000
--- a/mandala/queries/solver.py
+++ /dev/null
@@ -1,264 +0,0 @@
-from ..common_imports import *
-from ..core.model import FuncOp
-from ..core.sig import Signature
-from ..core.config import Config
-from .weaver import ValNode, CallNode, traverse_all
-from .graphs import InducedSubgraph
-from .viz import visualize_graph, get_names
-from ..core.utils import invert_dict
-
-
-class NaiveQueryEngine:
- """
- Represents the graph expressing a query in a form that is suitable for
- incrementally computing the join of all the tables.
-
- Used as an alternative to a RDBMS engine for computing queries.
-
- Should work with induced subgraphs
- """
-
- def __init__(
- self,
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- selection: List[ValNode],
- tables: Dict[CallNode, pd.DataFrame],
- _table_evaluator: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
- _visualize_steps_at: Optional[Path] = None,
- ):
- self.vqs = vqs
- self.fqs = list(fqs)
- self.g = InducedSubgraph(vqs=vqs, fqs=fqs)
- self.selection = selection
- # {func query: table of data}. Note that there may be multiple func
- # queries with the same table, but we keep separate references to each
- # in order to enable recursively joining nodes in the graph.
-
- # pass to induced tables wrt the graph
- self.induce_tables(tables=tables)
- self.tables = tables
- for k, v in self.tables.items():
- for col in Config.special_call_cols:
- if col in v.columns:
- v.drop(columns=[col], inplace=True)
-
- # for visualization
- self._visualize_intermediate_states = _visualize_steps_at is not None
- self._table_evaluator = _table_evaluator
- self._visualize_steps_at = _visualize_steps_at
- if self._visualize_intermediate_states:
- assert self._table_evaluator is not None
-
- def induce_tables(self, tables: Dict[CallNode, pd.DataFrame]):
- for fq, df in tables.items():
- inps, outps = self.get_fq_inputs(fq), self.get_fq_outputs(fq)
- induced_keys = set(inps.keys()) | set(outps.keys())
- df.drop(
- columns=[c for c in df.columns if c not in induced_keys], inplace=True
- )
-
- def get_fq_inputs(self, fq: CallNode) -> Dict[str, ValNode]:
- if fq in self.g.fqs:
- return self.g.fq_inputs(fq=fq)
- else:
- return fq.inputs
-
- def get_fq_outputs(self, fq: CallNode) -> Dict[str, ValNode]:
- if fq in self.g.fqs:
- return self.g.fq_outputs(fq=fq)
- else:
- return fq.outputs
-
- def _get_col_to_vq_mappings(
- self, func_query: CallNode
- ) -> Tuple[Dict[str, ValNode], Dict[ValNode, List[str]]]:
- """
- Given a FuncQuery, returns:
- - a mapping from column names to the ValQuery that they point to
- - a mapping from ValQuery objects to the list of column names that point to it
- """
- df = self.tables[func_query]
- col_to_vq = {}
- vq_to_cols = defaultdict(list)
- for name, val_query in self.get_fq_inputs(func_query).items():
- assert name in df.columns
- col_to_vq[name] = val_query
- vq_to_cols[val_query].append(name)
- for output_name, val_query in self.get_fq_outputs(func_query).items():
- assert output_name in df.columns
- col_to_vq[output_name] = val_query
- vq_to_cols[val_query].append(output_name)
- return col_to_vq, vq_to_cols
-
- def _drop_self_constraints(
- self, df: pd.DataFrame, vq_to_cols: Dict[ValNode, List[str]]
- ) -> Tuple[pd.DataFrame, Dict[str, ValNode], Dict[ValNode, str]]:
- new_col_to_vq = {}
- df = df.copy()
- for vq, cols in vq_to_cols.items():
- if len(cols) > 1:
- representative = cols[0]
- for col in cols[1:]:
- df = df[df[representative] == df[col]]
- df.drop(columns=col, inplace=True)
- new_col_to_vq[cols[0]] = vq
- new_vq_to_col = {vq: col for col, vq in new_col_to_vq.items()}
- return df, new_col_to_vq, new_vq_to_col
-
- def _join_dataframes(
- self,
- df1: pd.DataFrame,
- df2: pd.DataFrame,
- left_on: List[str],
- right_on: List[str],
- ) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, str]]:
- """
- Join two dataframes along the specified dimensions.
-
- Returns
- - the result,
- - together with coprojections from the columns of each dataframe to
- the columns of the result.
- """
- assert len(left_on) == len(right_on)
- # rename the dataframe columns to avoid conflicts
- mapping1 = {col: f"input_{i}" for i, col in enumerate(df1.columns)}
- ncols1 = len(df1.columns)
- mapping2 = {col: f"input_{i + ncols1}" for i, col in enumerate(df2.columns)}
- renamed_df1 = df1.rename(columns=mapping1)
- renamed_df2 = df2.rename(columns=mapping2)
- # join the dataframes
- logger.info(f"Joining tables of shapes {df1.shape} and {df2.shape}...")
- start = time.time()
- if len(left_on) == 0:
- df = renamed_df1.merge(renamed_df2, how="cross")
- else:
- df = renamed_df1.merge(
- renamed_df2,
- left_on=[mapping1[col] for col in left_on],
- right_on=[mapping2[col] for col in right_on],
- )
- end = time.time()
- logger.info(f"Join took {round(end - start, 3)} seconds")
- # drop duplicate columns from the *right* dataframe
- df.drop(columns=[mapping2[col] for col in right_on], inplace=True)
- # construct the mapping functions from the columns of each dataframe to
- # the columns of the result
- from1 = mapping1
- from2 = {}
- for col in df2.columns:
- if mapping2[col] in df.columns:
- # first, assign the columns that we did not drop
- from2[col] = mapping2[col]
- for left_col, right_col in zip(left_on, right_on):
- # fill in the remaining ones
- from2[right_col] = mapping1[left_col]
- return df, from1, from2
-
- def merge(self, f1: CallNode, f2: CallNode):
- """
- Merge two func query nodes in the graph by joining their tables along
- the columns that correspond to the shared inputs/outputs
- """
- # get the data
- df1, df2 = self.tables[f1], self.tables[f2]
- # compute correspondence between columns and vqs
- col_to_vq1, vq_to_cols1 = self._get_col_to_vq_mappings(f1)
- col_to_vq2, vq_to_cols2 = self._get_col_to_vq_mappings(f2)
- # apply self-join constraints
- df1, col_to_vq1, vq_to_col1 = self._drop_self_constraints(df1, vq_to_cols1)
- df2, col_to_vq2, vq_to_col2 = self._drop_self_constraints(df2, vq_to_cols2)
- # compute the pairs of columns along which we need to join
- # {shared value query: (col1, col2)}
- intersection_vq_to_col_pairs = OrderedDict({})
- for col, vq in col_to_vq1.items():
- if vq in col_to_vq2.values():
- intersection_vq_to_col_pairs[vq] = (col, vq_to_col2[vq])
- left_on = [col for _, (col, _) in intersection_vq_to_col_pairs.items()]
- right_on = [col for _, (_, col) in intersection_vq_to_col_pairs.items()]
- df, from1, from2 = self._join_dataframes(
- df1=df1, df2=df2, left_on=left_on, right_on=right_on
- )
- # get the correspondence between columns and vqs for the new table
- inputs = {}
- for col in df.columns:
- if col in from1.values():
- col_1 = invert_dict(from1)[col]
- inputs[col] = col_to_vq1[col_1]
- elif col in from2.values():
- col_2 = invert_dict(from2)[col]
- inputs[col] = col_to_vq2[col_2]
- else:
- raise ValueError()
- # insert new func query
- sig = Signature(
- ui_name="internal_node",
- input_names=set(inputs.keys()),
- n_outputs=0,
- version=0,
- defaults={},
- input_annotations={k: Any for k in inputs.keys()},
- output_annotations=[Any for _ in range(0)],
- )
- func_op = FuncOp._from_sig(sig=sig)
- f = CallNode(inputs=inputs, func_op=func_op, outputs={}, constraint=None)
- for k, v in inputs.items():
- v.add_consumer(consumer=f, consumed_as=k)
- self.tables[f] = df
- self.fqs.append(f)
- # remove the old func queries from the graph
- f1.unlink()
- f2.unlink()
- self.fqs = [f for f in self.fqs if f not in (f1, f2)]
- del self.tables[f1], self.tables[f2]
-
- ### solver and solver utils
- def compute_intersection_size(self, f1: CallNode, f2: CallNode) -> int:
- col_to_vq1, vq_to_cols1 = self._get_col_to_vq_mappings(f1)
- col_to_vq2, vq_to_cols2 = self._get_col_to_vq_mappings(f2)
- return len(set(col_to_vq1.values()) & set(col_to_vq2.values()))
-
- def _visualize_state(self, step_num: int):
- assert self._visualize_intermediate_states
- val_queries, func_queries = traverse_all(vqs=self.selection)
- memoization_tables = {
- k: self._table_evaluator(v) for k, v in self.tables.items()
- }
- visualize_graph(
- vqs=val_queries,
- fqs=func_queries,
- names=get_names(hints={}, canonical_order=list(val_queries)),
- output_path=self._visualize_steps_at / f"{step_num}.svg",
- layout="bipartite",
- memoization_tables=memoization_tables,
- )
-
- def solve(self, verbose: bool = False) -> pd.DataFrame:
- step_num = 0
- while len(self.fqs) > 1:
- if verbose:
- logger.info(f"step {step_num}")
- if self._visualize_intermediate_states:
- self._visualize_state(step_num=step_num)
- intersections = {}
- # compute pairwise intersections
- for f1 in self.fqs:
- for f2 in self.fqs:
- if f1 == f2:
- continue
- intersections[(f1, f2)] = self.compute_intersection_size(f1, f2)
- # pick the pair with the largest intersection
- f1, f2 = max(intersections, key=lambda x: intersections.get(x, 0))
- # merge the pair
- self.merge(f1, f2)
- step_num += 1
- if self._visualize_intermediate_states:
- self._visualize_state(step_num=step_num)
- assert len(self.tables) == 1
- df = self.tables[self.fqs[0]]
- f = self.fqs[0]
- # figure out which columns to select
- col_to_vq, vq_to_cols = self._get_col_to_vq_mappings(f)
- cols = [vq_to_cols[vq][0] for vq in self.selection]
- return df[cols]
diff --git a/mandala/queries/viz.py b/mandala/queries/viz.py
deleted file mode 100644
index 6492bd5..0000000
--- a/mandala/queries/viz.py
+++ /dev/null
@@ -1,600 +0,0 @@
-from abc import ABC, abstractmethod
-from ..common_imports import *
-from typing import Literal
-from ..core.config import parse_output_idx, Config
-from ..core.model import Ref
-from ..core.wrapping import unwrap
-from ..core.tps import ListType, StructType, DictType, SetType
-from .weaver import (
- ValNode,
- CallNode,
- StructOrientations,
- get_items,
- get_elts,
- get_elt_and_struct,
- get_idx,
- is_key,
- is_idx,
- traverse_all,
- get_vq_orientation,
-)
-from .graphs import InducedSubgraph, get_canonical_order
-from ..ui.viz import (
- Node,
- Edge,
- SOLARIZED_LIGHT,
- to_dot_string,
- write_output,
- HTMLBuilder,
- Cell,
-)
-import textwrap
-
-
-class ValueLoaderABC(ABC):
- @abstractmethod
- def load_value(self, full_uid: str) -> Any:
- raise NotImplementedError
-
-
-class GraphPrinter:
- def __init__(
- self,
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- value_loader: Optional[ValueLoaderABC] = None,
- names: Optional[Dict[ValNode, str]] = None,
- fnames: Optional[Dict[CallNode, str]] = None,
- ):
- self.vqs = vqs
- self.fqs = fqs
- self.value_loader = value_loader
- self.g = InducedSubgraph(vqs=vqs, fqs=fqs)
- self.v_labels, self.f_labels, self.vq_to_node = self.g.canonicalize()
- self.full_topsort, self.sources, self.sinks = self.g.topsort(
- canonical_labels={**self.v_labels, **self.f_labels}
- )
- if names is None:
- names = get_names(
- hints={}, canonical_order=get_canonical_order(vqs=vqs, fqs=fqs)
- )
- self.names = names
- self.fnames = fnames
-
- def get_struct_comment(
- self,
- vq: ValNode,
- elt_names: Tuple[str, ...],
- idx_names: Optional[Tuple[str, ...]] = None,
- ) -> str:
- id_to_name = {"__list__": "list", "__dict__": "dict", "__set__": "set"}
- if len(elt_names) == 1:
- s = f"{self.names[vq]} will match any {id_to_name[vq.tp.struct_id]} containing a match for {elt_names[0]}"
- if idx_names is not None:
- s += f" at index {idx_names[0]}"
- else:
- s = f'{self.names[vq]} will match any {id_to_name[vq.tp.struct_id]} containing matches for each of {", ".join(elt_names)}'
- if idx_names is not None:
- s += f" at indices {', '.join(idx_names)}"
- return s
-
- def get_source_comment(self, vq: ValNode) -> str:
- if is_idx(vq=vq):
- return "index into list"
- elif is_key(vq=vq):
- return "key into dict"
- else:
- return "input to computation; can match anything"
-
- def get_construct_computation_rhs(self, node: ValNode) -> str:
- """
- Given a constructive struct, return
- [name_0, ..., name_1] for lists
- {key_0: name_0, ..., key_n: name_n} for dicts
- {name_0, ..., name_n} for sets
- """
- if isinstance(node.tp, (ListType, DictType)):
- idxs_and_elts = list(get_items(vq=node).values())
- idxs = [x[0] for x in idxs_and_elts]
- idx_values = [
- unwrap(self.value_loader.load_value(full_uid=idx.constraint[0]))
- if idx.constraint is not None
- else i
- for i, idx in enumerate(idxs)
- ]
- elts = [x[1] for x in idxs_and_elts]
- elt_names = tuple([str(self.names[elt]) for elt in elts])
- idx_value_to_elt_name = {
- idx_value: elt_name
- for idx_value, elt_name in zip(idx_values, elt_names)
- }
- if isinstance(node.tp, ListType):
- assert sorted(idx_value_to_elt_name.keys()) == list(range(len(elts)))
- rhs = f'[{", ".join([idx_value_to_elt_name[i] for i in range(len(elts))])}]'
- elif isinstance(node.tp, DictType):
- elt_strings = [
- f"'{idx_value}': {elt_name}"
- for idx_value, elt_name in idx_value_to_elt_name.items()
- ]
- rhs = f'{", ".join(elt_strings)}'
- rhs = f"{{{rhs}}}"
- else:
- raise RuntimeError
- elif isinstance(node.tp, SetType):
- elts = list(get_elts(vq=node).values())
- elt_names = tuple([str(self.names[elt]) for elt in elts])
- rhs = f'{", ".join(elt_names)}'
- rhs = f"{{{rhs}}}"
- else:
- raise ValueError
- return rhs
-
- def get_construct_query_rhs(self, node: ValNode) -> str:
- if isinstance(node.tp, (ListType, DictType)):
- idxs_and_elts = list(get_items(vq=node).values())
- idxs = [x[0] for x in idxs_and_elts]
- elts = [x[1] for x in idxs_and_elts]
- elt_names = tuple([str(self.names[elt]) for elt in elts])
- idx_names = tuple(
- [str(self.names[idx]) if idx is not None else "?" for idx in idxs]
- )
- comment = self.get_struct_comment(
- vq=node, elt_names=elt_names, idx_names=idx_names
- )
- if isinstance(node.tp, ListType):
- rhs = f'ListQ(elts=[{", ".join(elt_names)}], idxs=[{", ".join(idx_names)}]) # {comment}'
- elif isinstance(node.tp, DictType):
- elt_strings = [
- f"{idx_name}: {elt_name}"
- for idx_name, elt_name in zip(idx_names, elt_names)
- ]
- rhs = f'DictQ(dct={{..., {", ".join(elt_strings)}, ...}}) # {comment}'
- else:
- raise RuntimeError
- elif isinstance(node.tp, SetType):
- elts = list(get_elts(vq=node).values())
- elt_names = tuple([str(self.names[elt]) for elt in elts])
- comment = self.get_struct_comment(vq=node, elt_names=elt_names)
- rhs = f'SetQ(elts=[{", ".join(elt_names)}]) # {comment}'
- else:
- raise ValueError
- return rhs
-
- def get_op_computation_line(
- self,
- node: CallNode,
- require_all_inputs: bool = False,
- add_node_name: bool = True,
- ) -> str:
- """
- Get the line of code representing a function call of the form
- output_0, output_1 = func_name(input_0=input_0, input_1=input_1)
- """
- full_returns = node.returns_interp
- full_returns_names = [
- self.names[vq] if vq is not None else "_" for vq in full_returns
- ]
- lhs = ", ".join(full_returns_names)
- # rhs
- args_dict = {arg_name: self.names[vq] for arg_name, vq in node.inputs.items()}
- for inp_name in node.func_op.sig.input_names:
- if inp_name not in args_dict and require_all_inputs:
- raise RuntimeError
- args_string = ", ".join(
- [f"{k}={v}" for k, v in args_dict.items()]
- + [f"{k}=_" for k in node.func_op.sig.input_names if k not in args_dict]
- )
- rhs = f"{node.func_op.sig.ui_name}({args_string})"
- if add_node_name:
- rhs = rhs + " # OP NAME: " + self.fnames[node]
- if len(lhs) == 0:
- return rhs
- else:
- return f"{lhs} = {rhs}"
-
- def get_op_query_line(self, node: CallNode) -> str:
- # lhs
- full_returns = node.returns_interp
- full_returns_names = [
- self.names[vq] if vq is not None else "_" for vq in full_returns
- ]
- lhs = ", ".join(full_returns_names)
- # rhs
- args_dict = {arg_name: self.names[vq] for arg_name, vq in node.inputs.items()}
- for inp_name in node.func_op.sig.input_names:
- if inp_name not in args_dict:
- args_dict[inp_name] = "Q()"
- args_string = ", ".join([f"{k}={v}" for k, v in args_dict.items()])
- rhs = f"{node.func_op.sig.ui_name}({args_string})"
- if len(lhs) == 0:
- return rhs
- else:
- return f"{lhs} = {rhs}"
-
- def get_destruct_computation_line(self, node: CallNode) -> str:
- elt, struct = get_elt_and_struct(fq=node)
- idx = get_idx(fq=node)
- lhs = f"{self.names[elt]}"
- if idx is None:
- raise NotImplementedError
- # inline the index value if the index is a source in the graph
- if idx in self.sources:
- if idx.constraint is None:
- rhs = f"{self.names[struct]}[{self.names[idx]}]"
- else:
- assert len(idx.constraint) == 1
- idx_ref = self.value_loader.load_value(full_uid=idx.constraint[0])
- idx_value = unwrap(idx_ref)
- if isinstance(idx_value, int):
- rhs = f"{self.names[struct]}[{idx_value}]"
- elif isinstance(idx_value, str):
- rhs = f"{self.names[struct]}['{idx_value}']"
- else:
- raise RuntimeError
- else:
- rhs = f"{self.names[struct]}[{self.names[idx]}]"
- return f"{lhs} = {rhs}"
-
- def get_destruct_query_line(self, node: CallNode) -> str:
- elt, struct = get_elt_and_struct(fq=node)
- idx = get_idx(fq=node)
- lhs = f"{self.names[elt]}"
- if idx is None:
- idx_label = "?"
- else:
- idx_label = self.names[idx]
- rhs = f"{self.names[struct]}[{idx_label}] # {self.names[elt]} will match any element of a match for {self.names[struct]} at index matching {idx_label}"
- return f"{lhs} = {rhs}"
-
- def print_computational_graph(
- self, show_sources_as: Literal["values", "uids", "omit", "name_only"] = "values"
- ) -> str:
- res = []
- for node in self.full_topsort:
- if isinstance(node, ValNode):
- if node in self.sources:
- if is_idx(node) or is_key(node):
- # exclude indices/keys if they are sources (we will inline them)
- continue
- if show_sources_as == "omit":
- continue
- elif show_sources_as == "name_only":
- res.append(self.names[node])
- continue
- assert node.constraint is not None
- assert len(node.constraint) == 1
- if show_sources_as == "values":
- ref = self.value_loader.load_value(full_uid=node.constraint[0])
- value = unwrap(ref)
- rep = textwrap.shorten(repr(value), 25)
- res.append(f"{self.names[node]} = {rep}")
- elif show_sources_as == "uids":
- uid, causal_uid = Ref.parse_full_uid(
- full_uid=node.constraint[0]
- )
- res.append(
- f"{self.names[node]} = Ref(uid={uid}, causal_uid={causal_uid})"
- )
- else:
- raise ValueError
- elif isinstance(node.tp, StructType):
- if get_vq_orientation(node) != StructOrientations.construct:
- continue
- rhs = self.get_construct_computation_rhs(node=node)
- lhs = f"{self.names[node]}"
- line = f"{lhs} = {rhs}"
- res.append(line)
- elif isinstance(node, CallNode):
- if not node.func_op.is_builtin:
- res.append(self.get_op_computation_line(node=node))
- else:
- if node.orientation == StructOrientations.destruct:
- res.append(self.get_destruct_computation_line(node=node))
- return "\n".join(res)
-
- def print_query_graph(
- self,
- selection: Optional[List[ValNode]],
- pprint: bool = False,
- ):
- res = []
- for node in self.full_topsort:
- if isinstance(node, ValNode):
- if node in self.sources:
- comment = self.get_source_comment(vq=node)
- res.append(f"{self.names[node]} = Q() # {comment}")
- elif isinstance(node.tp, StructType):
- if get_vq_orientation(node) != StructOrientations.construct:
- continue
- rhs = self.get_construct_query_rhs(node=node)
- lhs = f"{self.names[node]}"
- line = f"{lhs} = {rhs}"
- res.append(line)
- elif isinstance(node, CallNode):
- if not node.func_op.is_builtin:
- res.append(self.get_op_query_line(node=node))
- else:
- if node.orientation == StructOrientations.destruct:
- res.append(self.get_destruct_query_line(node=node))
- if selection is not None:
- res.append(
- f"result = storage.df({', '.join([self.names[vq] for vq in selection])})"
- )
- res = [textwrap.indent(line, " ") for line in res]
- if Config.has_rich and pprint:
- from rich.syntax import Syntax
-
- highlighted = Syntax(
- "\n".join(res),
- "python",
- theme="solarized-light",
- line_numbers=False,
- )
- # return Panel.fit(highlighted, title="Computational Graph")
- return highlighted
- else:
- return "\n".join(res)
-
-
-def print_graph(
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- names: Dict[ValNode, str],
- selection: Optional[List[ValNode]],
- pprint: bool = False,
-):
- printer = GraphPrinter(vqs=vqs, fqs=fqs, names=names)
- s = printer.print_query_graph(selection=selection)
- if Config.has_rich and pprint:
- rich.print(s)
- else:
- print(s)
-
-
-def graph_to_dot(
- vqs: List[ValNode],
- fqs: List[CallNode],
- names: Dict[ValNode, str],
- layout: Literal["computational", "bipartite"] = "computational",
- memoization_tables: Optional[Dict[CallNode, pd.DataFrame]] = None,
-) -> str:
- # should work for subgraphs
- assert set(vqs) <= set(names.keys())
- nodes = {} # val/op query -> Node obj
- edges = []
- col_names = []
- counter = 0
- for vq in vqs:
- if names[vq] is not None:
- col_names.append(names[vq])
- else:
- col_names.append(f"unnamed_{counter}")
- counter += 1
- for vq, col_name in zip(vqs, col_names):
- html_label = HTMLBuilder()
- if hasattr(vq, "_hidden_message"):
- msg = vq._hidden_message
- text = f"{col_name} ({msg})"
- else:
- text = str(col_name)
- html_label.add_row(
- cells=[
- Cell(
- text=text,
- port=None,
- bold=True,
- bgcolor=SOLARIZED_LIGHT["orange"],
- font_color=SOLARIZED_LIGHT["base3"],
- )
- ]
- )
- node = Node(
- internal_name=str(id(vq)),
- # label=str(val_query.column_name),
- label=html_label.to_html_like_label(),
- color=SOLARIZED_LIGHT["blue"],
- shape="plain",
- )
- nodes[vq] = node
- for func_query in fqs:
- html_label = HTMLBuilder()
- func_preview = (
- # f'{func_query.displayname}({", ".join(func_query.inputs.keys())})'
- func_query.displayname
- )
- if hasattr(func_query, "_hidden_message"):
- msg = func_query._hidden_message
- func_preview = f"{func_preview} ({msg})"
- title_cell = Cell(
- text=func_preview,
- port=None,
- bgcolor=SOLARIZED_LIGHT["blue"],
- bold=True,
- font_color=SOLARIZED_LIGHT["base3"],
- )
- input_cells = []
- output_cells = []
- for input_name in func_query.inputs.keys():
- input_cells.append(Cell(text=input_name, port=input_name, bold=True))
- # html_label.add_row(elts=[Cell(text=input_name, port=input_name)])
- for output_idx in range(len(func_query.outputs)):
- output_cells.append(
- Cell(
- text=f"output_{output_idx}", port=f"output_{output_idx}", bold=True
- )
- )
- if layout == "bipartite":
- html_label.add_row(cells=[title_cell])
- if len(input_cells + output_cells) > 0:
- html_label.add_row(cells=input_cells + output_cells)
- if memoization_tables is not None:
- column_names = [cell.text for cell in input_cells + output_cells]
- port_names = [cell.port for cell in input_cells + output_cells]
- # remove ports from the table column cells
- for cell in input_cells + output_cells:
- cell.port = None
- df = memoization_tables[func_query][column_names]
- rows = list(df.head().itertuples(index=False))
- for tup in rows:
- html_label.add_row(
- cells=[
- Cell(text=textwrap.shorten(str(x), 25), bold=True)
- for x in tup
- ]
- )
- # add port names to the cells in the *last* row
- for cell, port_name in zip(html_label.rows[-1], port_names):
- cell.port = port_name
- elif layout == "computational":
- if len(input_cells) > 0:
- html_label.add_row(input_cells)
- html_label.add_row([title_cell])
- if len(output_cells) > 0:
- html_label.add_row(output_cells)
- else:
- raise ValueError(f"Unknown layout: {layout}")
- node = Node(
- internal_name=str(id(func_query)),
- # label=str(func_query.func_op.sig.ui_name),
- label=html_label.to_html_like_label(),
- color=SOLARIZED_LIGHT["red"],
- shape="plain",
- )
- nodes[func_query] = node
- for input_name, vq in func_query.inputs.items():
- if not vq in nodes:
- continue
- if layout == "bipartite":
- edges.append(
- Edge(
- target_node=nodes[vq],
- source_node=nodes[func_query],
- source_port=input_name,
- arrowtail="none",
- arrowhead="none",
- )
- )
- elif layout == "computational":
- edges.append(
- Edge(
- source_node=nodes[vq],
- target_node=nodes[func_query],
- target_port=input_name,
- )
- )
- else:
- raise ValueError(f"Unknown layout: {layout}")
- for output_name, vq in func_query.outputs.items():
- if not vq in nodes:
- continue
- edges.append(
- Edge(
- source_node=nodes[func_query],
- target_node=nodes[vq],
- source_port=output_name,
- arrowtail="none" if layout == "bipartite" else None,
- arrowhead="none" if layout == "bipartite" else None,
- )
- )
- return to_dot_string(nodes=list(nodes.values()), edges=edges, groups=[])
-
-
-def visualize_graph(
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- names: Optional[Dict[ValNode, str]],
- layout: Literal["computational", "bipartite"] = "computational",
- memoization_tables: Optional[Dict[CallNode, pd.DataFrame]] = None,
- output_path: Optional[Path] = None,
- show_how: Literal["none", "browser", "inline", "open"] = "none",
-):
- if names is None:
- names = get_names(
- hints={}, canonical_order=get_canonical_order(vqs=vqs, fqs=fqs)
- )
- dot_string = graph_to_dot(
- vqs=list(vqs),
- fqs=list(fqs),
- names=names,
- layout=layout,
- memoization_tables=memoization_tables,
- )
- if output_path is None:
- tempfile_obj, output_name = tempfile.mkstemp(suffix=".svg")
- output_path = Path(output_name)
- write_output(
- output_path=output_path,
- dot_string=dot_string,
- output_ext="svg",
- show_how=show_how,
- )
- return output_path
-
-
-def show(*vqs: ValNode):
- vqs, fqs = traverse_all(vqs=list(vqs), direction="both")
- visualize_graph(vqs=vqs, fqs=fqs, show_how="browser")
-
-
-def extract_names_from_scope(scope: Dict[str, Any]) -> Dict[ValNode, str]:
- """
- Heuristic to get deterministic name for all ValQueries we can find in the
- scope.
- """
- names_per_vq = defaultdict(list)
- for k, v in scope.items():
- if k.startswith("_"):
- continue
- if isinstance(v, Ref) and v._query is not None:
- names_per_vq[v.query].append(k)
- elif isinstance(v, ValNode):
- names_per_vq[v].append(k)
- res = {}
- for vq, names_per_vq in names_per_vq.items():
- if len(names_per_vq) > 1:
- logger.warning(
- f"Found multiple names for {vq}: {names_per_vq}, choosing {sorted(names_per_vq)[0]}"
- )
- res[vq] = names_per_vq[0]
- return res
-
-
-def get_names(
- hints: Dict[ValNode, str], canonical_order: List[ValNode]
-) -> Dict[ValNode, str]:
- """
- Get names for the given oredered list of ValQueries, using the following
- priority:
- - from the object itself;
- - the hints;
- - a_{counter} for the rest.
- """
- counter = Count()
- idx_counter = Count()
- key_counter = Count()
- existing_names = set(hints.values()) | {
- vq.name for vq in canonical_order if vq.name is not None
- }
- res = {}
- for vq in canonical_order:
- if vq.name is not None:
- res[vq] = vq.name
- elif vq in hints:
- res[vq] = hints[vq]
- else:
- if is_key(vq):
- c, prefix = key_counter, "key"
- elif is_idx(vq):
- c, prefix = idx_counter, "idx"
- else:
- c, prefix = counter, "a"
- while f"{prefix}{c.i}" in existing_names:
- c.i += 1
- res[vq] = f"{prefix}{c.i}"
- c.i += 1
- return res
-
-
-class Count:
- def __init__(self):
- self.i = 0
diff --git a/mandala/queries/weaver.py b/mandala/queries/weaver.py
deleted file mode 100644
index 5e6f486..0000000
--- a/mandala/queries/weaver.py
+++ /dev/null
@@ -1,710 +0,0 @@
-from abc import ABC, abstractmethod
-from ..common_imports import *
-from ..core.model import FuncOp, Ref, wrap_atom, Call
-from ..core.wrapping import causify_atom
-from ..core.config import dump_output_name, parse_output_idx
-from ..core.tps import Type, ListType, DictType, SetType, AnyType, StructType
-from ..core.builtins_ import Builtins, ListRef, DictRef, SetRef
-from ..core.utils import Hashing, concat_lists, invert_dict
-from typing import Literal
-from typing import Sequence, Iterator
-
-T = TypeVar("T")
-T1 = TypeVar("T1")
-
-
-class StructOrientations:
- # at runtime only
- construct = "construct"
- destruct = "destruct"
-
-
-class Node(ABC):
- @abstractmethod
- def neighbors(
- self, direction: Literal["backward", "forward", "both"] = "both"
- ) -> List["Node"]:
- raise NotImplementedError
-
-
-class ValNode(Node):
- def __init__(
- self,
- constraint: Optional[List[str]],
- tp: Optional[Type],
- creators: Optional[List["CallNode"]] = None,
- created_as: Optional[List[str]] = None,
- _label: Optional[str] = None,
- name: Optional[str] = None,
- refs: Optional[List[Ref]] = None,
- ):
- self.creators: List["CallNode"] = [] if creators is None else creators
- self.created_as = [] if created_as is None else created_as
- self.consumers: List["CallNode"] = []
- self.consumed_as: List[str] = []
- self.name = name
- self.constraint = constraint
- self.tp = tp
- self._label: Optional[str] = _label
- self._refs: Optional[List[Ref]] = None
- self._refs_hash: Optional[str] = None
-
- if refs is not None:
- self.refs = refs
-
- @property
- def refs(self) -> List[Ref]:
- return self._refs
-
- @refs.setter
- def refs(self, value: List[Ref]):
- # a single assignment expression guarantees no broken state
- self._refs, self._refs_hash = value, ValNode.get_refs_hash(value)
-
- @staticmethod
- def get_refs_hash(refs: List[Ref]) -> str:
- return Hashing.get_content_hash(
- obj=[r.causal_uid for r in refs],
- )
-
- @property
- def refs_hash(self) -> str:
- if self._refs_hash is None:
- raise ValueError("No refs")
- return self._refs_hash
-
- def add_consumer(self, consumer: "CallNode", consumed_as: str):
- self.consumers.append(consumer)
- self.consumed_as.append(consumed_as)
-
- def add_creator(self, creator: "CallNode", created_as: str):
- assert isinstance(created_as, str)
- self.creators.append(creator)
- self.created_as.append(created_as)
-
- def neighbors(
- self, direction: Literal["backward", "forward", "both"] = "both"
- ) -> Set["CallNode"]:
- backward = self.creators
- forward = self.consumers
- if direction == "backward":
- return set(backward)
- elif direction == "forward":
- return set(forward)
- elif direction == "both":
- return set(backward + forward)
-
- def named(self, name: str) -> Any:
- self.name = name
- return self
-
- def __getitem__(self, idx: Union[int, str, "ValNode"]) -> "ValNode":
- tp = self.tp or _infer_type(self)
- if isinstance(tp, ListType):
- return BuiltinQueries.GetListItemQuery(
- lst=self, idx=qwrap(idx, tp=tp.elt_type)
- )
- elif isinstance(tp, DictType):
- assert isinstance(idx, (ValNode, str))
- return BuiltinQueries.GetDictItemQuery(
- dct=self, key=qwrap(idx, tp=tp.elt_type)
- )
- else:
- raise NotImplementedError(f"Cannot index into query of type {tp}")
-
- def __repr__(self):
- if self.name is not None:
- return f"ValNode({self.name})"
- else:
- return f"ValNode({self.tp})"
-
- def inplace_mask(self, mask: np.ndarray):
- """
- Inplace mask the refs of this `ValQuery` object.
- """
- if self.refs is None:
- raise ValueError("No refs to mask")
- if isinstance(mask, np.ndarray):
- assert mask.dtype == np.dtype("bool")
- assert mask.shape[0] == len(self.refs)
- self.refs = [r for r, m in zip(self.refs, mask) if m]
- else:
- raise NotImplementedError("Indexing only supported for boolean arrays")
-
- def get_constraint(self, *values) -> List[str]:
- wrapped = [wrap_atom(v) for v in values]
- for w in wrapped:
- causify_atom(w)
- return [w.full_uid for w in wrapped]
-
- def pin(self, *values):
- if len(values) == 0:
- raise ValueError("Must pin to at least one value")
- else:
- self.constraint = self.get_constraint(*values)
-
- def unpin(self):
- self.constraint = None
-
-
-def copy(vq: ValNode, label: Optional[str] = None) -> ValNode:
- return ValNode(
- constraint=vq.constraint,
- tp=vq.tp,
- creators=vq.creators,
- created_as=vq.created_as,
- _label=label,
- )
-
-
-class CallNode(Node):
- def __init__(
- self,
- inputs: Dict[str, ValNode],
- func_op: FuncOp,
- outputs: Dict[str, ValNode],
- constraint: Optional[List[str]],
- orientation: Optional[str] = None,
- calls: Optional[List[Call]] = None,
- ):
- self.func_op = func_op
- self.inputs = inputs
- self.outputs = outputs
- self.orientation = orientation
- self.constraint = constraint
- self._calls: Optional[List[Call]] = None
- self._calls_hash: Optional[str] = None
-
- if calls is not None:
- self.calls = calls
-
- @property
- def calls(self) -> List[Call]:
- return self._calls
-
- @property
- def calls_hash(self) -> str:
- if self._calls_hash is None:
- raise ValueError("No df")
- return self._calls_hash
-
- @calls.setter
- def calls(self, value: List[Call]):
- # a single assignment expression guarantees no broken state
- self._calls, self._calls_hash = value, CallNode.get_calls_hash(value)
-
- @staticmethod
- def get_calls_hash(calls: List[Call]) -> str:
- return Hashing.get_content_hash(
- [c.causal_uid for c in calls],
- )
-
- def inplace_mask(self, mask: np.ndarray):
- """
- Inplace mask the call UIDs of this `FuncQuery` object.
- """
- if self.calls is None:
- raise ValueError("No call UIDs to mask")
- if isinstance(mask, np.ndarray):
- assert mask.dtype == np.dtype("bool")
- assert mask.shape[0] == len(self.calls)
- self.calls = [c for c, m in zip(self.calls, mask) if m]
- else:
- raise NotImplementedError("Indexing only supported for boolean arrays")
-
- def set_outputs(self, outputs: Dict[str, ValNode]):
- self.outputs = outputs
-
- @property
- def returns(self) -> List[ValNode]:
- if self.func_op.is_builtin:
- raise NotImplementedError()
- else:
- ord_outputs = {parse_output_idx(k): v for k, v in self.outputs.items()}
- ord_outputs = [ord_outputs[i] for i in range(len(ord_outputs))]
- return ord_outputs
-
- @property
- def returns_interp(self) -> List[Optional[ValNode]]:
- if self.func_op.is_builtin:
- raise NotImplementedError()
- else:
- ord_outputs = {parse_output_idx(k): v for k, v in self.outputs.items()}
- res = []
- for i in range(self.func_op.sig.n_outputs):
- if i in ord_outputs:
- res.append(ord_outputs[i])
- else:
- res.append(None)
- return res
-
- @property
- def displayname(self) -> str:
- data = {
- "__list__": {"construct": "__list__", "destruct": "__getitem__"},
- "__dict__": {"construct": "__dict__", "destruct": "__getitem__"},
- "__set__": {"construct": "__set__", "destruct": "__getitem__"},
- }
- if self.func_op._is_builtin:
- return data[self.func_op.sig.ui_name][self.orientation]
- else:
- return self.func_op.sig.ui_name
-
- def neighbors(
- self, direction: Literal["backward", "forward", "both"] = "both"
- ) -> Set[ValNode]:
- backward = list(self.inputs.values())
- forward = list(self.outputs.values())
- if direction == "backward":
- return set(backward)
- elif direction == "forward":
- return set(forward)
- elif direction == "both":
- return set(backward + forward)
-
- @staticmethod
- def link(
- inputs: Dict[str, ValNode],
- func_op: FuncOp,
- outputs: Dict[str, ValNode],
- constraint: Optional[List[str]],
- orientation: Optional[str],
- include_indexing: bool = True,
- calls: Optional[List[Call]] = None,
- ) -> "CallNode":
- """
- Link a func query into the graph
- """
- if func_op._is_builtin:
- struct_id = func_op.sig.ui_name
- assert orientation is not None
- joined = {**inputs, **outputs}
- if not include_indexing:
- if struct_id == "__list__":
- joined = {k: v for k, v in joined.items() if k != "idx"}
- if struct_id == "__dict__":
- joined = {k: v for k, v in joined.items() if k != "key"}
- input_keys = Builtins.IO[orientation][struct_id]["in"] & joined.keys()
- output_keys = Builtins.IO[orientation][struct_id]["out"] & joined.keys()
- effective_inputs = {k: joined[k] for k in input_keys}
- effective_outputs = {k: joined[k] for k in output_keys}
- else:
- assert orientation is None
- effective_inputs = inputs
- effective_outputs = outputs
- result = CallNode(
- inputs=effective_inputs,
- func_op=func_op,
- outputs=effective_outputs,
- orientation=orientation,
- constraint=constraint,
- calls=calls,
- )
- for name, inp in effective_inputs.items():
- inp.add_consumer(consumer=result, consumed_as=name)
- for name, out in effective_outputs.items():
- out.add_creator(creator=result, created_as=name)
- return result
-
- def unlink(self):
- """
- Remove this `FuncQuery` from the graph.
- """
- for inp in self.inputs.values():
- idxs = [i for i, x in enumerate(inp.consumers) if x is self]
- inp.consumers = [x for i, x in enumerate(inp.consumers) if i not in idxs]
- inp.consumed_as = [
- x for i, x in enumerate(inp.consumed_as) if i not in idxs
- ]
- for out in self.outputs.values():
- idxs = [i for i, x in enumerate(out.creators) if x is self]
- out.creators = [x for i, x in enumerate(out.creators) if i not in idxs]
- out.created_as = [x for i, x in enumerate(out.created_as) if i not in idxs]
-
- def __repr__(self):
- args_string = ", ".join(f"{k}={v}" for k, v in self.inputs.items())
- if self.orientation is not None:
- args_string += f", orientation={self.orientation}"
- return f"CallNode({self.func_op.sig.ui_name}, {args_string})"
-
-
-def traverse_all(
- vqs: Set[ValNode],
- direction: Literal["backward", "forward", "both"] = "both",
-) -> Tuple[Set[ValNode], Set[CallNode]]:
- """
- Extend the given `ValQuery` objects to all objects connected to them through
- function inputs and/or outputs.
- """
- vqs_ = {_ for _ in vqs}
- fqs_: Set[CallNode] = set()
- found_new = True
- while found_new:
- found_new = False
- val_neighbors = concat_lists([v.neighbors(direction=direction) for v in vqs_])
- op_neighbors = concat_lists([o.neighbors(direction=direction) for o in fqs_])
- if any(k not in fqs_ for k in val_neighbors):
- found_new = True
- for neigh in val_neighbors:
- if neigh not in fqs_:
- fqs_.add(neigh)
- if any(k not in vqs_ for k in op_neighbors):
- found_new = True
- for neigh in op_neighbors:
- if neigh not in vqs_:
- vqs_.add(neigh)
- return vqs_, fqs_
-
-
-class BuiltinQueries:
- @staticmethod
- def ListQ(
- elts: List[ValNode], idxs: Optional[List[Optional[ValNode]]] = None
- ) -> ValNode:
- result = ValNode(
- creators=[], created_as=[], tp=ListType(elt_type=AnyType()), constraint=None
- )
- if idxs is None:
- idxs = [
- ValNode(constraint=None, tp=AnyType(), creators=[], created_as=[])
- for _ in elts
- ]
- for elt, idx in zip(elts, idxs):
- CallNode.link(
- inputs={"lst": result, "elt": elt, "idx": idx},
- func_op=Builtins.list_op,
- outputs={},
- constraint=None,
- orientation=StructOrientations.construct,
- )
- return result
-
- @staticmethod
- def DictQ(dct: Dict[ValNode, ValNode]) -> ValNode:
- result = ValNode(
- creators=[], created_as=[], tp=DictType(elt_type=AnyType()), constraint=None
- )
- for key, val in dct.items():
- CallNode.link(
- inputs={"dct": result, "key": key, "val": val},
- func_op=Builtins.dict_op,
- outputs={},
- constraint=None,
- orientation=StructOrientations.construct,
- )
- return result
-
- @staticmethod
- def SetQ(elts: Set[ValNode]) -> ValNode:
- result = ValNode(
- creators=[], created_as=[], tp=SetType(elt_type=AnyType()), constraint=None
- )
- for elt in elts:
- CallNode.link(
- inputs={"st": result, "elt": elt},
- func_op=Builtins.set_op,
- outputs={},
- constraint=None,
- orientation=StructOrientations.construct,
- )
- return result
-
- @staticmethod
- def GetListItemQuery(lst: ValNode, idx: Optional[ValNode] = None) -> ValNode:
- elt_tp = lst.tp.elt_type if isinstance(lst.tp, ListType) else None
- result = ValNode(creators=[], created_as=[], tp=elt_tp, constraint=None)
- CallNode.link(
- inputs={"lst": lst, "elt": result, "idx": idx},
- func_op=Builtins.list_op,
- outputs={},
- orientation=StructOrientations.destruct,
- constraint=None,
- )
- return result
-
- @staticmethod
- def GetDictItemQuery(dct: ValNode, key: Optional[ValNode] = None) -> ValNode:
- val_tp = dct.tp.elt_type if isinstance(dct.tp, DictType) else None
- result = ValNode(creators=[], created_as=[], tp=val_tp, constraint=None)
- CallNode.link(
- inputs={"dct": dct, "key": key, "val": result},
- func_op=Builtins.dict_op,
- outputs={},
- orientation=StructOrientations.destruct,
- constraint=None,
- )
- return result
-
- ############################################################################
- ### syntactic sugar
- ############################################################################
- @staticmethod
- def is_pattern(obj: Any) -> bool:
- if type(obj) is list and Ellipsis in obj:
- return all(
- BuiltinQueries.is_pattern(elt) or isinstance(elt, ValNode)
- for elt in obj
- if elt is not Ellipsis
- )
- elif type(obj) is dict and Ellipsis in obj:
- return all(
- BuiltinQueries.is_pattern(elt) or isinstance(elt, ValNode)
- for elt in obj.values()
- if elt is not Ellipsis
- )
- elif type(obj) is set:
- return all(
- BuiltinQueries.is_pattern(elt) or isinstance(elt, ValNode)
- for elt in obj
- if elt is not Ellipsis
- )
- else:
- return False
-
- @staticmethod
- def link_pattern(obj: Union[list, dict, set, ValNode]) -> ValNode:
- if isinstance(obj, ValNode):
- return obj
- elif type(obj) is list:
- elts = [
- BuiltinQueries.link_pattern(elt) for elt in obj if elt is not Ellipsis
- ]
- result = ValNode(
- creators=[],
- created_as=[],
- tp=ListType(elt_type=AnyType()),
- constraint=None,
- )
- for elt in elts:
- CallNode.link(
- inputs={
- "lst": result,
- "elt": elt,
- "idx": ValNode(
- constraint=None, tp=AnyType(), creators=[], created_as=[]
- ),
- },
- func_op=Builtins.list_op,
- outputs={},
- constraint=None,
- orientation=StructOrientations.construct,
- )
- elif type(obj) is dict:
- elts = {
- k: BuiltinQueries.link_pattern(v)
- for k, v in obj.items()
- if k is not Ellipsis
- }
- result = ValNode(
- creators=[],
- created_as=[],
- tp=DictType(elt_type=AnyType()),
- constraint=None,
- )
- for k, v in elts.items():
- CallNode.link(
- inputs={"dct": result, "key": k, "val": v},
- func_op=Builtins.dict_op,
- outputs={},
- constraint=None,
- orientation=StructOrientations.construct,
- )
- elif type(obj) is set:
- elts = {
- BuiltinQueries.link_pattern(elt) for elt in obj if elt is not Ellipsis
- }
- result = ValNode(
- creators=[],
- created_as=[],
- tp=SetType(elt_type=AnyType()),
- constraint=None,
- )
- for elt in elts:
- CallNode.link(
- inputs={"st": result, "elt": elt},
- func_op=Builtins.set_op,
- outputs={},
- constraint=None,
- orientation=StructOrientations.construct,
- )
- else:
- raise ValueError
- return result
-
-
-def qwrap(obj: Any, tp: Optional[Type] = None, strict: bool = False) -> ValNode:
- """
- Produce a ValQuery from an object.
- """
- if isinstance(obj, ValNode):
- return obj
- elif isinstance(obj, Ref):
- assert obj._query is not None, "Ref must be linked to a query"
- return obj.query
- elif BuiltinQueries.is_pattern(obj=obj):
- if not strict:
- return BuiltinQueries.link_pattern(obj=obj)
- else:
- raise ValueError
- else:
- if strict:
- raise ValueError("value must be a `ValQuery` or `Ref`")
- if tp is None:
- tp = AnyType()
- # wrap a raw value as a pointwise constraint
- uid = obj.uid if isinstance(obj, Ref) else Hashing.get_content_hash(obj)
- return ValNode(
- tp=tp,
- creators=[],
- created_as=[],
- constraint=[uid],
- )
-
-
-def call_query(
- func_op: FuncOp, inputs: Dict[str, Union[list, dict, set, ValNode, Ref, Any]]
-) -> List[ValNode]:
- for k in inputs.keys():
- inputs[k] = qwrap(obj=inputs[k])
- assert all(isinstance(inp, ValNode) for inp in inputs.values())
- ord_outputs = [
- ValNode(creators=[], created_as=[], tp=tp, constraint=None)
- for tp in func_op.output_types
- ]
- outputs = {dump_output_name(index=i): o for i, o in enumerate(ord_outputs)}
- CallNode.link(
- inputs=inputs,
- func_op=func_op,
- outputs=outputs,
- orientation=None,
- constraint=None,
- )
- return ord_outputs
-
-
-################################################################################
-### introspection
-################################################################################
-def _infer_type(val_query: ValNode) -> Type:
- consumer_op_names = [c.func_op.sig.ui_name for c in val_query.consumers]
- mapping = {"__list__": ListType(), "__dict__": DictType(), "__set__": SetType()}
- tps = [mapping.get(x, None) for x in consumer_op_names]
- struct_tps = [x for x in tps if x is not None]
- if len(struct_tps) == 0:
- return AnyType()
- elif len(struct_tps) == 1:
- return struct_tps[0]
- else:
- raise RuntimeError(f"Multiple types for {val_query}: {struct_tps}")
-
-
-def get_vq_orientation(vq: ValNode) -> str:
- if not isinstance(vq.tp, StructType):
- raise ValueError
- if (
- len(vq.creators) == 1
- and vq.creators[0].orientation == StructOrientations.destruct
- ):
- return StructOrientations.destruct
- elif len(vq.creators) == 1 and vq.creators[0].orientation is None:
- return StructOrientations.destruct
- else:
- return StructOrientations.construct
-
-
-def is_idx(vq: ValNode) -> bool:
- for consumer, consumed_as in zip(vq.consumers, vq.consumed_as):
- if consumed_as == "idx" and consumer.func_op.sig.ui_name == "__list__":
- return True
- return False
-
-
-def is_key(vq: ValNode) -> bool:
- for consumer, consumed_as in zip(vq.consumers, vq.consumed_as):
- if consumed_as == "key" and consumer.func_op.sig.ui_name == "__dict__":
- return True
- return False
-
-
-def get_elt_fqs(vq: ValNode) -> List[CallNode]:
- assert isinstance(vq.tp, StructType)
- struct_id = vq.tp.struct_id
- orientation = get_vq_orientation(vq)
- fqs_to_search = (
- vq.consumers if orientation == StructOrientations.destruct else vq.creators
- )
- fqs = [
- fq
- for fq in fqs_to_search
- if fq.func_op.sig.ui_name == struct_id and fq.orientation == orientation
- ]
- return fqs
-
-
-def get_elts(vq: ValNode) -> Dict[CallNode, ValNode]:
- """
- Get the constituent element queries of a set, as a dictionary of {fq: vq}
- pairs.
- """
- orientation = get_vq_orientation(vq)
- elt_fqs = get_elt_fqs(vq)
- return {
- fq: fq.inputs["elt"]
- if orientation == StructOrientations.construct
- else fq.outputs["elt"]
- for fq in elt_fqs
- }
-
-
-def get_items(vq: ValNode) -> Dict[CallNode, Tuple[Optional[ValNode], ValNode]]:
- """
- Get the constituent elements and indices of a list or dict in the form of
- {fq: (idx_vq, elt_vq)} pairs.
- """
- assert isinstance(vq.tp, (ListType, DictType))
- orientation = get_vq_orientation(vq)
- elt_fqs = get_elt_fqs(vq)
- elt_key = "elt" if isinstance(vq.tp, ListType) else "val"
- idx_key = "idx" if isinstance(vq.tp, ListType) else "key"
- return {
- fq: (fq.inputs.get(idx_key), fq.inputs[elt_key])
- if orientation == StructOrientations.destruct
- else (fq.inputs.get(idx_key), fq.inputs[elt_key])
- for fq in elt_fqs
- }
-
-
-def get_elt_and_struct(fq: CallNode) -> Tuple[ValNode, ValNode]:
- assert fq.func_op.is_builtin
- struct_id = fq.func_op.sig.ui_name
- elt_target = (
- fq.outputs if fq.orientation == StructOrientations.destruct else fq.inputs
- )
- struct_target = (
- fq.inputs if fq.orientation == StructOrientations.destruct else fq.outputs
- )
- if struct_id == "__list__":
- return elt_target["elt"], struct_target["lst"]
- elif struct_id == "__dict__":
- return elt_target["val"], struct_target["dct"]
- elif struct_id == "__set__":
- return elt_target["elt"], struct_target["st"]
- else:
- raise NotImplementedError()
-
-
-def get_idx(fq: CallNode) -> Optional[ValNode]:
- assert fq.func_op.is_builtin
- struct_id = fq.func_op.sig.ui_name
- idx_target = fq.inputs
- if struct_id == "__list__":
- return idx_target.get("idx", None)
- elif struct_id == "__dict__":
- return idx_target.get("key", None)
- else:
- raise ValueError
-
-
-def prepare_query(ref: Ref, tp: Type):
- if ref._query is None:
- ref._query = ValNode(tp=tp, constraint=None, creators=[], created_as=[])
diff --git a/mandala/queries/workflow.py b/mandala/queries/workflow.py
deleted file mode 100644
index 2643586..0000000
--- a/mandala/queries/workflow.py
+++ /dev/null
@@ -1,288 +0,0 @@
-from ..common_imports import *
-from .weaver import *
-from ..core.model import Ref, FuncOp
-from ..core.utils import Hashing
-from ..core.config import dump_output_name, parse_output_idx
-
-
-class CallStruct:
- def __init__(self, func_op: FuncOp, inputs: Dict[str, Ref], outputs: List[Ref]):
- self.func_op = func_op
- self.inputs = inputs
- self.outputs = outputs
-
-
-class Workflow:
- """
- An intermediate representation of a collection of calls, possibly not all of
- which have been executed, following a particular computational graph.
-
- Used to:
- - represent work to be done in a `batch` context as a data structure
- - encode an entire workflow for e.g. testing scenarios that simulate
- real-world workloads
- """
-
- def __init__(self):
- ### encoding the shape
- # in topological order
- self.var_nodes: List[ValNode] = []
- # note that there may be many var nodes with the same causal hash
- self.var_node_to_causal_hash: Dict[ValNode, str] = {}
- # in topological order
- self.op_nodes: List[CallNode] = []
- # self.causal_hash_to_op_node: Dict[str, FuncQuery] = {}
- self.op_node_to_causal_hash: Dict[CallNode, str] = {}
- ### encoding instance data
- # multiple refs may map to the same query node
- self.value_to_var: Dict[Ref, ValNode] = {}
- # for a given op node, there may be multiple call structs
- self.op_node_to_call_structs: Dict[CallNode, List[CallStruct]] = {}
-
- def check_invariants(self):
- assert set(self.var_node_to_causal_hash.keys()) == set(self.var_nodes)
- assert set(self.op_node_to_causal_hash.keys()) == set(self.op_nodes)
- assert set(self.op_node_to_call_structs.keys()) == set(self.op_nodes)
- assert set(self.value_to_var.values()) <= set(self.var_nodes)
- for op_node in self.op_nodes:
- for call_struct in self.op_node_to_call_structs[op_node]:
- input_locations = {
- k: self.value_to_var[v] for k, v in call_struct.inputs.items()
- }
- assert input_locations == op_node.inputs
- output_locations = {
- dump_output_name(i): self.value_to_var[v]
- for i, v in enumerate(call_struct.outputs)
- }
- assert output_locations == op_node.outputs
-
- def get_default_hash(self) -> str:
- return Hashing.get_content_hash(obj="null")
-
- @property
- def callable_op_nodes(self) -> List[CallNode]:
- # return op nodes that have non-empty inputs
- res = []
- var_to_values = self.var_to_values()
- for op_node in self.op_nodes:
- if all([len(var_to_values[var]) > 0 for var in op_node.inputs.values()]):
- res.append(op_node)
- return res
-
- @property
- def inputs(self) -> List[ValNode]:
- # return [var for var in self.var_nodes if var.creator is None]
- return [var for var in self.var_nodes if len(var.creators) == 0]
-
- def var_to_values(self) -> Dict[ValNode, List[Ref]]:
- res = defaultdict(list)
- for value, var in self.value_to_var.items():
- res[var].append(value)
- return res
-
- def add_var(self, val_query: Optional[ValNode] = None) -> ValNode:
- res = (
- val_query
- if val_query is not None
- # else ValQuery(creator=None, created_as=None)
- else ValNode(creators=[], created_as=[], constraint=None, tp=AnyType())
- )
- # if res.creator is None:
- if len(res.creators) == 0:
- causal_hash = self.get_default_hash()
- else:
- # creator_hash = self.op_node_to_causal_hash[res.creator]
- creator_hash = self.op_node_to_causal_hash[res.creators[0]]
- # causal_hash = Hashing.get_content_hash(obj=[creator_hash,
- # res.created_as])
- causal_hash = Hashing.get_content_hash(
- obj=[creator_hash, res.created_as[0]]
- )
- self.var_nodes.append(res)
- self.var_node_to_causal_hash[res] = causal_hash
- return res
-
- def get_op_hash(
- self,
- func_op: FuncOp,
- node_inputs: Optional[Dict[str, ValNode]] = None,
- val_inputs: Optional[Dict[str, Ref]] = None,
- ) -> str:
- assert (node_inputs is None) != (val_inputs is None)
- if val_inputs is not None:
- node_inputs = {
- name: self.value_to_var[val] for name, val in val_inputs.items()
- }
- assert node_inputs is not None
- input_causal_hashes = {
- name: self.var_node_to_causal_hash[val] for name, val in node_inputs.items()
- }
- input_causal_hashes = {
- k: v for k, v in input_causal_hashes.items() if v != self.get_default_hash()
- }
- input_causal_hashes = sorted(input_causal_hashes.items())
- op_representation = [
- input_causal_hashes,
- func_op.sig.versioned_internal_name,
- ]
- causal_hash = Hashing.get_content_hash(obj=op_representation)
- return causal_hash
-
- def add_op(
- self,
- inputs: Dict[str, ValNode],
- func_op: FuncOp,
- ) -> Tuple[CallNode, Dict[str, ValNode]]:
- # TODO: refactor the `FuncQuery` creation here
- res = CallNode(inputs=inputs, func_op=func_op, outputs={}, constraint=None)
- causal_hash = self.get_op_hash(node_inputs=inputs, func_op=func_op)
- self.op_nodes.append(res)
- self.op_node_to_causal_hash[res] = causal_hash
- # create outputs
- outputs = {}
- for i in range(func_op.sig.n_outputs):
- # output = self.add_var(val_query=ValQuery(creator=res,
- # created_as=i))
- output_name = dump_output_name(index=i)
- output = self.add_var(
- val_query=ValNode(
- creators=[res],
- created_as=[output_name],
- constraint=None,
- tp=AnyType(),
- )
- )
- outputs[output_name] = output
- # assign outputs to op
- res.set_outputs(outputs=outputs)
- self.op_node_to_call_structs[res] = []
- return res, outputs
-
- def add_value(self, value: Ref, var: ValNode):
- assert var in self.var_nodes
- self.value_to_var[value] = var
-
- def add_call_struct(self, call_struct: CallStruct):
- # process inputs
- func_op, inputs, outputs = (
- call_struct.func_op,
- call_struct.inputs,
- call_struct.outputs,
- )
- if any([inp not in self.value_to_var.keys() for inp in inputs.values()]):
- raise NotImplementedError()
- op_hash = self.get_op_hash(func_op=func_op, val_inputs=inputs)
- if op_hash not in self.op_node_to_causal_hash.values():
- # create op
- op_node, output_nodes = self.add_op(
- inputs={name: self.value_to_var[inp] for name, inp in inputs.items()},
- func_op=func_op,
- )
- else:
- candidates = [
- op_node
- for op_node in self.op_nodes
- if self.op_node_to_causal_hash[op_node] == op_hash
- and op_node.inputs
- == {name: self.value_to_var[inp] for name, inp in inputs.items()}
- ]
- op_node = candidates[0]
- output_nodes = op_node.outputs
- # process outputs
- outputs_dict = {dump_output_name(i): output for i, output in enumerate(outputs)}
- for k in outputs_dict.keys():
- self.value_to_var[outputs_dict[k]] = output_nodes[k]
- self.op_node_to_call_structs[op_node].append(call_struct)
-
- ############################################################################
- ###
- ############################################################################
- @staticmethod
- def from_call_structs(call_structs: List[CallStruct]) -> "Workflow":
- """
- Assumes calls are given in topological order
- """
- res = Workflow()
- input_var = res.add_var()
- for call_struct in call_structs:
- inputs = call_struct.inputs
- for inp in inputs.values():
- if inp not in res.value_to_var.keys():
- res.add_value(value=inp, var=input_var)
- res.add_call_struct(call_struct)
- return res
-
- @staticmethod
- def from_traversal(
- vqs: List[ValNode],
- ) -> Tuple["Workflow", Dict[ValNode, ValNode]]:
- vqs, fqs = traverse_all(vqs, direction="backward")
- vqs_topsort = reversed(vqs)
- fqs_topsort = reversed(fqs)
- # input_vqs = [vq for vq in vqs_topsort if vq.creator is None]
- input_vqs = [vq for vq in vqs_topsort if len(vq.creators) == 0]
- res = Workflow()
- vq_to_new_vq = {}
- for vq in input_vqs:
- new_vq = res.add_var(val_query=vq)
- vq_to_new_vq[vq] = new_vq
- for fq in fqs_topsort:
- new_inputs = {name: vq_to_new_vq[vq] for name, vq in fq.inputs.items()}
- new_fq, new_outputs = res.add_op(inputs=new_inputs, func_op=fq.func_op)
- for k in new_outputs.keys():
- vq, new_vq = fq.outputs[k], new_outputs[k]
- vq_to_new_vq[vq] = new_vq
- return res, vq_to_new_vq
-
- ############################################################################
- ###
- ############################################################################
- @property
- def empty(self) -> bool:
- return len(self.value_to_var) == 0
-
- @property
- def shape_size(self) -> int:
- return len(self.op_nodes) + len(self.var_nodes)
-
- @property
- def num_calls(self) -> int:
- return sum(
- [
- len(call_structs)
- for call_structs in self.op_node_to_call_structs.values()
- ]
- )
-
- @property
- def is_saturated(self) -> bool:
- var_to_values = self.var_to_values()
- return all([len(var_to_values[var]) > 0 for var in self.var_nodes])
-
- @property
- def has_delayed(self) -> bool:
- return any([value.is_delayed() for value in self.value_to_var.keys()])
-
- def print_shape(self):
- var_names = {var: f"var_{i}" for i, var in enumerate(self.var_nodes)}
- for var in self.inputs:
- print(f"{var_names[var]} = Q()")
- for op_node in self.op_nodes:
- numbered_outputs = {
- parse_output_idx(k): v for k, v in op_node.outputs.items()
- }
- outputs_list = [numbered_outputs[i] for i in range(len(numbered_outputs))]
- lhs = ", ".join([var_names[var] for var in outputs_list])
- print(
- f"{lhs} = {op_node.func_op.sig.ui_name}("
- + ", ".join(
- [f"{name}={var_names[var]}" for name, var in op_node.inputs.items()]
- )
- + ")"
- )
-
-
-class History:
- def __init__(self, workflow: Workflow, node: ValNode):
- self.workflow = workflow
- self.node = node
diff --git a/mandala/_next/storage.py b/mandala/storage.py
similarity index 100%
rename from mandala/_next/storage.py
rename to mandala/storage.py
diff --git a/mandala/_next/storage_utils.py b/mandala/storage_utils.py
similarity index 100%
rename from mandala/_next/storage_utils.py
rename to mandala/storage_utils.py
diff --git a/mandala/storages/__init__.py b/mandala/storages/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/mandala/storages/kv.py b/mandala/storages/kv.py
deleted file mode 100644
index 64d88b0..0000000
--- a/mandala/storages/kv.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from ..common_imports import *
-from typing import Generic, TypeVar
-from multiprocessing import Manager
-from multiprocessing.managers import DictProxy
-
-
-_KT = TypeVar("_KT")
-_VT = TypeVar("_VT")
-
-
-class KVCache(Generic[_KT, _VT]):
- """
- Interface for key-value stores for Python objects (keyed by strings).
- """
-
- def __init__(self):
- self.dirty_entries = set()
-
- def exists(self, k: _KT) -> bool:
- raise NotImplementedError
-
- def set(self, k: _KT, v: _VT) -> None:
- raise NotImplementedError
-
- def get(self, k: _KT) -> Any:
- raise NotImplementedError
-
- def __getitem__(self, k: _KT) -> _VT:
- raise NotImplementedError
-
- def __setitem__(self, k: _KT, v: _VT) -> None:
- raise NotImplementedError
-
- def delete(self, k: _KT) -> None:
- raise NotImplementedError
-
- def keys(self) -> List[_KT]:
- raise NotImplementedError
-
- def items(self) -> Iterable[Tuple[_KT, _VT]]:
- raise NotImplementedError
-
- def evict_all(self) -> None:
- """
- Remove all entries from the cache.
- """
- raise NotImplementedError
-
- def clear_all(self) -> None:
- """
- Mark all entries as clean.
- """
- raise NotImplementedError
-
-
-class InMemoryStorage(KVCache):
- """
- Simple in-memory implementation for local testing and/or buffering
- """
-
- def __init__(self):
- self.data: dict[str, Any] = {}
- self.dirty_entries: set[str] = set()
-
- def __repr__(self):
- return f"InMemoryStorage(data={self.data})"
-
- def exists(self, k: str) -> bool:
- return k in self.data
-
- def set(self, k: str, v: Any):
- self.data[k] = v
- self.dirty_entries.add(k)
-
- def get(self, k: str) -> Any:
- return self.data[k]
-
- def __getitem__(self, k: str) -> Any:
- return self.data[k]
-
- def __setitem__(self, k: str, v: Any) -> None:
- self.data[k] = v
- self.dirty_entries.add(k)
-
- def delete(self, k: str):
- del self.data[k]
- self.dirty_entries.remove(k)
-
- def keys(self) -> List[str]:
- return list(self.data.keys())
-
- def items(self) -> Iterable[Tuple[str, Any]]:
- return self.data.items()
-
- def evict_all(self) -> None:
- self.data = {}
- self.dirty_entries = set()
-
- def clear_all(self) -> None:
- self.dirty_entries = set()
-
- @property
- def is_clean(self) -> bool:
- return len(self.dirty_entries) == 0
-
-
-class MultiProcInMemoryStorage(KVCache):
- def __init__(self):
- manager = Manager()
- self.data: DictProxy[str, Any] = manager.dict()
- self.dirty_entries: DictProxy[str, None] = manager.dict()
-
- def __repr__(self):
- return f"MultiProcInMemoryStorage(data={self.data})"
-
- def exists(self, k: str) -> bool:
- return k in self.data.keys()
-
- def set(self, k: str, v: Any):
- self.data[k] = v
- self.dirty_entries[k] = None
-
- def get(self, k: str) -> Any:
- return self.data[k]
-
- def delete(self, k: str):
- del self.data[k]
- self.dirty_entries.pop(k)
-
- def keys(self) -> List[str]:
- return list(self.data.keys())
-
- @property
- def is_clean(self) -> bool:
- return len(self.dirty_entries) == 0
diff --git a/mandala/storages/rel_impls/__init__.py b/mandala/storages/rel_impls/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/mandala/storages/rel_impls/bases.py b/mandala/storages/rel_impls/bases.py
deleted file mode 100644
index f06ce49..0000000
--- a/mandala/storages/rel_impls/bases.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from abc import ABC, abstractmethod
-import sqlite3
-from pypika import Query, Column
-from ...common_imports import *
-from .utils import Connection
-import pyarrow as pa
-
-
-class RelStorage(ABC):
- """
- Responsible for the low-level (i.e., unaware of mandala-specific concepts)
- interactions with the relational part of the storage, such as creating and
- extending tables, running queries, etc. This is intended to be a pretty
- generic, minimal database interface, supporting just the things we need.
-
- It's deliberately referred to as "relational storage" as opposed to a
- "relational database" because simpler implementations exist.
- """
-
- @abstractmethod
- def create_relation(
- self,
- name: str,
- columns: List[Tuple[str, Optional[str]]], # [(col name, type), ...]
- defaults: Dict[str, Any], # {col name: default value, ...}
- primary_key: Optional[Union[str, List[str]]] = None,
- if_not_exists: bool = True,
- conn: Optional[Any] = None,
- ):
- """
- Create a relation with the given name and columns.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def create_column(self, relation: str, name: str, default_value: str):
- raise NotImplementedError()
-
- @abstractmethod
- def insert(self, name: str, df: pd.DataFrame):
- """
- Append rows to a table
- """
- raise NotImplementedError()
-
- @abstractmethod
- def upsert(self, relation: str, ta: pa.Table, conn: Optional[Connection] = None):
- """
- Upsert rows in a table based on index
- """
- raise NotImplementedError()
-
- @abstractmethod
- def delete(
- self,
- relation: str,
- where_col: str,
- where_values: List[str],
- conn: Optional[Connection] = None,
- ):
- """
- Delete rows from a table where `where_col` is in `where_values`
- """
- raise NotImplementedError()
-
- @abstractmethod
- def get_data(
- self, table: str, conn: Optional[sqlite3.Connection] = None
- ) -> pd.DataFrame:
- """
- Fetch data from a table.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def execute_df(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[sqlite3.Connection] = None,
- ) -> pd.DataFrame:
- """
- Execute a query and return the result as a DataFrame.
- """
- raise NotImplementedError()
diff --git a/mandala/storages/rel_impls/duckdb_impl.py b/mandala/storages/rel_impls/duckdb_impl.py
deleted file mode 100644
index d37a397..0000000
--- a/mandala/storages/rel_impls/duckdb_impl.py
+++ /dev/null
@@ -1,307 +0,0 @@
-import duckdb
-import pyarrow as pa
-from duckdb import DuckDBPyConnection as Connection
-from pypika import Query, Column
-
-from .bases import RelStorage
-from .utils import Transactable, transaction
-from ...core.config import Config
-from ...core.utils import get_uid
-from ...common_imports import *
-
-
-class DuckDBRelStorage(RelStorage, Transactable):
- UID_DTYPE = "VARCHAR" # TODO - change this
- TEMP_ARROW_TABLE = Config.temp_arrow_table
-
- def __init__(self, address: Optional[str] = None, _read_only: bool = False):
- self.in_memory = address is None
- self._read_only = _read_only
- if self.in_memory:
- self._conn = duckdb.connect(database=":memory:", read_only=self._read_only)
- else:
- self.path = address
-
- ############################################################################
- ### transaction interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return (
- self._conn
- if self.in_memory
- else duckdb.connect(database=self.path, read_only=self._read_only)
- )
-
- def _end_transaction(self, conn: Connection):
- if not self.in_memory:
- conn.close()
-
- ############################################################################
- ###
- ############################################################################
- @transaction()
- def get_tables(self, conn: Optional[Connection] = None) -> List[str]:
- return self.execute_df(query="SHOW TABLES;", conn=conn)["name"].tolist()
-
- @transaction()
- def table_exists(self, relation: str, conn: Optional[Connection] = None) -> bool:
- return relation in self.get_tables(conn=conn)
-
- @transaction()
- def get_data(self, table: str, conn: Optional[Connection] = None) -> pd.DataFrame:
- return self.execute_df(query=f"SELECT * FROM {table};", conn=conn)
-
- @transaction()
- def get_count(self, table: str, conn: Optional[Connection] = None) -> int:
- df = self.execute_df(query=f"SELECT COUNT(*) FROM {table};", conn=conn)
- return df["count_star()"].item()
-
- @transaction()
- def get_all_data(
- self, conn: Optional[Connection] = None
- ) -> Dict[str, pd.DataFrame]:
- tables = self.get_tables(conn=conn)
- data = {}
- for table in tables:
- data[table] = self.get_data(table=table, conn=conn)
- return data
-
- ############################################################################
- ### schema management
- ############################################################################
- @transaction()
- def create_relation(
- self,
- name: str,
- columns: List[Tuple[str, Optional[str]]], # [(col name, type), ...]
- defaults: Dict[str, Any], # {col name: default value, ...}
- primary_key: Optional[str] = None,
- if_not_exists: bool = True,
- conn: Optional[Connection] = None,
- ):
- """
- Create a table with given columns, with an optional primary key.
- Columns are given as tuples of (name, type).
- Columns without a dtype are assumed to be of type `self.UID_DTYPE`.
- """
- query = Query.create_table(table=name).columns(
- *[
- Column(
- column_name=column_name,
- column_type=column_type
- if column_type is not None
- else self.UID_DTYPE,
- default=defaults.get(column_name, None),
- # nullable=False,
- )
- for column_name, column_type in columns
- ],
- )
- if if_not_exists:
- query = query.if_not_exists()
- if primary_key is not None:
- query = query.primary_key(primary_key)
- conn.execute(str(query))
- logger.debug(
- f'Created table "{name}" with columns {[elt[0] for elt in columns]}'
- )
-
- @transaction()
- def create_column(
- self,
- relation: str,
- name: str,
- default_value: str,
- conn: Optional[Connection] = None,
- ):
- """
- Add a new column to a table.
- """
- query = f"ALTER TABLE {relation} ADD COLUMN {name} {self.UID_DTYPE} DEFAULT '{default_value}'"
- conn.execute(query=query)
- logger.debug(f'Added column "{name}" to table "{relation}"')
-
- @transaction()
- def drop_column(self, relation: str, name: str, conn: Optional[Connection] = None):
- """
- Drop a column from a table.
- """
- query = f"ALTER TABLE {relation} DROP COLUMN {name}"
- conn.execute(query=query)
- logger.debug(f'Dropped column "{name}" from table "{relation}"')
-
- @transaction()
- def rename_relation(
- self, name: str, new_name: str, conn: Optional[Connection] = None
- ):
- """
- Rename a table
- """
- query = f"ALTER TABLE {name} RENAME TO {new_name};"
- conn.execute(query)
- logger.debug(f'Renamed table "{name}" to "{new_name}"')
-
- @transaction()
- def rename_column(
- self, relation: str, name: str, new_name: str, conn: Optional[Connection] = None
- ):
- """
- Rename a column
- """
- query = f'ALTER TABLE {relation} RENAME "{name}" TO "{new_name}";'
- conn.execute(query)
- logger.debug(f'Renamed column "{name}" of table "{relation}" to "{new_name}"')
-
- @transaction()
- def rename_columns(
- self, relation: str, mapping: Dict[str, str], conn: Optional[Connection] = None
- ):
- # factorize the renaming into two maps that can be applied atomically
- part_1 = {k: get_uid() for k in mapping.keys()}
- part_2 = {part_1[k]: v for k, v in mapping.items()}
- for k, v in part_1.items():
- self.rename_column(relation=relation, name=k, new_name=v, conn=conn)
- for k, v in part_2.items():
- self.rename_column(relation=relation, name=k, new_name=v, conn=conn)
- if len(mapping) > 0:
- logger.debug(f'Renamed columns of table "{relation}" via mapping {mapping}')
-
- ############################################################################
- ### instance management
- ############################################################################
- @transaction()
- def _get_cols(self, relation: str, conn: Optional[Connection] = None) -> List[str]:
- """
- Duckdb-specific method to get the *ordered* columns of a table.
- """
- return (
- self.execute_arrow(query=f'DESCRIBE "{relation}";', conn=conn)
- .column("column_name")
- .to_pylist()
- )
-
- @transaction()
- def _get_primary_keys(
- self, relation: str, conn: Optional[Connection] = None
- ) -> List[str]:
- """
- Duckdb-specific method to get the primary key of a table.
- """
- constraint_type = "PRIMARY KEY"
- df = self.execute_df(query=f"SELECT * FROM duckdb_constraints();", conn=conn)
- df = df[["table_name", "constraint_type", "constraint_column_names"]]
- df = df[
- (df["table_name"] == relation) & (df["constraint_type"] == constraint_type)
- ]
- if len(df) == 0:
- raise NotImplementedError()
- elif len(df) == 1:
- return df["constraint_column_names"].item()
- else:
- raise NotImplementedError(f"Multiple primary keys for {relation}")
-
- @transaction()
- def insert(self, relation: str, ta: pa.Table, conn: Optional[Connection] = None):
- """
- Append rows to a table
- """
- if len(ta) == 0:
- return
- table_cols = self._get_cols(relation=relation, conn=conn)
- assert set(ta.column_names) == set(table_cols)
- cols_string = ", ".join([f'"{column_name}"' for column_name in ta.column_names])
- conn.register(view_name=self.TEMP_ARROW_TABLE, python_object=ta)
- conn.execute(
- f'INSERT INTO "{relation}" ({cols_string}) SELECT * FROM {self.TEMP_ARROW_TABLE}'
- )
- conn.unregister(view_name=self.TEMP_ARROW_TABLE)
-
- @transaction()
- def upsert(
- self,
- relation: str,
- ta: pa.Table,
- key_cols: Optional[List[str]] = None,
- conn: Optional[Connection] = None,
- ):
- """
- Upsert rows in a table based on primary key.
-
- TODO: currently does NOT update matching rows
- """
- if len(ta) == 0:
- return
- if not self.table_exists(relation, conn=conn):
- raise RuntimeError()
- col_names = ta.column_names if isinstance(ta, pa.Table) else ta.columns
- cols_string = ", ".join([f'"{column_name}"' for column_name in col_names])
- conn.register(view_name=self.TEMP_ARROW_TABLE, python_object=ta)
- if key_cols is None:
- primary_keys = self._get_primary_keys(relation=relation, conn=conn)
- if len(primary_keys) != 1:
- raise NotImplementedError()
- primary_key = primary_keys[0]
- query = f'INSERT INTO "{relation}" ({cols_string}) SELECT * FROM {self.TEMP_ARROW_TABLE} WHERE "{primary_key}" NOT IN (SELECT "{primary_key}" FROM "{relation}")'
- else:
- key_cols_str = ", ".join([f'"{column_name}"' for column_name in key_cols])
- key_cols_str = f"({key_cols_str})"
- query = f'INSERT INTO "{relation}" ({cols_string}) SELECT * FROM {self.TEMP_ARROW_TABLE} WHERE {key_cols_str} NOT IN (SELECT {key_cols_str} FROM "{relation}")'
- conn.execute(query)
- conn.unregister(view_name=self.TEMP_ARROW_TABLE)
-
- @transaction()
- def delete(
- self, relation: str, index: List[str], conn: Optional[Connection] = None
- ):
- """
- Delete rows from a table based on index
- """
- primary_keys = self._get_primary_keys(relation=relation, conn=conn)
- if len(primary_keys) != 1:
- raise NotImplementedError()
- primary_key = primary_keys[0]
- in_str = ", ".join([f"'{i}'" for i in index])
- query = f'DELETE FROM "{relation}" WHERE {primary_key} IN ({in_str})'
- conn.execute(query)
-
- ############################################################################
- ### queries
- ############################################################################
- @transaction()
- def execute_arrow(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[Connection] = None,
- ) -> pa.Table:
- if parameters is None:
- parameters = []
- if not isinstance(query, str):
- query = str(query)
- return conn.execute(query, parameters=parameters).fetch_arrow_table()
-
- @transaction()
- def execute_no_results(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[Connection] = None,
- ) -> None:
- if parameters is None:
- parameters = []
- if not isinstance(query, str):
- query = str(query)
- return conn.execute(query, parameters=parameters)
-
- @transaction()
- def execute_df(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[Connection] = None,
- ) -> pd.DataFrame:
- if parameters is None:
- parameters = []
- if not isinstance(query, str):
- query = str(query)
- return conn.execute(query, parameters=parameters).fetchdf()
diff --git a/mandala/storages/rel_impls/sqlite_impl.py b/mandala/storages/rel_impls/sqlite_impl.py
deleted file mode 100644
index 0e37edc..0000000
--- a/mandala/storages/rel_impls/sqlite_impl.py
+++ /dev/null
@@ -1,393 +0,0 @@
-import sqlite3
-from pandas.api.types import is_string_dtype
-import pyarrow as pa
-from pypika import Query, Column
-
-from .bases import RelStorage
-from .utils import Transactable, transaction
-from ...core.utils import get_uid
-from ...common_imports import *
-
-
-class SQLiteRelStorage(RelStorage, Transactable):
- UID_DTYPE = "VARCHAR"
-
- def __init__(
- self,
- address: Optional[str] = None,
- _read_only: bool = False,
- autocommit: bool = False,
- journal_mode: str = "WAL",
- page_size: int = 32768,
- mmap_size_MB: int = 256,
- cache_size_pages: int = 1000,
- synchronous: str = "normal",
- ):
- self.journal_mode = journal_mode
- self.page_size = page_size
- self.mmap_size_MB = mmap_size_MB
- self.cache_size_pages = cache_size_pages
- self.synchronous = synchronous
- self.autocommit = autocommit
- self._read_only = _read_only
-
- self.in_memory = address is None
- if self.in_memory:
- self._id = get_uid()
- self._connection_address = f"file:{self._id}?mode=memory&cache=shared"
- self._conn = sqlite3.connect(
- str(self._connection_address), isolation_level=None, uri=True
- )
- with self._conn:
- self.apply_optimizations(self._conn)
- else:
- self._connection_address = address
-
- def get_optimizations(self) -> List[str]:
- """
- This needs some explaining:
- - you cannot change `page_size` after setting `journal_mode = WAL`
- - `journal_mode = WAL` is persistent across database connections
- - `cache_size` is in pages when positive, in kB when negative
- """
- if self.mmap_size_MB is None:
- mmap_size = 0
- else:
- mmap_size = self.mmap_size_MB * 1024**2
- pragma_dict = OrderedDict(
- [
- # 'temp_store': 'memory',
- ("synchronous", self.synchronous),
- ("page_size", self.page_size),
- ("cache_size", self.cache_size_pages),
- ("journal_mode", self.journal_mode),
- ("mmap_size", mmap_size),
- ("foreign_keys", "ON"),
- ]
- )
- lines = [f"PRAGMA {k} = {v};" for k, v in pragma_dict.items()]
- return lines
-
- def apply_optimizations(self, conn: sqlite3.Connection):
- opts = self.get_optimizations()
- for line in opts:
- conn.execute(line)
-
- def read_cursor(self, c: sqlite3.Cursor) -> pd.DataFrame:
- if c.description is None:
- assert len(c.fetchall()) == 0
- return pd.DataFrame()
- colnames = [col[0] for col in c.description]
- df = pd.DataFrame(c.fetchall(), columns=colnames)
- return self.postprocess_df(df)
-
- def postprocess_df(self, df: pd.DataFrame) -> pd.DataFrame:
- for col, dtype in df.dtypes.items():
- if not is_string_dtype(dtype):
- df[col] = df[col].astype(str)
- return df
-
- ############################################################################
- ### transaction interface
- ############################################################################
- def _get_connection(self) -> sqlite3.Connection:
- if self.in_memory:
- return self._conn
- else:
- return sqlite3.connect(
- str(self._connection_address), isolation_level=None # "IMMEDIATE"
- )
-
- def _end_transaction(self, conn: sqlite3.Connection):
- conn.commit()
- if not self.in_memory:
- conn.close()
-
- ############################################################################
- ###
- ############################################################################
- @transaction()
- def get_tables(self, conn: Optional[sqlite3.Connection] = None) -> List[str]:
- query = "SELECT name FROM sqlite_master WHERE type='table';"
- cur = conn.cursor()
- cur.execute(query)
- return [row[0] for row in cur.fetchall()]
-
- @transaction()
- def table_exists(
- self, relation: str, conn: Optional[sqlite3.Connection] = None
- ) -> bool:
- return relation in self.get_tables(conn=conn)
-
- @transaction()
- def get_data(
- self, table: str, conn: Optional[sqlite3.Connection] = None
- ) -> pd.DataFrame:
- return self.execute_df(query=f"SELECT * FROM {table};", conn=conn)
-
- @transaction()
- def get_count(self, table: str, conn: Optional[sqlite3.Connection] = None) -> int:
- query = f"SELECT COUNT(*) FROM {table};"
- return int(self.execute_df(query=query, conn=conn).iloc[0, 0])
-
- @transaction()
- def get_all_data(
- self, conn: Optional[sqlite3.Connection] = None
- ) -> Dict[str, pd.DataFrame]:
- return {
- table: self.get_data(table, conn=conn)
- for table in self.get_tables(conn=conn)
- }
-
- ############################################################################
- ### schema management
- ############################################################################
- @transaction()
- def create_relation(
- self,
- name: str,
- columns: List[Tuple[str, Optional[str]]], # [(col name, type), ...]
- defaults: Dict[str, Any], # {col name: default value, ...}
- primary_key: Optional[Union[str, List[str]]] = None,
- if_not_exists: bool = True,
- conn: Optional[sqlite3.Connection] = None,
- ):
- """
- Create a table with given columns, with an optional primary key.
- Columns are given as tuples of (name, type).
- Columns without a dtype are assumed to be of type `self.UID_DTYPE`.
- """
- query = Query.create_table(table=name).columns(
- *[
- Column(
- column_name=column_name,
- column_type=column_type
- if column_type is not None
- else self.UID_DTYPE,
- default=defaults.get(column_name, None),
- # nullable=False,
- )
- for column_name, column_type in columns
- ],
- )
- if if_not_exists:
- query = query.if_not_exists()
- if primary_key is not None:
- if isinstance(primary_key, str):
- query = query.primary_key(primary_key)
- else:
- query = query.primary_key(*primary_key)
- conn.execute(str(query))
- logger.debug(
- f'Created table "{name}" with columns {[elt[0] for elt in columns]}'
- )
-
- @transaction()
- def create_column(
- self,
- relation: str,
- name: str,
- default_value: str,
- conn: Optional[sqlite3.Connection] = None,
- ):
- """
- Add a new column to a table.
- """
- query = (
- f"ALTER TABLE {relation} ADD COLUMN {name} TEXT DEFAULT '{default_value}'"
- )
- conn.execute(query)
- logger.debug(f'Created column "{name}" in table "{relation}"')
-
- @transaction()
- def drop_column(
- self, relation: str, name: str, conn: Optional[sqlite3.Connection] = None
- ):
- """
- Drop a column from a table.
- """
- query = f'ALTER TABLE {relation} DROP COLUMN "{name}"'
- conn.execute(query)
- logger.debug(f'Dropped column "{name}" from table "{relation}"')
-
- @transaction()
- def rename_relation(
- self, name: str, new_name: str, conn: Optional[sqlite3.Connection] = None
- ):
- """
- Rename a table
- """
- query = f"ALTER TABLE {name} RENAME TO {new_name};"
- conn.execute(query)
- logger.debug(f'Renamed table "{name}" to "{new_name}"')
-
- @transaction()
- def rename_column(
- self,
- relation: str,
- name: str,
- new_name: str,
- conn: Optional[sqlite3.Connection] = None,
- ):
- """
- Rename a column
- """
- query = f'ALTER TABLE {relation} RENAME COLUMN "{name}" TO "{new_name}";'
- conn.execute(query)
- logger.debug(f'Renamed column "{name}" in table "{relation}" to "{new_name}"')
-
- @transaction()
- def rename_columns(
- self,
- relation: str,
- mapping: Dict[str, str],
- conn: Optional[sqlite3.Connection] = None,
- ):
- # factorize the renaming into two maps that can be applied atomically
- part_1 = {k: get_uid() for k in mapping.keys()}
- part_2 = {part_1[k]: v for k, v in mapping.items()}
- for k, v in part_1.items():
- self.rename_column(relation=relation, name=k, new_name=v, conn=conn)
- for k, v in part_2.items():
- self.rename_column(relation=relation, name=k, new_name=v, conn=conn)
- if len(mapping) > 0:
- logger.debug(f'Renamed columns of table "{relation}" via mapping {mapping}')
-
- ############################################################################
- ### instance management
- ############################################################################
- @transaction()
- def _get_cols(
- self, relation: str, conn: Optional[sqlite3.Connection] = None
- ) -> List[str]:
- """
- get the *ordered* columns of a table.
- """
- query = f"PRAGMA table_info({relation})"
- df = self.execute_df(query=query, conn=conn)
- return list(df["name"])
-
- @transaction()
- def _get_primary_keys(
- self, relation: str, conn: Optional[sqlite3.Connection] = None
- ) -> List[str]:
- """
- get the primary key of a table.
- """
- query = f"PRAGMA table_info({relation})"
- df = self.execute_df(query=query, conn=conn)
- return list(df[df["pk"].apply(int) == 1]["name"])
-
- @transaction()
- def insert(
- self, relation: str, ta: pa.Table, conn: Optional[sqlite3.Connection] = None
- ):
- """
- Append rows to a table
- """
- df = ta.to_pandas()
- if df.empty:
- return
- columns = df.columns.tolist()
- col_str = ", ".join([f'"{col}"' for col in columns])
- placeholder_str = ",".join(["?" for _ in columns])
- query = f"INSERT INTO {relation}({col_str}) VALUES ({placeholder_str})"
- parameters = list(df.itertuples(index=False))
- conn.executemany(query, parameters)
-
- @transaction()
- def upsert(
- self,
- relation: str,
- ta: pa.Table,
- key_cols: Optional[List[str]] = None,
- conn: Optional[sqlite3.Connection] = None,
- ):
- """
- Upsert rows in a table based on primary key.
- """
- if isinstance(ta, pa.Table):
- df = ta.to_pandas()
- else:
- df = ta
- if df.empty: # engine complains
- return
- columns = df.columns.tolist()
- col_str = ", ".join([f'"{col}"' for col in columns])
- placeholder_str = ",".join(["?" for _ in columns])
- query = (
- f"INSERT OR REPLACE INTO {relation}({col_str}) VALUES ({placeholder_str})"
- )
- parameters = list(df.itertuples(index=False))
- conn.executemany(query, parameters)
-
- @transaction()
- def delete(
- self,
- relation: str,
- where_col: str,
- where_values: List[str],
- conn: Optional[sqlite3.Connection] = None,
- ):
- """
- Delete rows from a table where `where_col` is in `where_values`
- """
- query = f"DELETE FROM {relation} WHERE {where_col} IN ({','.join(['?']*len(where_values))})"
- conn.execute(query, where_values)
-
- @transaction()
- def vacuum(self, warn: bool = True, conn: Optional[sqlite3.Connection] = None):
- """
- ! this needs a lot of free space on disk to work (~2x db size)
- """
- if warn:
- total_db_size = (
- self.execute_df("PRAGMA page_count").astype(float).iloc[0, 0]
- * self.page_size
- )
- question = "Vacuuming a database of size {:.2f} MB, this may take a long time and requires ~2x as much free space on disk, are you sure?".format(
- total_db_size / 1024**2
- )
- # ask the user if they are sure
- user_input = input(question + " (y/n): ")
- if user_input.lower() != "y":
- logging.info("Aborting vacuuming.")
- return
- conn.execute("VACUUM")
-
- @transaction()
- def execute_df(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[sqlite3.Connection] = None,
- ) -> pd.DataFrame:
- if parameters is None:
- parameters = []
- cursor = conn.execute(str(query), parameters)
- return self.postprocess_df(self.read_cursor(cursor))
-
- @transaction()
- def execute_arrow(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[sqlite3.Connection] = None,
- ) -> pa.Table:
- if isinstance(query, Query):
- query = str(query)
- df = self.execute_df(query=query, parameters=parameters, conn=conn)
- return pa.Table.from_pandas(df)
-
- @transaction()
- def execute_no_results(
- self,
- query: Union[str, Query],
- parameters: List[Any] = None,
- conn: Optional[sqlite3.Connection] = None,
- ) -> None:
- if parameters is None:
- parameters = []
- if not isinstance(query, str):
- query = str(query)
- return conn.execute(query, parameters)
diff --git a/mandala/storages/rel_impls/utils.py b/mandala/storages/rel_impls/utils.py
deleted file mode 100644
index 54bf7a9..0000000
--- a/mandala/storages/rel_impls/utils.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from ...common_imports import *
-from ...core.config import Config
-from abc import ABC, abstractmethod
-import functools
-
-if Config.has_duckdb:
- from duckdb import DuckDBPyConnection as Connection
-else:
-
- class Connection:
- pass
-
-
-class Transactable(ABC):
- @abstractmethod
- def _get_connection(self) -> Connection:
- raise NotImplementedError()
-
- @abstractmethod
- def _end_transaction(self, conn: Connection):
- raise NotImplementedError()
-
-
-class Transaction:
- def __init__(self):
- pass
-
- def __call__(self, method) -> "method":
- @functools.wraps(method)
- def inner(instance: Transactable, *args, conn: Connection = None, **kwargs):
- transaction_started_here = False
- if conn is None:
- # new transaction
- conn = instance._get_connection()
- instance._current_conn = conn
- transaction_started_here = True
- # conn.execute("BEGIN IMMEDIATE")
- try:
- result = method(instance, *args, conn=conn, **kwargs)
- if transaction_started_here:
- instance._end_transaction(conn=conn)
- return result
- except Exception as e:
- if transaction_started_here:
- conn.rollback()
- instance._end_transaction(conn=conn)
- raise e
- # instance._end_transaction(conn=conn)
- # return result
- # else:
- # # nest in existing transaction
- # result = method(instance, *args, conn=conn, **kwargs)
- # return result
-
- return inner
-
-
-transaction = Transaction
diff --git a/mandala/storages/rels.py b/mandala/storages/rels.py
deleted file mode 100644
index 122bd0c..0000000
--- a/mandala/storages/rels.py
+++ /dev/null
@@ -1,1351 +0,0 @@
-from collections import defaultdict
-
-import pyarrow as pa
-import pyarrow.parquet as pq
-from pypika import Query, Table, Parameter
-from pypika.terms import LiteralValue
-
-from ..common_imports import *
-from ..core.config import Config, dump_output_name, Provenance
-from ..core.model import Call, FuncOp, Ref, ValueRef, collect_detached
-from ..core.builtins_ import Builtins
-from ..core.wrapping import unwrap
-from ..core.utils import get_fibers_as_lists
-from ..core.sig import Signature
-from ..deps.versioner import Versioner
-from ..utils import serialize, deserialize, _rename_cols
-
-if Config.has_duckdb:
- from .rel_impls.duckdb_impl import DuckDBRelStorage
-from .rel_impls.utils import Transactable, transaction, Connection
-from .rel_impls.bases import RelStorage
-
-
-# {internal table name -> serialized (internally named) table}
-RemoteEventLogEntry = Dict[str, bytes]
-
-
-class SpilledRef:
- pass
-
-
-class VersionAdapter(Transactable):
- # todo: this is too similar to SignatureAdapter, refactor
- # like SignatureAdapter, but for dependency state.
- # encapsulates methods to load and write the dependency table
- def __init__(self, rel_adapter: "RelAdapter"):
- self.rel_adapter = rel_adapter
- self.rel_storage = rel_adapter.rel_storage
-
- ############################################################################
- ### `Transactable` interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return self.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_storage._end_transaction(conn=conn)
-
- ###
- @transaction()
- def dump_state(
- self,
- state: Versioner,
- conn: Optional[Connection] = None,
- ):
- """
- Dump the given state of the signatures to the database. Should always
- call this after the signatures have been updated.
- """
- # delete existing, if any
- index_col = "index"
- query = f'DELETE FROM {self.rel_adapter.DEPS_TABLE} WHERE "{index_col}" = 0'
- conn.execute(query)
- # insert new
- serialized = serialize(obj=state)
- df = pd.DataFrame(
- {
- index_col: [0],
- "deps": [serialized],
- }
- )
- ta = pa.Table.from_pandas(df)
- self.rel_storage.insert(relation=self.rel_adapter.DEPS_TABLE, ta=ta, conn=conn)
-
- @transaction()
- def has_state(self, conn: Optional[Connection] = None) -> bool:
- query = f'SELECT * FROM {self.rel_adapter.DEPS_TABLE} WHERE "index" = 0'
- df = self.rel_storage.execute_df(query=query, conn=conn)
- return len(df) != 0
-
- @transaction()
- def load_state(self, conn: Optional[Connection] = None) -> Optional[Versioner]:
- """
- Load the state of the signatures from the database. All interactions
- with the state of the signatures are done transactionally through this
- method.
- """
- query = f'SELECT * FROM {self.rel_adapter.DEPS_TABLE} WHERE "index" = 0'
- df = self.rel_storage.execute_df(query=query, conn=conn)
- if len(df) == 0:
- return None
- else:
- return deserialize(df["deps"][0])
-
-
-class SigAdapter(Transactable):
- """
- Responsible for state transitions of the schema that update the
- signature objects *and* the relational tables in a transactional way.
- """
-
- def __init__(
- self,
- rel_adapter: "RelAdapter",
- ):
- self.rel_adapter = rel_adapter
- self.rel_storage = self.rel_adapter.rel_storage
-
- @transaction()
- def dump_state(
- self, state: Dict[Tuple[str, int], Signature], conn: Optional[Connection] = None
- ):
- """
- Dump the given state of the signatures to the database. Should always
- call this after the signatures have been updated.
- """
- # delete existing, if any
- index_col = "index"
- query = (
- f'DELETE FROM {self.rel_adapter.SIGNATURES_TABLE} WHERE "{index_col}" = 0'
- )
- conn.execute(query)
- # insert new
- serialized = serialize(obj=state)
- df = pd.DataFrame(
- {
- index_col: [0],
- "signatures": [serialized],
- }
- )
- ta = pa.Table.from_pandas(df)
- self.rel_storage.insert(
- relation=self.rel_adapter.SIGNATURES_TABLE, ta=ta, conn=conn
- )
-
- @transaction()
- def has_state(self, conn: Optional[Connection] = None) -> bool:
- query = f'SELECT * FROM {self.rel_adapter.SIGNATURES_TABLE} WHERE "index" = 0'
- df = self.rel_storage.execute_df(query=query, conn=conn)
- return len(df) != 0
-
- @transaction()
- def load_state(
- self, conn: Optional[Connection] = None
- ) -> Dict[Tuple[str, int], Signature]:
- """
- Load the state of the signatures from the database. All interactions
- with the state of the signatures are done transactionally through this
- method.
- """
- query = f'SELECT * FROM {self.rel_adapter.SIGNATURES_TABLE} WHERE "index" = 0'
- df = self.rel_storage.execute_df(query=query, conn=conn)
- if len(df) == 0:
- return {}
- else:
- return deserialize(df["signatures"][0])
-
- def check_invariants(
- self,
- sigs: Optional[Dict[Tuple[str, int], Signature]] = None,
- conn: Optional[Connection] = None,
- ):
- """
- This checks that the invariants of the *set* of signatures for the storage
- hold. This means that:
- - all versions of a signature are consecutive integers starting from
- 0
- - signatures have the same UI name iff they have the same internal
- name
-
- Invariants for individual signatures are not checked by this (they are
- checked by the `Signature` class).
- """
- # check version numbering
- if sigs is None:
- sigs = self.load_state(conn=conn)
- internal_names = {internal_name for internal_name, _ in sigs.keys()}
- for internal_name in internal_names:
- versions = [version for _, version in sigs.keys() if _ == internal_name]
- assert sorted(versions) == list(range(len(versions)))
- # check exactly 1 UI name per internal name
- internal_to_ui_names = defaultdict(set)
- for (internal_name, _), sig in sigs.items():
- internal_to_ui_names[internal_name].add(sig.ui_name)
- for internal_name, ui_names in internal_to_ui_names.items():
- assert (
- len(ui_names) == 1
- ), f"Internal name {internal_name} has multiple UI names: {ui_names}"
- # check exactly 1 internal name per UI name
- ui_to_internal_names = defaultdict(set)
- for (internal_name, _), sig in sigs.items():
- ui_to_internal_names[sig.ui_name].add(internal_name)
- for ui_name, internal_names in ui_to_internal_names.items():
- assert (
- len(internal_names) == 1
- ), f"UI name {ui_name} has multiple internal names: {internal_names}"
-
- ############################################################################
- ### `Transactable` interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return self.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_storage._end_transaction(conn=conn)
-
- ############################################################################
- ###
- ############################################################################
- @transaction()
- def load_ui_sigs(
- self, conn: Optional[Connection] = None
- ) -> Dict[Tuple[str, int], Signature]:
- """
- Get the signatures indexed by (ui_name, version)
- """
- sigs = self.load_state(conn=conn)
- res = {(sig.ui_name, sig.version): sig for sig in sigs.values()}
- assert len(res) == len(sigs)
- return res
-
- @transaction()
- def exists_versioned_ui(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> bool:
- """
- Check if the signature exists based on its UI name *and* version
- """
- return (sig.ui_name, sig.version) in self.load_ui_sigs(conn=conn)
-
- @transaction()
- def exists_any_version(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> bool:
- """
- Check using internal name (or UI name, if it has no internal data) if
- there exists any version for this signature.
- """
- if sig.has_internal_data:
- return any(
- sig.internal_name == k[0] for k in self.load_state(conn=conn).keys()
- )
- else:
- return any(sig.ui_name == k[0] for k in self.load_ui_sigs(conn=conn).keys())
-
- @transaction()
- def get_latest_version(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> Signature:
- """
- Get the latest version of the signature, based on internal name or UI
- name if it has no internal data.
- """
- sigs = self.load_state(conn=conn)
- if sig.has_internal_data:
- versions = [k[1] for k in sigs.keys() if k[0] == sig.internal_name]
- if len(versions) == 0:
- raise ValueError(f"No versions for signature {sig}")
- version = max(versions)
- return sigs[(sig.internal_name, version)]
- else:
- versions = [
- k[1] for k in self.load_ui_sigs(conn=conn).keys() if k[0] == sig.ui_name
- ]
- if len(versions) == 0:
- raise ValueError(f"No versions for signature {sig}")
- version = max(versions)
- return self.load_ui_sigs(conn=conn)[(sig.ui_name, version)]
-
- @transaction()
- def get_versions(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> List[int]:
- """
- Get all versions of the signature, based on internal name or UI name if
- it has no internal data.
- """
- sigs = self.load_state(conn=conn)
- if sig.has_internal_data:
- return [k[1] for k in sigs.keys() if k[0] == sig.internal_name]
- else:
- ui_sigs = self.load_ui_sigs(conn=conn)
- return [k[1] for k in ui_sigs.keys() if k[0] == sig.ui_name]
-
- @transaction()
- def exists_internal(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> bool:
- """
- Check if the signature exists based on its *internal* name
- """
- return (sig.internal_name, sig.version) in self.load_state(conn=conn)
-
- @transaction()
- def internal_to_ui(self, conn: Optional[Connection] = None) -> Dict[str, str]:
- """
- Get a mapping from internal names to UI names
- """
- return {k[0]: v.ui_name for k, v in self.load_state(conn=conn).items()}
-
- @transaction()
- def ui_to_internal(self, conn: Optional[Connection] = None) -> Dict[str, str]:
- """
- Get a mapping from UI names to internal names
- """
- return {v.ui_name: k[0] for k, v in self.load_state(conn=conn).items()}
-
- @transaction()
- def ui_names(self, conn: Optional[Connection] = None) -> Set[str]:
- # return the set of ui names
- return set(self.ui_to_internal(conn=conn).keys())
-
- @transaction()
- def internal_names(self, conn: Optional[Connection] = None) -> Set[str]:
- # return the set of internal names
- return set(self.internal_to_ui(conn=conn).keys())
-
- @transaction()
- def is_sig_table_name(
- self, name: str, use_internal: bool, conn: Optional[Connection] = None
- ) -> bool:
- """
- Check if the name is a valid name for a table corresponding to a
- signature.
- """
- parts = name.split("_", 1)
- return (
- parts[0]
- in (
- self.internal_names(conn=conn)
- if use_internal
- else self.ui_names(conn=conn)
- )
- and parts[1].isdigit()
- )
-
- ############################################################################
- ### elementary transitions for local state
- ############################################################################
- @transaction()
- def _init_deps(self, sig: Signature, conn: Optional[Connection] = None):
- pass
- # deps = self.deps_adapter.load_state(conn=conn)
- # deps.op_graphs[(sig.internal_name, sig.version)] = DependencyGraph()
- # self.deps_adapter.dump_state(state=deps, conn=conn)
-
- @transaction()
- def _create_relation(self, sig: Signature, conn: Optional[Connection] = None):
- io_colnames = list(sig.input_names) + [
- dump_output_name(index=i) for i in range(sig.n_outputs)
- ]
- all_cols = [(col, None) for col in Config.special_call_cols] + [
- (column, None) for column in io_colnames
- ]
- self.rel_storage.create_relation(
- name=sig.versioned_ui_name,
- columns=all_cols,
- primary_key=Config.causal_uid_col,
- defaults=sig.new_ui_input_default_uids,
- conn=conn,
- )
- self._init_deps(sig=sig, conn=conn)
-
- @transaction()
- def create_sig(self, sig: Signature, conn: Optional[Connection] = None):
- """
- Create a new signature `sig`. `sig` must have internal data, and not be
- present in storage at any version.
- """
- assert sig.has_internal_data
- sigs = self.load_state(conn=conn)
- assert sig.internal_name not in self.internal_names(conn=conn)
- # assert (sig.internal_name, sig.version) not in sigs.keys()
- sigs[(sig.internal_name, sig.version)] = sig
- # write signatures
- self.dump_state(state=sigs, conn=conn)
- # create relation
- self._create_relation(sig=sig, conn=conn)
- logger.debug(f"Created signature:\n{sig}")
-
- @transaction()
- def create_new_version(self, sig: Signature,
- strict: bool = False,
- conn: Optional[Connection] = None):
- """
- Create a new version of an already existing function using the `sig`
- object. `sig` must have internal data, and the internal name must
- already be present in some version.
- """
- assert sig.has_internal_data
- latest_sig = self.get_latest_version(sig=sig, conn=conn)
- if strict:
- assert sig.version == latest_sig.version + 1
- # update signatures
- sigs = self.load_state(conn=conn)
- sigs[(sig.internal_name, sig.version)] = sig
- self.dump_state(state=sigs, conn=conn)
- # create relation
- self._create_relation(sig=sig, conn=conn)
- logger.debug(f"Created new version:\n{sig}")
-
- @transaction()
- def update_sig(self, sig: Signature, conn: Optional[Connection] = None):
- """
- Update an existing signature. `sig` must have internal data, and
- must already exist in storage.
- """
- assert sig.has_internal_data
- sigs = self.load_state(conn=conn)
- assert (sig.internal_name, sig.version) in sigs.keys()
- current = sigs[(sig.internal_name, sig.version)]
- # the `update` method also ensures that the signature is compatible
- n_outputs_new = sig.n_outputs
- n_outputs_old = current.n_outputs
- new_sig, updates = current.update(new=sig)
- # update the signature data
- sigs[(sig.internal_name, sig.version)] = new_sig
- # create new inputs in the database, if any
- for new_input, default_value in updates.items():
- internal_input_name = new_sig.ui_to_internal_input_map[new_input]
- full_default_uid = new_sig._new_input_defaults_uids[internal_input_name]
- self.rel_storage.create_column(
- relation=new_sig.versioned_ui_name,
- name=new_input,
- default_value=full_default_uid,
- conn=conn,
- )
- # insert the default in the objects *in the database*, if it's
- # not there already
- default_uid, _ = Ref.parse_full_uid(full_uid=full_default_uid)
- default_vref = ValueRef(uid=default_uid, obj=default_value, in_memory=True)
- self.rel_adapter.obj_set(uid=default_uid, value=default_vref, conn=conn)
- self.rel_adapter.obj_set_causal(full_uid=full_default_uid, conn=conn)
- # update the outputs in the database, if this is allowed
- n_rows = self.rel_storage.get_count(table=new_sig.versioned_ui_name, conn=conn)
- if n_rows > 0 and n_outputs_new != n_outputs_old:
- raise ValueError(
- f"Cannot change the number of outputs of a signature that has already been used. "
- f"Current number of outputs: {n_outputs_old}, new number of outputs: {n_outputs_new}."
- )
- if n_outputs_new > n_outputs_old:
- for i in range(n_outputs_old, n_outputs_new):
- self.rel_storage.create_column(
- relation=new_sig.versioned_ui_name,
- name=dump_output_name(index=i),
- default_value=None,
- conn=conn,
- )
- if n_outputs_new < n_outputs_old:
- for i in range(n_outputs_new, n_outputs_old):
- self.rel_storage.drop_column(
- relation=new_sig.versioned_ui_name,
- name=dump_output_name(index=i),
- conn=conn,
- )
- if len(updates) > 0:
- logger.debug(
- f"Updated signature:\n new inputs:{updates} new signature:\n {sig}"
- )
- self.dump_state(state=sigs, conn=conn)
-
- @transaction()
- def update_ui_name(
- self,
- sig: Signature,
- conn: Optional[Connection] = None,
- validate_only: bool = False,
- ) -> Dict[Tuple[str, int], Signature]:
- """
- Update a signature's UI name using the given `Signature` object to get
- the new name. `sig` must have internal data, and must carry the new UI
- name.
-
- NOTE: the `sig` may have the same UI name as the current signature, in
- which case this method does nothing but return the current state of the
- signatures.
-
- This method has the option of only generating the new state of the
- signatures without performing the update.
- """
- assert sig.has_internal_data
- assert self.exists_internal(sig=sig, conn=conn)
- sigs = self.load_state(conn=conn)
- all_versions = self.get_versions(sig=sig, conn=conn)
- current_ui_name = sigs[(sig.internal_name, all_versions[0])].ui_name
- new_ui_name = sig.ui_name
- if current_ui_name == new_ui_name:
- # nothing to do
- return sigs
- # make sure there are no conflicts
- internal_to_ui = self.internal_to_ui(conn=conn)
- if new_ui_name in internal_to_ui.values():
- raise ValueError(
- f"UI name {new_ui_name} already exists for another signature."
- )
- for version in all_versions:
- current = sigs[(sig.internal_name, version)]
- if current.ui_name != sig.ui_name:
- new_sig = current.rename(new_name=sig.ui_name)
- # update signature object in memory
- sigs[(sig.internal_name, version)] = new_sig
- if not validate_only:
- # update table
- self.rel_storage.rename_relation(
- name=current.versioned_ui_name,
- new_name=new_sig.versioned_ui_name,
- conn=conn,
- )
- if not validate_only:
- # update signatures state
- self.dump_state(state=sigs, conn=conn)
- if current_ui_name != new_ui_name:
- logger.debug(
- f"Updated UI name of signature: from {current_ui_name} to {new_ui_name}"
- )
- return sigs
-
- @transaction()
- def update_input_ui_names(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> Signature:
- """
- Update a signature's input UI names from the given `Signature` object.
- `sig` must have internal data, and must carry the new UI input names.
- """
- assert sig.has_internal_data
- sigs = self.load_state(conn=conn)
- current = sigs[(sig.internal_name, sig.version)]
- current_internal_to_ui = current.internal_to_ui_input_map
- new_internal_to_ui = sig.internal_to_ui_input_map
- renaming_map = {
- current_internal_to_ui[k]: new_internal_to_ui[k]
- for k in current_internal_to_ui
- if current_internal_to_ui[k] != new_internal_to_ui[k]
- }
- # update signature object
- new_sig = current.rename_inputs(mapping=renaming_map)
- sigs[(sig.internal_name, sig.version)] = new_sig
- self.dump_state(state=sigs, conn=conn)
- # update table columns
- self.rel_storage.rename_columns(
- relation=new_sig.versioned_ui_name, mapping=renaming_map, conn=conn
- )
- if len(renaming_map) > 0:
- logger.debug(
- f"Updated input UI names of signature named {sig.ui_name}: via mapping {renaming_map}"
- )
- return new_sig
-
- @transaction()
- def rename_tables(
- self,
- tables: Dict[str, TableType],
- to: str = "internal",
- conn: Optional[Connection] = None,
- ) -> Dict[str, TableType]:
- """
- Rename a dictionary of {versioned name: table} pairs and the tables'
- columns to either internal or UI names.
- """
- result = {}
- assert to in ["internal", "ui"]
- for table_name, table in tables.items():
- if self.is_sig_table_name(
- name=table_name, use_internal=(to != "internal"), conn=conn
- ):
- if to == "internal":
- ui_name, version = Signature.parse_versioned_name(table_name)
- sig = self.load_ui_sigs(conn=conn)[ui_name, version]
- new_table_name = sig.versioned_internal_name
- mapping = sig.ui_to_internal_input_map
- else:
- internal_name, version = Signature.parse_versioned_name(table_name)
- sig = self.load_state(conn=conn)[internal_name, version]
- new_table_name = sig.versioned_ui_name
- mapping = sig.internal_to_ui_input_map
- result[new_table_name] = _rename_cols(table=table, mapping=mapping)
- else:
- result[table_name] = table
- return result
-
-
-class RelAdapter(Transactable):
- EVENT_LOG_TABLE = Config.event_log_table
- VREF_TABLE = Config.vref_table
- CAUSAL_VREF_TABLE = Config.causal_vref_table
- SIGNATURES_TABLE = Config.schema_table
- DEPS_TABLE = Config.deps_table
- PROVENANCE_TABLE = Config.provenance_table
- # tables to be excluded from certain operations
- SPECIAL_TABLES = (
- EVENT_LOG_TABLE,
- Config.temp_arrow_table,
- SIGNATURES_TABLE,
- DEPS_TABLE,
- PROVENANCE_TABLE,
- )
-
- def __init__(
- self,
- rel_storage: RelStorage,
- spillover_dir: Optional[Path] = None,
- spillover_threshold_mb: Optional[float] = None,
- ):
- self.rel_storage = rel_storage
- self.spillover_dir = spillover_dir
- self.spillover_threshold_mb = (
- Config.spillover_threshold_mb
- if spillover_threshold_mb is None
- else spillover_threshold_mb
- )
- if self.spillover_dir is not None:
- self.spillover_dir.mkdir(parents=True, exist_ok=True)
- self.sig_adapter = SigAdapter(rel_adapter=self)
- self.init()
- # check if we are connecting to an existing instance
- conn = self._get_connection()
- if not self.sig_adapter.has_state(conn=conn):
- self.sig_adapter.dump_state(state={}, conn=conn)
- self._end_transaction(conn=conn)
-
- @transaction()
- def init(self, conn: Optional[Connection] = None):
- if self.rel_storage._read_only:
- return
- self.rel_storage.create_relation(
- name=self.VREF_TABLE,
- columns=[(Config.uid_col, None), (Config.vref_value_col, "blob")],
- primary_key=Config.uid_col,
- defaults={},
- if_not_exists=True,
- conn=conn,
- )
- self.rel_storage.create_relation(
- name=self.CAUSAL_VREF_TABLE,
- columns=[(Config.full_uid_col, None)],
- primary_key=Config.full_uid_col,
- defaults={},
- if_not_exists=True,
- conn=conn,
- )
- self.rel_storage.create_relation(
- name=self.PROVENANCE_TABLE,
- columns=[
- (Provenance.causal_uid, None),
- (Provenance.name, None),
- (Provenance.call_causal_uid, None),
- (Provenance.direction, None),
- (Provenance.op_id, None),
- ],
- primary_key=[Provenance.call_causal_uid, Provenance.name],
- defaults={},
- if_not_exists=True,
- conn=conn,
- )
- # Initialize the event log.
- # The event log is just a list of UIDs that changed, for now.
- # the UID column stores the vref/call uid, the `table` column stores the
- # table in which this UID is to be found.
- self.rel_storage.create_relation(
- name=self.EVENT_LOG_TABLE,
- columns=[(Config.uid_col, None), ("table", "varchar")],
- primary_key=Config.uid_col,
- defaults={},
- if_not_exists=True,
- conn=conn,
- )
- # The signatures table is a binary dump of the signatures
- self.rel_storage.create_relation(
- name=self.SIGNATURES_TABLE,
- columns=[
- ("index", "int"),
- ("signatures", "blob"),
- ],
- primary_key="index",
- defaults={},
- if_not_exists=True,
- conn=conn,
- )
- self.rel_storage.create_relation(
- name=self.DEPS_TABLE,
- columns=[
- ("index", "int"),
- ("deps", "blob"),
- ],
- primary_key="index",
- defaults={},
- if_not_exists=True,
- conn=conn,
- )
-
- @transaction()
- def get_call_tables(self, conn: Optional[Connection] = None) -> List[str]:
- tables = self.rel_storage.get_tables(conn=conn)
- return [
- t
- for t in tables
- if t not in self.SPECIAL_TABLES
- and t != self.VREF_TABLE
- and t != self.CAUSAL_VREF_TABLE
- ]
-
- @transaction()
- def get_vrefs(self, conn: Optional[Connection] = None) -> pd.DataFrame:
- """
- Returns a dataframe of the deserialized values of the value references
- in the storage.
- """
- data = self.rel_storage.get_data(table=self.VREF_TABLE, conn=conn)
- data["value"] = data["value"].apply(lambda vref: unwrap(deserialize(vref)))
- return data
-
- @transaction()
- def get_causal_vrefs(self, conn: Optional[Connection] = None) -> pd.DataFrame:
- data = self.rel_storage.get_data(table=self.CAUSAL_VREF_TABLE, conn=conn)
- return data
-
- @transaction()
- def get_all_call_data(
- self, conn: Optional[Connection] = None
- ) -> Dict[str, pd.DataFrame]:
- """
- Return a dictionary of all the memoization tables (labeled by versioned
- ui name)
- """
- result = {}
- for table in self.get_call_tables(conn=conn):
- result[table] = self.rel_storage.get_data(table=table, conn=conn)
- return result
-
- ############################################################################
- ### event log stuff
- ############################################################################
- @transaction()
- def get_event_log(self, conn: Optional[Connection] = None) -> pd.DataFrame:
- return self.rel_storage.get_data(table=self.EVENT_LOG_TABLE, conn=conn)
-
- @transaction()
- def clear_event_log(self, conn: Optional[Connection] = None):
- event_log_table = Table(self.EVENT_LOG_TABLE)
- query = Query.from_(event_log_table).delete()
- self.rel_storage.execute_no_results(query=query, conn=conn)
-
- ############################################################################
- ### `Transactable` interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return self.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_storage._end_transaction(conn=conn)
-
- ############################################################################
- ### call methods
- ############################################################################
- @transaction()
- def _get_current_names(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> Tuple[str, Dict[str, str]]:
- """
- Given a possibly stale signature `sig`, return
- - the current ui name for this signature
- - a mapping of stale input names to their current values
- """
- current_sigs = self.sig_adapter.load_state(conn=conn)
- current_sig = current_sigs[sig.internal_name, sig.version]
- true_versioned_ui_name = current_sig.versioned_ui_name
- stale_to_true_input_mapping = {
- k: current_sig.internal_to_ui_input_map[v]
- for k, v in sig.ui_to_internal_input_map.items()
- # for k, v in current_sig.ui_to_internal_input_map.items()
- }
- return true_versioned_ui_name, stale_to_true_input_mapping
-
- @transaction()
- def tabulate_calls(
- self, calls: List[Call], conn: Optional[Connection] = None
- ) -> Dict[str, pa.Table]:
- """
- Converts call objects to a dictionary of {op UI name: table
- to upsert} pairs.
-
- Note that the calls can involve functions in different stages of
- staleness. This method can handle calls to many different variants of
- the same function (adding inputs, renaming the function or its inputs).
- To handle calls to stale functions, this passes through internal names
- to get the current UI names.
- """
- if not len(calls) == len(set([call.full_uid for call in calls])):
- # something fishy may be going on
- raise InternalError("Calls must have unique UIDs")
- # split by operation *internal* name to group calls to the same op in
- # the same group, even if UI names are different.
- calls_by_op = defaultdict(list)
- for call in calls:
- calls_by_op[call.func_op.sig.versioned_internal_name].append(call)
- res = {}
- for versioned_internal_name, calls in calls_by_op.items():
- rows = []
- true_sig = self.sig_adapter.load_state()[
- Signature.parse_versioned_name(versioned_internal_name)
- ]
- true_versioned_ui_name = None
- for call in calls:
- # it is necessary to process each call individually to properly
- # handle multiple stale variants of this op
- sig = call.func_op.sig
- # get the current state of this signature
- (
- true_versioned_ui_name,
- # stale UI input -> current UI input. This could vary across calls
- stale_to_true_input_mapping,
- ) = self._get_current_names(sig, conn=conn)
- # form the input UIDs
- input_uids = {
- stale_to_true_input_mapping[k]: v.full_uid
- for k, v in call.inputs.items()
- }
- # patch the input uids using the true signature. This is
- # necesary to do here because it seems duckdb has issues with
- # interpreting NaNs from pyarrow as nulls
- for k, v in true_sig.new_ui_input_default_uids.items():
- if k not in input_uids:
- input_uids[k] = v
- row = {
- Config.uid_col: call.uid,
- Config.causal_uid_col: call.causal_uid,
- Config.content_version_col: call.content_version,
- Config.semantic_version_col: call.semantic_version,
- Config.transient_col: call.transient,
- **input_uids,
- **{
- dump_output_name(index=i): v.full_uid
- for i, v in enumerate(call.outputs)
- },
- }
- rows.append(row)
- assert true_versioned_ui_name is not None
- res[true_versioned_ui_name] = pa.Table.from_pylist(rows)
- return res
-
- @transaction()
- def upsert_calls(self, calls: List[Call], conn: Optional[Connection] = None):
- """
- Upserts *detached* calls in the relational storage so that they will
- show up in declarative queries.
- """
- if len(calls) > 0: # avoid dealing with empty dataframes
- ### upsert full ref uids
- full_uids = set()
- for call in calls:
- full_uids.update({v.full_uid for v in call.inputs.values()})
- full_uids.update({v.full_uid for v in call.outputs})
- self.rel_storage.upsert(
- relation=self.CAUSAL_VREF_TABLE,
- ta=pa.Table.from_pydict(
- {
- Config.full_uid_col: list(full_uids),
- }
- ),
- conn=conn,
- )
- ### upsert calls
- for table_name, ta in self.tabulate_calls(calls).items():
- self.rel_storage.upsert(relation=table_name, ta=ta, conn=conn)
- # Write changes to the event log table
- self.rel_storage.upsert(
- relation=self.EVENT_LOG_TABLE,
- ta=pa.Table.from_pydict(
- {
- Config.uid_col: ta[Config.uid_col],
- "table": [table_name] * len(ta),
- }
- ),
- conn=conn,
- )
- ### upsert in provenance
- self.upsert_provenance(calls=calls, conn=conn)
-
- @transaction()
- def _query_call(
- self, uid: str, by_causal: bool, conn: Optional[Connection] = None
- ) -> pa.Table:
- # TODO: replace this by something more efficient
- col = Config.causal_uid_col if by_causal else Config.uid_col
- all_tables = [
- Query.from_(table_name)
- .where(Table(table_name)[col] == Parameter("$1"))
- .select(
- Table(table_name)[col],
- LiteralValue(f"'{table_name}'"),
- )
- for table_name in self.get_call_tables()
- ]
- query = "\nUNION\n".join([str(q) for q in all_tables])
- # query = sum(all_tables[1:], start=all_tables[0])
- return self.rel_storage.execute_arrow(query, [uid], conn=conn)
-
- @transaction()
- def call_exists(
- self, uid: str, by_causal: bool, conn: Optional[Connection] = None
- ) -> bool:
- return len(self._query_call(uid, by_causal=by_causal, conn=conn)) > 0
-
- @transaction()
- def call_get_lazy(
- self, uid: str, by_causal: bool, conn: Optional[Connection] = None
- ) -> Call:
- """
- Return the call with the inputs/outputs as lazy value references.
- """
- row = self._query_call(uid, by_causal=by_causal, conn=conn).take([0])
- col = Config.causal_uid_col if by_causal else Config.uid_col
- table_name = row.column(1)[0]
- table = Table(table_name)
- query = (
- Query.from_(table).where(table[col] == Parameter("$1")).select(table.star)
- )
- results = self.rel_storage.execute_arrow(query, [uid], conn=conn)
- # determine the signature for this call
- ui_name, version = Signature.parse_versioned_name(
- versioned_name=str(table_name)
- )
- sig = self.sig_adapter.load_ui_sigs(conn=conn)[ui_name, version]
- return Call.from_row(results, func_op=FuncOp._from_sig(sig=sig))
-
- @transaction()
- def mget_call_lazy(
- self,
- versioned_ui_name: str,
- uids: List[str],
- by_causal: bool = True,
- conn: Optional[Connection] = None,
- ) -> List[Call]:
- """
- Get many calls to the same op.
- """
- if not by_causal:
- raise NotImplementedError()
- table_name = versioned_ui_name
- table = Table(table_name)
- query = (
- Query.from_(table)
- .where(table[Config.causal_uid_col].isin(uids))
- .select(table.star)
- )
- results = self.rel_storage.execute_df(query, conn=conn).set_index(
- Config.causal_uid_col
- )
- ui_name, version = Signature.parse_versioned_name(
- versioned_name=str(table_name)
- )
- sig = self.sig_adapter.load_ui_sigs(conn=conn)[ui_name, version]
- calls_by_causal_uid = {}
- for causal_uid, row in results.iterrows():
- call_dict = dict(row)
- call_dict.update({Config.causal_uid_col: causal_uid})
- calls_by_causal_uid[causal_uid] = Call.from_row(
- call_dict, func_op=FuncOp._from_sig(sig=sig)
- )
- return [calls_by_causal_uid[uid] for uid in uids]
-
- ############################################################################
- ### object methods
- ############################################################################
- def _spillover_criterion(self, serialized: bytes) -> bool:
- return len(serialized) / 1024**2 > self.spillover_threshold_mb
-
- def _mset_spillover(self, uids: List[str], refs: List[Ref]):
- assert self.spillover_dir is not None
- for uid, ref in zip(uids, refs):
- dump_path = self.spillover_dir / f"{uid}.joblib"
- with open(dump_path, "wb") as f:
- joblib.dump(ref, f)
- if Config.warnings:
- size_in_mb = round(os.path.getsize(dump_path) / 10**6, 2)
- logger.warning(
- f"Spill over {ref.__repr__(shorten=True)} of size {size_in_mb}MB"
- )
-
- def _mget_spillover(self, uids: List[str]) -> List[Ref]:
- assert self.spillover_dir is not None
- values = []
- for uid in uids:
- try:
- with open(self.spillover_dir / f"{uid}.joblib", "rb") as f:
- values.append(joblib.load(f))
- except FileNotFoundError:
- if Config.warnings:
- logger.warning(f"Could not find spillover file for uid {uid}")
- values.append(Ref.from_uid(uid=uid))
- return values
-
- def _serialize_spillover(self, ref: Ref) -> bytes:
- s = serialize(ref)
- if self._spillover_criterion(s):
- substitute = Ref.from_uid(uid=ref.uid)
- substitute._obj = SpilledRef()
- self._mset_spillover([ref.uid], [ref])
- return serialize(substitute)
- else:
- return s
-
- def _deserialize_spillover(self, serialized: bytes) -> Ref:
- value = deserialize(serialized)
- if isinstance(value._obj, SpilledRef):
- return self._mget_spillover([value.uid])[0]
- else:
- return value
-
- @transaction()
- def obj_exists(
- self, uids: List[str], conn: Optional[Connection] = None
- ) -> List[bool]:
- if len(uids) == 0:
- return []
- table = Table(Config.vref_table)
- query = (
- Query.from_(table)
- .where(table[Config.uid_col].isin(uids))
- .select(table[Config.uid_col])
- )
- existing_uids = set(
- self.rel_storage.execute_df(query=query, conn=conn)[Config.uid_col]
- )
- return [uid in existing_uids for uid in uids]
-
- @transaction()
- def obj_set(
- self,
- uid: str,
- value: Ref,
- shallow: bool = True,
- conn: Optional[Connection] = None,
- ) -> None:
- self.obj_sets(vrefs={uid: value}, shallow=shallow, conn=conn)
-
- @transaction()
- def obj_set_causal(
- self,
- full_uid: str,
- conn: Optional[Connection] = None,
- ) -> None:
- self.rel_storage.upsert(
- relation=Config.causal_vref_table,
- ta=pd.DataFrame({Config.full_uid_col: [full_uid]}),
- conn=conn,
- )
-
- @transaction()
- def obj_sets(
- self,
- vrefs: Dict[str, Ref],
- shallow: bool = True,
- conn: Optional[Connection] = None,
- ) -> None:
- if not shallow:
- raise NotImplementedError()
- uids = list(vrefs.keys())
- indicators = self.obj_exists(uids, conn=conn)
- new_uids = [uid for uid, indicator in zip(uids, indicators) if not indicator]
- if self.spillover_dir is not None:
- serializer = self._serialize_spillover
- else:
- serializer = serialize
- serialized_refs = [serializer(vrefs[new_uid].dump()) for new_uid in new_uids]
- # to get around the 2GB limit of pyarrow
- ta = pd.DataFrame(
- {
- Config.uid_col: new_uids,
- "value": serialized_refs,
- }
- )
- self.rel_storage.upsert(relation=Config.vref_table, ta=ta, conn=conn)
- log_ta = pa.Table.from_pylist(
- [
- {Config.uid_col: new_uid, "table": Config.vref_table}
- for new_uid in new_uids
- ]
- )
- self.rel_storage.upsert(relation=self.EVENT_LOG_TABLE, ta=log_ta, conn=conn)
-
- @transaction()
- def obj_gets(
- self,
- uids: List[str],
- depth: Optional[int] = None,
- _attach_atoms: bool = True,
- conn: Optional[Connection] = None,
- ) -> List[Ref]:
- """
- Returns a list of value references for the given uids.
-
- Note that causal UIDs are not set for these values.
- """
- if len(uids) == 0:
- return []
- if depth == 0:
- return [Ref.from_uid(uid=uid) for uid in uids]
- elif depth == 1:
- table = Table(Config.vref_table)
- if not _attach_atoms:
- query_uids = [uid for uid in uids if Builtins.is_builtin_uid(uid=uid)]
- else:
- query_uids = uids
- if len(query_uids) > 0:
- query = (
- Query.from_(table)
- .where(table[Config.uid_col].isin(query_uids))
- .select(table[Config.uid_col], table["value"])
- )
- output = self.rel_storage.execute_df(query, conn=conn)
- if self.spillover_dir is not None:
- deserializer = self._deserialize_spillover
- else:
- deserializer = deserialize
- output["value"] = output["value"].map(lambda x: deserializer(bytes(x)))
- else:
- output = pd.DataFrame(columns=[Config.uid_col, "value"])
- if not _attach_atoms:
- atoms_df = pd.DataFrame(
- {
- Config.uid_col: [
- uid for uid in uids if not Builtins.is_builtin_uid(uid=uid)
- ],
- "value": None,
- }
- )
- atoms_df["value"] = atoms_df[Config.uid_col].map(
- lambda x: Ref.from_uid(uid=x)
- )
- output = pd.concat([output, atoms_df])
- return output.set_index(Config.uid_col).loc[uids, "value"].tolist()
- elif depth is None:
- results = [Ref.from_uid(uid=uid) for uid in uids]
- self.mattach(
- vrefs=results, shallow=False, _attach_atoms=_attach_atoms, conn=conn
- )
- return results
-
- @transaction()
- def obj_get(
- self,
- uid: str,
- depth: Optional[int] = None,
- _attach_atoms: bool = True,
- conn: Optional[Connection] = None,
- ) -> Ref:
- vref_option = self.obj_gets(
- uids=[uid], depth=depth, _attach_atoms=_attach_atoms, conn=conn
- )[0]
- if vref_option is None:
- raise ValueError(f"Ref with uid {uid} does not exist")
- return vref_option
-
- @transaction()
- def mattach(
- self,
- vrefs: List[Ref],
- shallow: bool = False,
- _attach_atoms: bool = True,
- conn: Optional[Connection] = None,
- ) -> None:
- """
- In-place attach objects. If `shallow`, only attach the next level;
- otherwise, attach until leaf nodes.
-
- Note that some objects may already be attached.
- """
- ### pass to the vrefs that need to be attached
- detached_vrefs = collect_detached(refs=vrefs, include_transient=False)
- vrefs = detached_vrefs
- ### group the vrefs by uid
- vrefs_by_uid = get_fibers_as_lists(mapping={vref: vref.uid for vref in vrefs})
- unique_uids = list(vrefs_by_uid.keys())
- ### load one level of the unique vrefs
- vals = self.obj_gets(
- uids=unique_uids, depth=1, _attach_atoms=_attach_atoms, conn=conn
- ) #! this can be optimized
- for i, uid in enumerate(unique_uids):
- for obj in vrefs_by_uid[uid]:
- if vals[i].in_memory:
- obj.attach(reference=vals[i])
- if not shallow:
- residues = collect_detached(refs=vals, include_transient=False)
- if len(residues) > 0:
- self.mattach(vrefs=residues, shallow=False, conn=conn)
-
- ############################################################################
- ### provenance methods
- ############################################################################
- @transaction()
- def upsert_provenance(self, calls: List[Call], conn: Optional[Connection] = None):
- rows = []
- for call in calls:
- call_causal = call.causal_uid
- for name, inp in call.inputs.items():
- rows.append(
- {
- Provenance.causal_uid: inp.causal_uid,
- Provenance.name: name,
- Provenance.call_causal_uid: call_causal,
- Provenance.direction: "input",
- Provenance.op_id: call.func_op.sig.versioned_internal_name,
- }
- )
- for i, output in enumerate(call.outputs):
- name = dump_output_name(index=i)
- rows.append(
- {
- Provenance.causal_uid: output.causal_uid,
- Provenance.name: name,
- Provenance.call_causal_uid: call_causal,
- Provenance.direction: "output",
- Provenance.op_id: call.func_op.sig.versioned_internal_name,
- }
- )
- df = pd.DataFrame(rows)
- self.rel_storage.upsert(relation=Config.provenance_table, ta=df, conn=conn)
-
- ############################################################################
- ### deletion
- ############################################################################
- @transaction()
- def delete_causal_refs(
- self, full_uids: List[str], conn: Optional[Connection] = None
- ):
- self.rel_storage.delete(
- relation=Config.causal_vref_table,
- where_col=Config.full_uid_col,
- where_values=full_uids,
- conn=conn,
- )
-
- @transaction()
- def delete_refs(
- self, uids: List[str], verbose: bool = False, conn: Optional[Connection] = None
- ):
- self.rel_storage.delete(
- relation=Config.vref_table,
- where_col=Config.uid_col,
- where_values=uids,
- conn=conn,
- )
- full_uids = set(
- self.rel_storage.execute_df(
- query=f"SELECT {Config.full_uid_col} FROM {Config.causal_vref_table}",
- conn=conn,
- )[Config.full_uid_col].tolist()
- )
- full_uids_to_delete = {
- full_uid for full_uid in full_uids if full_uid.rsplit(".", 1)[0] in uids
- }
- if verbose:
- logger.info(f"Deleting {len(full_uids_to_delete)} causal refs")
- self.delete_causal_refs(full_uids=list(full_uids_to_delete), conn=conn)
- spillover_uids = self._get_spillover_uids(uids=uids)
- for uid in spillover_uids:
- (self.spillover_dir / f"{uid}.joblib").unlink(missing_ok=True)
-
- @transaction()
- def cleanup_vrefs(self, verbose: bool = False, conn: Optional[Connection] = None):
- """
- Delete all value references that are not referenced by any call.
- """
- # collect all the full uids referenced by calls
- all_referenced_uids = set()
- for table in self.get_call_tables(conn=conn):
- df = self.rel_storage.get_data(table=table, conn=conn)
- for col in [
- col for col in df.columns if col not in Config.special_call_cols
- ]:
- all_referenced_uids.update(
- {x.rsplit(".", 1)[0] for x in df[col].unique()}
- )
- # delete the unreferenced full uids
- stored_uids = set(
- self.rel_storage.execute_df(
- query=f"SELECT {Config.uid_col} FROM {Config.vref_table}", conn=conn
- )[Config.uid_col].tolist()
- )
- unreferenced_uids = stored_uids - all_referenced_uids
- if verbose:
- spillover_uids = self._get_spillover_uids(uids=list(unreferenced_uids))
- total_memory_usage = 0
- for spillover_uid in spillover_uids:
- size_in_mb = round(
- os.path.getsize(self.spillover_dir / f"{spillover_uid}.joblib")
- / 1024**2,
- 2,
- )
- total_memory_usage += size_in_mb
- total_memory_usage += self.get_vref_memory_usage(
- uids=list(unreferenced_uids - spillover_uids), conn=conn
- )
- logger.info(
- f"Deleting {len(unreferenced_uids)} unreferenced value refs that take up {total_memory_usage}MB"
- )
- self.delete_refs(uids=list(unreferenced_uids), conn=conn)
-
- def _get_spillover_uids(self, uids: Iterable[str]) -> Set[str]:
- if self.spillover_dir is None:
- return set()
- spillover_uids = set()
- for uid in uids:
- if (self.spillover_dir / f"{uid}.joblib").exists():
- spillover_uids.add(uid)
- return spillover_uids
-
- @transaction()
- def delete_calls(
- self,
- versioned_ui_name: str,
- causal_uids: List[str],
- conn: Optional[Connection] = None,
- ):
- # delete from memoization table
- self.rel_storage.delete(
- relation=versioned_ui_name,
- where_col=Config.causal_uid_col,
- where_values=causal_uids,
- conn=conn,
- )
- # delete from provenance table
- self.rel_storage.delete(
- relation=Config.provenance_table,
- where_col=Provenance.call_causal_uid,
- where_values=causal_uids,
- conn=conn,
- )
-
- @transaction()
- def get_vref_memory_usage(
- self, uids: List[str], units: str = "MB", conn: Optional[Connection] = None
- ) -> float:
- if not uids:
- return 0
- if units != "MB":
- raise NotImplementedError
- if len(uids) == 1:
- in_clause = f"('{uids[0]}')"
- else:
- in_clause = str(tuple(uids))
- query = (
- f"SELECT length({Config.vref_value_col}) AS size_in_bytes FROM {Config.vref_table} WHERE __uid__ IN "
- + in_clause
- )
- df = self.rel_storage.execute_df(query=query, conn=conn)
- total_size_in_bytes = df["size_in_bytes"].astype(int).sum()
- size_in_mb = round(total_size_in_bytes / 1024**2, 2)
- return size_in_mb
diff --git a/mandala/storages/remote_impls/__init__.py b/mandala/storages/remote_impls/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/mandala/storages/remote_impls/mongo_impl.py b/mandala/storages/remote_impls/mongo_impl.py
deleted file mode 100644
index 4c85c99..0000000
--- a/mandala/storages/remote_impls/mongo_impl.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# import datetime
-# from typing import Tuple, List
-#
-# import pymongo
-#
-# from mandala.storages.rels import RemoteEventLogEntry
-# from mandala.storages.remote_storage import RemoteStorage
-#
-#
-# class MongoRemoteStorage(RemoteStorage):
-# def __init__(self, db_name: str, client: pymongo.MongoClient):
-# self.db_name = db_name
-# self.client = client
-# self.log = client.experiment_data[self.db_name].event_log
-#
-# def save_event_log_entry(self, entry: RemoteEventLogEntry):
-# response = self.log.insert_one({"tables": entry})
-# assert response.acknowledged
-# # Set the timestamp based on the server time.
-# self.log.update_one(
-# {"_id": response.inserted_id},
-# {"$currentDate": {"timestamp": {"$type": "date"}}},
-# )
-#
-# def get_log_entries_since(
-# self, timestamp: datetime.datetime
-# ) -> Tuple[List[RemoteEventLogEntry], datetime.datetime]:
-# entries = []
-# last_timestamp = datetime.datetime.fromtimestamp(0)
-# for entry in self.log.find({"timestamp": {"$gt": timestamp}}):
-# entries.append(entry["tables"])
-# if entry["timestamp"] > last_timestamp:
-# last_timestamp = entry["timestamp"]
-# return entries, last_timestamp
-#
diff --git a/mandala/storages/remote_impls/mongo_mock.py b/mandala/storages/remote_impls/mongo_mock.py
deleted file mode 100644
index 667f1b4..0000000
--- a/mandala/storages/remote_impls/mongo_mock.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# from ...core.sig import Signature
-# from ...common_imports import *
-# import mongomock
-# import datetime
-#
-# from mandala.storages.rels import RemoteEventLogEntry
-# from mandala.storages.remote_storage import RemoteStorage
-#
-#
-# class MongoMockRemoteStorage(RemoteStorage):
-# def __init__(self, db_name: str, client: mongomock.MongoClient):
-# self.db_name = db_name
-# self.client = client
-# self.log = client.experiment_data[self.db_name].event_log
-# self.sigs: Dict[Tuple[str, int], Signature] = {}
-#
-# def pull_signatures(self) -> Dict[Tuple[str, int], Signature]:
-# return self.sigs
-#
-# def push_signatures(self, new_sigs: Dict[Tuple[str, int], Signature]) -> None:
-# current_internal_sigs = self.sigs
-# for (internal_name, version), new_sig in new_sigs.items():
-# if (internal_name, version) in current_internal_sigs:
-# current_sig = current_internal_sigs[(internal_name, version)]
-# if not current_sig.is_compatible(new_sig):
-# raise ValueError(
-# f"Signature {internal_name}:{version} is incompatible with {new_sig}"
-# )
-# self.sigs = new_sigs
-#
-# def save_event_log_entry(self, entry: RemoteEventLogEntry):
-# response = self.log.insert_one({"tables": entry})
-# assert response.acknowledged
-# # Set the timestamp based on the server time.
-# self.log.update_one(
-# {"_id": response.inserted_id},
-# {"$currentDate": {"timestamp": {"$type": "date"}}},
-# )
-#
-# def get_log_entries_since(
-# self, timestamp: datetime.datetime
-# ) -> Tuple[List[RemoteEventLogEntry], datetime.datetime]:
-# logger.debug(f"Getting log entries since {timestamp}")
-# entries = []
-# last_timestamp = datetime.datetime.fromtimestamp(0)
-# for entry in self.log.find({"timestamp": {"$gt": timestamp}}):
-# entries.append(entry["tables"])
-# if entry["timestamp"] > last_timestamp:
-# last_timestamp = entry["timestamp"]
-# return entries, last_timestamp
-#
diff --git a/mandala/storages/remote_storage.py b/mandala/storages/remote_storage.py
deleted file mode 100644
index 20a57ef..0000000
--- a/mandala/storages/remote_storage.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import abc
-import datetime
-from abc import abstractmethod
-from typing import Optional
-
-from ..common_imports import *
-from ..core.sig import Signature
-
-from mandala.storages.rels import RemoteEventLogEntry, RelAdapter
-
-
-class RemoteStorage(abc.ABC):
- @abstractmethod
- def save_event_log_entry(self, entry: RemoteEventLogEntry):
- raise NotImplementedError()
-
- @abstractmethod
- def get_log_entries_since(
- self, timestamp: datetime.datetime
- ) -> Tuple[List[RemoteEventLogEntry], datetime.datetime]:
- raise NotImplementedError()
-
- @abstractmethod
- def pull_signatures(self) -> Dict[Tuple[str, int], Signature]:
- raise NotImplementedError()
-
- @abstractmethod
- def push_signatures(self, new_sigs: Dict[Tuple[str, int], Signature]) -> None:
- raise NotImplementedError()
-
-
-# class RemoteSyncManager:
-# def __init__(
-# self,
-# local_storage: RelAdapter,
-# remote_storage: RemoteStorage,
-# timestamp: Optional[datetime.datetime] = None,
-# ):
-# self.local_storage = local_storage
-# self.remote_storage = remote_storage
-# self.last_timestamp = (
-# timestamp if timestamp is not None else datetime.datetime.fromtimestamp(0)
-# )
-#
-# def sync_from_remote(self):
-# new_log_entries, timestamp = self.remote_storage.get_log_entries_since(
-# self.last_timestamp
-# )
-# self.local_storage.apply_from_remote(new_log_entries)
-# self.last_timestamp = timestamp
-#
-# def sync_to_remote(self):
-# changes = self.local_storage.bundle_to_remote()
-# self.remote_storage.save_event_log_entry(changes)
-#
diff --git a/mandala/storages/sigs.py b/mandala/storages/sigs.py
deleted file mode 100644
index dfed46d..0000000
--- a/mandala/storages/sigs.py
+++ /dev/null
@@ -1,296 +0,0 @@
-import pyarrow as pa
-
-from ..common_imports import *
-from ..core.config import Config, dump_output_name
-from ..core.sig import Signature
-from .rel_impls.utils import Transactable, transaction, Connection
-from .rels import RelAdapter, serialize, SigAdapter
-from .remote_storage import RemoteStorage
-
-
-class SigSyncer(Transactable):
- """
- Responsible for syncing the local schema and the server.
-
- There are two kinds of updates:
- - updates from the server: these come in bulk when you pull the current
- true state of the signatures. The signature objects come with their
- internal data.
- - updates from the client: these could happen piece by piece, or in bulk
- if using a context manager. The signature objects may not have
- internal data if they come straight from the client's code, so need to
- be matched to the existing signatures.
-
- This class ensures that all updates are valid against the current state of
- the signatures on the server, and that only successful updates against this
- copy go through to the local storage.
-
- """
-
- def __init__(
- self,
- sig_adapter: SigAdapter,
- root: Optional[Union[Path, RemoteStorage]] = None,
- ):
- self.sig_adapter = sig_adapter
- self.root = root
- self.rel_storage = self.sig_adapter.rel_storage
-
- ############################################################################
- ### `Transactable` interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return self.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_storage._end_transaction(conn=conn)
-
- ############################################################################
- ### sync with server
- ############################################################################
- @property
- def has_remote(self) -> bool:
- return self.root is not None
-
- def pull_signatures(self) -> Dict[Tuple[str, int], Signature]:
- """
- Pull the current state of the signatures from the remote, make sure that
- they are compatible with the current ones, and then update or create
- according to the new signatures.
- """
- if isinstance(self.root, RemoteStorage):
- new_sigs = self.root.pull_signatures()
- return new_sigs
- else:
- raise ValueError("No remote storage to pull from.")
-
- def push_signatures(self, sigs: Dict[Tuple[str, int], Signature]):
- if isinstance(self.root, RemoteStorage):
- assert isinstance(self.root, RemoteStorage)
- self.root.push_signatures(new_sigs=sigs)
-
- @transaction()
- def sync_from_remote(self, conn: Optional[Connection] = None):
- """
- Update state from signatures *with internal data* (coming from the
- server). This includes:
- - creating new signatures
- - updating existing signatures
- - renaming functions and inputs
- """
- logger.debug("Syncing signatures from remote...")
- if self.has_remote:
- sigs = self.pull_signatures()
- # sort them by internal name and version to ensure earlier versions
- # are updated first
- sigs = sorted(sigs.values(), key=lambda s: (s.internal_name, s.version))
- for sig in sigs:
- logging.debug(f"Processing signature {sig}")
- if self.sig_adapter.exists_internal(sig=sig, conn=conn):
- # first set the ui name to the current one (if necessary)
- self.sig_adapter.update_ui_name(sig=sig, conn=conn)
- # then, update the input names too (if necessary)
- self.sig_adapter.update_input_ui_names(sig=sig, conn=conn)
- # now, update the (already name-aligned) signature from the new
- self.sig_adapter.update_sig(sig=sig, conn=conn)
- elif self.sig_adapter.exists_any_version(sig=sig, conn=conn):
- self.sig_adapter.create_new_version(sig=sig, conn=conn)
- else:
- self.sig_adapter.create_sig(sig=sig, conn=conn)
-
- ############################################################################
- ### atomic operations by the client
- ############################################################################
- @transaction()
- def validate_transaction(
- self,
- new_sig: Signature,
- all_sigs: Dict[Tuple[str, int], Signature],
- conn: Optional[Connection] = None,
- ) -> bool:
- """
- Check that a new signature is compatible with a current state of the
- signatures WITHOUT actually updating the state. This is used to check
- that a transaction is valid before committing it.
- """
- assert new_sig.has_internal_data
- if self.sig_adapter.exists_internal(sig=new_sig, conn=conn):
- current = all_sigs[new_sig.internal_name, new_sig.version]
- compatible, reason_not = current.is_compatible(
- new=new_sig,
- )
- if compatible:
- return True
- else:
- raise ValueError(reason_not)
- else:
- return True
-
- @transaction()
- def sync_create(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> Signature:
- """
- Pull the current state of the signatures from the remote, make sure that
- they are compatible with the creation of this signature, then push the
- updated signatures to the remote and create the signature locally.
- """
- self.sync_from_remote(conn=conn)
- new_sig = sig._generate_internal()
- self.validate_transaction(
- new_sig=new_sig,
- all_sigs=self.sig_adapter.load_state(conn=conn),
- conn=conn,
- )
- all_sigs = self.sig_adapter.load_state(conn=conn)
- all_sigs[new_sig.internal_name, new_sig.version] = new_sig
- self.push_signatures(sigs=all_sigs)
- self.sig_adapter.create_sig(sig=new_sig, conn=conn)
- return new_sig
-
- @transaction()
- def sync_update(
- self, sig: Signature, conn: Optional[Connection] = None
- ) -> Signature:
- """
- Pull the current state of the signatures from the remote, make sure that
- they are compatible with the update of this signature, then push the
- updated signatures to the remote and update the signature locally.
- """
- self.sync_from_remote(conn=conn)
- current = self.sig_adapter.load_ui_sigs(conn=conn)[sig.ui_name, sig.version]
- if current != sig:
- new_sig, _ = current.update(new=sig)
- self.validate_transaction(
- new_sig=new_sig, all_sigs=self.sig_adapter.load_state(conn=conn)
- )
- all_sigs = self.sig_adapter.load_state(conn=conn)
- all_sigs[(current.internal_name, current.version)] = new_sig
- self.push_signatures(sigs=all_sigs)
- self.sig_adapter.update_sig(sig=new_sig, conn=conn)
- return new_sig
- else:
- return current
-
- @transaction()
- def sync_new_version(
- self, sig: Signature, conn: Optional[Connection] = None,
- strict: bool = False,
- ) -> Signature:
- """
- Pull the current state of the signatures from the remote, make sure that
- they are compatible with the creation of a new version given by this
- signature, then push the updated signatures to the remote and create the
- new version locally.
- """
- self.sync_from_remote(conn=conn)
- if not self.sig_adapter.exists_any_version(sig=sig, conn=conn):
- raise ValueError()
- latest_sig = self.sig_adapter.get_latest_version(sig=sig, conn=conn)
- new_version = latest_sig.version + 1
- if strict:
- if not sig.version == new_version:
- raise ValueError(f"New version must be {new_version}, not {sig.version}")
- new_sig = sig._generate_internal(internal_name=latest_sig.internal_name)
- new_sig.check_invariants()
- self.validate_transaction(
- new_sig=new_sig, all_sigs=self.sig_adapter.load_state(conn=conn)
- )
- all_sigs = self.sig_adapter.load_state(conn=conn)
- all_sigs[(new_sig.internal_name, new_sig.version)] = new_sig
- self.push_signatures(sigs=all_sigs)
- self.sig_adapter.create_new_version(sig=new_sig, conn=conn)
- return new_sig
-
- @transaction()
- def sync_rename_sig(
- self, sig: Signature, new_name: str, conn: Optional[Connection] = None
- ) -> Signature:
- """
- Pull the current state of the signatures from the remote, make sure that
- they are compatible with the renaming of this signature, then push the
- updated signatures to the remote and rename the signature locally.
- """
- self.sync_from_remote(conn=conn)
- #! note: we validate before the renaming. Ideally we should have logic
- # to do this for the new signature directly
- # self.validate_transaction(
- # new_sig=sig, all_sigs=self.sig_adapter.load_state(conn=conn)
- # )
- new_sig = sig.rename(new_name=new_name)
- all_sigs = self.sig_adapter.update_ui_name(
- sig=new_sig, conn=conn, validate_only=True
- )
- # all_sigs = self.sig_adapter.load_state(conn=conn)
- # all_sigs[(new_sig.internal_name, new_sig.version)] = new_sig
- self.push_signatures(sigs=all_sigs)
- self.sig_adapter.update_ui_name(sig=new_sig, conn=conn)
- return new_sig
-
- @transaction()
- def sync_rename_input(
- self,
- sig: Signature,
- input_name: str,
- new_input_name: str,
- conn: Optional[Connection] = None,
- ) -> Signature:
- """
- Pull the current state of the signatures from the remote, make sure that
- they are compatible with the renaming of this input, then push the
- updated signatures to the remote and rename the input locally.
- """
- self.sync_from_remote(conn=conn)
- #! note: we validate before the renaming. Ideally we should have logic
- # to do this for the new signature directly
- self.validate_transaction(
- new_sig=sig, all_sigs=self.sig_adapter.load_state(conn=conn)
- )
- new_sig = sig.rename_inputs(mapping={input_name: new_input_name})
- all_sigs = self.sig_adapter.load_state(conn=conn)
- all_sigs[(new_sig.internal_name, new_sig.version)] = new_sig
- self.push_signatures(sigs=all_sigs)
- self.sig_adapter.update_input_ui_names(sig=new_sig, conn=conn)
- return new_sig
-
- @transaction()
- def sync_from_local(
- self,
- sig: Signature,
- conn: Optional[Connection] = None,
- ) -> Signature:
- """
- Create a new signature, create a new version, or update an existing one,
- and immediately send changes to the server.
- """
- if sig.version is None:
- if self.sig_adapter.exists_any_version(sig=sig, conn=conn):
- action = "use_latest"
- else:
- action = "create"
- else:
- if self.sig_adapter.exists_versioned_ui(sig=sig, conn=conn):
- action = "update_current"
- elif self.sig_adapter.exists_any_version(sig=sig, conn=conn):
- action = "new_version"
- else:
- action = "create"
- if action == "use_latest":
- latest_sig = self.sig_adapter.get_latest_version(sig=sig, conn=conn)
- new_sig = copy.deepcopy(sig)
- new_sig.version = latest_sig.version
- res = self.sync_update(sig=new_sig, conn=conn)
- elif action == "update_current":
- # current_version = self.sig_adapter.get_version(sig=sig, version=sig.version, conn=conn)
- # new_sig = copy.deepcopy(current_sig)
- # new_sig.version = latest_sig.version
- res = self.sync_update(sig=sig, conn=conn)
- elif action == "new_version":
- res = self.sync_new_version(sig=sig, conn=conn)
- elif action == "create":
- sig = copy.deepcopy(sig)
- sig.version = 0
- res = self.sync_create(sig=sig, conn=conn)
- else:
- raise ValueError(f"Unknown action {action}")
- return res
diff --git a/mandala/tests/__init__.py b/mandala/tests/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/mandala/tests/_test_remote_mock.py b/mandala/tests/_test_remote_mock.py
deleted file mode 100644
index 1e7905f..0000000
--- a/mandala/tests/_test_remote_mock.py
+++ /dev/null
@@ -1,301 +0,0 @@
-import mongomock
-
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-def test_disjoint_funcs():
- """
- Test a basic scenario where two users define two different functions and do
- some work with each, then sync their results.
-
- Expected behavior: the two storages end up in the same state.
- """
- client = mongomock.MongoClient()
- root = MongoMockRemoteStorage(db_name="test", client=client)
-
- # create multiple storages connected to it
- storage_1 = Storage(root=root)
- storage_2 = Storage(root=root)
-
- ### do work with storage 1
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage_1.run():
- inc(23)
-
- ### do work with storage 2
- @op
- def mult(x: int, y: int) -> int:
- return x * y
-
- with storage_2.run():
- mult(23, 42)
-
- # synchronize storage_1 with the new work
- storage_1.sync_with_remote()
-
- # verify that both storages have the same state
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
-
-def test_create_func():
- """
- Unit test for function creation in a multi-user setting.
-
- Specifically, test a scenario where:
- - one user defines a function
- - another user defines the same function under the same name
-
- Expected behavior:
- - the first creator wins: both functions are assigned the UID issued to
- the first creator.
- """
- client = mongomock.MongoClient()
- root = MongoMockRemoteStorage(db_name="test", client=client)
- storage_1 = Storage(root=root)
- storage_2 = Storage(root=root)
-
- @op(ui_name="f")
- def f_1(x: int) -> int:
- return x + 1
-
- @op(ui_name="f")
- def f_2(x: int) -> int:
- return x + 1
-
- storage_1.synchronize(f=f_1)
- storage_2.synchronize(f=f_2)
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
-
-
-def test_add_input():
- """
- Unit test for adding an input to a function in a multi-user setting.
-
- Specifically, test a scenario where:
- - users 1 and 2 agree on the definition of a function `f`
- - user 1 adds an input to the function
- - user 2 tries to send calls to the initial function to the server
-
- Expected behavior:
- - calls to the old variant of the function work as expected:
- - user 2 is able to commit their calls and send them to the server
- - user 1 is able to then load these new calls into their local storage
- - the two storages end up in the same state
- """
- client = mongomock.MongoClient()
- root = MongoMockRemoteStorage(db_name="test", client=client)
- storage_1 = Storage(root=root)
- storage_2 = Storage(root=root)
-
- @op(ui_name="inc")
- def inc_1(x: int) -> int:
- return x + 1
-
- @op(ui_name="inc")
- def inc_2(x: int) -> int:
- return x + 1
-
- storage_1.synchronize(f=inc_1)
- storage_2.synchronize(f=inc_2)
-
- @op(ui_name="inc")
- def inc_1(x: int, how_many_times: int = 1) -> int:
- return x + how_many_times
-
- storage_1.synchronize(f=inc_1)
- assert not signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- with storage_2.run():
- inc_2(23)
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert not data_is_equal(storage_1=storage_1, storage_2=storage_2)
- storage_1.sync_from_remote()
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
-
-def test_rename_func():
- """
- Unit test for renaming a function in a multi-user setting.
-
- Specifically, test a scenario where:
- - users 1 and 2 agree on the definition of a function `f`
- - the user 1 renames the function to `g`
- - unbeknownst to that, the user 2 still sees `f` and uses it in
- computations.
-
- Expected behavior:
- - calls to the old variant of the function work as expected:
- - user 2 is able to commit their calls and send them to the server
- - user 1 is able to then load these new calls into their local storage
- - the two storages end up in the same state
-
- ! Possible confusion:
- - if user 2 then re-synchronizes the old variant of the function called
- `f`, this will create a new function.
- """
- client = mongomock.MongoClient()
- root = MongoMockRemoteStorage(db_name="test", client=client)
- storage_1 = Storage(root=root)
- storage_2 = Storage(root=root)
-
- @op(ui_name="f")
- def f_1(x: int) -> int:
- return x + 1
-
- @op(ui_name="f")
- def f_2(x: int) -> int:
- return x + 1
-
- storage_1.synchronize(f=f_1)
- storage_2.synchronize(f=f_2)
-
- storage_1.rename_func(func=f_1, new_name="g")
-
- assert not signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- with storage_2.run():
- f_2(23)
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert not data_is_equal(storage_1=storage_1, storage_2=storage_2)
- storage_1.sync_from_remote()
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
-
-def test_rename_input():
- """
- Analogous to `test_rename_func`.
- """
- client = mongomock.MongoClient()
- root = MongoMockRemoteStorage(db_name="test", client=client)
- storage_1 = Storage(root=root)
- storage_2 = Storage(root=root)
-
- @op(ui_name="f")
- def f_1(x: int) -> int:
- return x + 1
-
- @op(ui_name="f")
- def f_2(x: int) -> int:
- return x + 1
-
- storage_1.synchronize(f=f_1)
- storage_2.synchronize(f=f_2)
-
- storage_1.rename_arg(func=f_1, name="x", new_name="y")
-
- assert not signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- with storage_2.run():
- f_2(23)
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert not data_is_equal(storage_1=storage_1, storage_2=storage_2)
- storage_1.sync_from_remote()
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
-
-def test_remote_lots_of_stuff():
- Config.autowrap_inputs = True
- Config.autounwrap_inputs = True
- # create a *single* (mock) remote database
- client = mongomock.MongoClient()
- root = MongoMockRemoteStorage(db_name="test", client=client)
-
- # create multiple storages connected to it
- storage_1 = Storage(root=root)
- storage_2 = Storage(root=root)
-
- def check_all_invariants():
- check_invariants(storage=storage_1)
- check_invariants(storage=storage_2)
-
- ### do stuff with storage 1
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage_1.run():
- for i in range(20, 25):
- j = inc(x=i)
- final = add(i, j)
-
- ### do stuff with storage 2
- @op
- def mult(x: int, y: int) -> int:
- return x * y
-
- with storage_2.run():
- for i, j in zip(range(20, 25), range(20, 25)):
- k = mult(i, j)
-
- storage_2.sync_with_remote()
- storage_1.sync_with_remote()
- storage_2.sync_with_remote()
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
- ### now, rename a function in storage 1!
- storage_1.rename_func(func=inc, new_name="inc_new")
-
- @op
- def inc_new(x: int) -> int:
- return x + 1
-
- storage_1.synchronize(f=inc_new)
-
- ### rename argument too
- storage_1.rename_arg(func=inc_new, name="x", new_name="x_new")
-
- @op
- def inc_new(x_new: int) -> int:
- return x_new + 1
-
- storage_1.synchronize(f=inc_new)
-
- # do work with the renamed function in storage 1
- with storage_1.run():
- for i in range(20, 30):
- j = inc_new(x_new=i)
- final = add(i, j)
-
- # now sync stuff
- storage_2.sync_with_remote()
-
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
- # check we can do work with the renamed function in storage_2
- @op
- def inc_new(x_new: int) -> int:
- return x_new + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage_2.run():
- for i in range(20, 40):
- j = inc_new(x_new=i)
- final = add(i, j)
- storage_2.sync_with_remote()
- storage_1.sync_with_remote()
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
-
- # do some versioning stuff
- @op(version=1)
- def add(x: int, y: int, z: int) -> int:
- return x + y + z
-
- with storage_2.run():
- add(x=1, y=2, z=3)
-
- storage_2.sync_with_remote()
- storage_1.sync_with_remote()
- assert signatures_are_equal(storage_1=storage_1, storage_2=storage_2)
- assert data_is_equal(storage_1=storage_1, storage_2=storage_2)
diff --git a/mandala/tests/stateful_utils.py b/mandala/tests/stateful_utils.py
deleted file mode 100644
index b4a4048..0000000
--- a/mandala/tests/stateful_utils.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from collections import OrderedDict
-from mandala.common_imports import *
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.core.utils import Hashing, invert_dict
-from mandala.queries.compiler import *
-from mandala.queries.weaver import ValNode, CallNode
-import string
-
-
-def combine_inputs(*args, **kwargs) -> str:
- return Hashing.get_content_hash(obj=(args, kwargs))
-
-
-def generate_deterministic(seed: str, n_outputs: int) -> List[str]:
- result = []
- current = seed
- for i in range(n_outputs):
- new = Hashing.get_content_hash(obj=current)
- result.append(new)
- current = new
- return result
-
-
-def random_string(size: int = 10) -> str:
- return "".join(random.choice(string.ascii_letters) for _ in range(size))
-
-
-TEMPLATE = """
-def {name}({inputs}) -> {output_annotation}:
- if {n_outputs} == 0:
- return None
- elif {n_outputs} == 1:
- return generate_deterministic(seed=combine_inputs({inputs}),
- n_outputs=1)[0]
- else:
- return tuple(generate_deterministic(seed=combine_inputs({inputs}), n_outputs={n_outputs}))
-"""
-
-
-def make_func(
- ui_name: str,
- input_names: List[str],
- n_outputs: int,
-) -> types.FunctionType:
- """
- Generate a deterministic function with given interface
- """
- inputs = ", ".join(input_names)
- output_annotation = (
- "None"
- if n_outputs == 0
- else "Any"
- if n_outputs == 1
- else f"Tuple[{', '.join(['Any'] * n_outputs)}]"
- )
- code = TEMPLATE.format(
- name=ui_name,
- inputs=inputs,
- output_annotation=output_annotation,
- n_outputs=n_outputs,
- )
- f = compile(code, "", "exec")
- exec(f)
- f = locals()[ui_name]
- return f
-
-
-def make_func_from_sig(sig: Signature) -> types.FunctionType:
- return make_func(sig.ui_name, list(sig.input_names), sig.n_outputs)
-
-
-def make_op(
- ui_name: str,
- input_names: List[str],
- n_outputs: int,
- defaults: Dict[str, Any],
- version: int = 0,
- deterministic: bool = True,
-) -> FuncOp:
- """
- Generate a deterministic function with given interface
- """
- sig = Signature(
- ui_name=ui_name,
- input_names=set(input_names),
- n_outputs=n_outputs,
- version=version,
- defaults=defaults,
- input_annotations={k: Any for k in input_names},
- output_annotations=[Any] * n_outputs,
- )
- f = make_func(ui_name, input_names, n_outputs)
- return FuncOp._from_data(func=f, sig=sig)
diff --git a/mandala/tests/test_async.py b/mandala/tests/test_async.py
deleted file mode 100644
index ea59ec9..0000000
--- a/mandala/tests/test_async.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-@pytest.mark.asyncio
-async def test_unit():
- storage = Storage()
-
- @op
- async def inc(x: int) -> int:
- time.sleep(1)
- print("Hi there!")
- return x + 1
-
- with storage.run():
- z = await inc(1)
- assert storage.similar(z).shape[0] == 1
-
- # now run 10 calls in parallel
- with storage.run():
- tasks = [inc(i) for i in range(10)]
- results = await asyncio.gather(*tasks)
- assert storage.similar(results[0]).shape[0] == 10
-
- # run again
- with storage.run():
- tasks = [inc(i) for i in range(10)]
- results = await asyncio.gather(*tasks)
- assert storage.similar(results[0]).shape[0] == 10
diff --git a/mandala/tests/test_causal.py b/mandala/tests/test_causal.py
deleted file mode 100644
index b9fac77..0000000
--- a/mandala/tests/test_causal.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-def test_unit():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- x = inc(23)
- df = storage.get_table(inc, values="uids")
- assert df.shape[0] == 1
- with storage.run():
- y = inc(22)
- z = inc(y)
- df = storage.get_table(inc, values="uids")
- assert df.shape[0] == 3
-
-
-def test_queries():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def dec(x: int) -> int:
- return x - 1
-
- with storage.run():
- x = inc(23)
-
- with storage.run():
- y = dec(24)
-
- with storage.query():
- x = inc(Q())
- y = dec(x)
- df = storage.df(x, y)
- assert df.empty
- num_content_calls_inc = storage.get_table(
- inc, drop_duplicates=True, values="uids"
- ).shape[0]
- num_content_calls_dec = storage.get_table(
- dec, drop_duplicates=True, values="uids"
- ).shape[0]
- assert num_content_calls_inc == 1
- assert num_content_calls_dec == 1
-
- with storage.run(allow_calls=False):
- x = inc(23)
- y = dec(x)
- num_content_calls_inc = storage.get_table(
- inc, drop_duplicates=True, values="uids"
- ).shape[0]
- num_content_calls_dec = storage.get_table(
- dec, drop_duplicates=True, values="uids"
- ).shape[0]
- assert num_content_calls_inc == 1
- assert num_content_calls_dec == 1
- with storage.query():
- x = inc(Q())
- y = dec(x)
- df = storage.df(x, y)
- assert df.shape[0] == 1
-
-
-def test_bug():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage.run():
- for x in range(5):
- for y in range(5):
- z = inc(x)
- w = add(z, y)
- assert storage.get_table(add).shape[0] == 25
diff --git a/mandala/tests/test_cfs.py b/mandala/tests/test_cfs.py
index d76b0b1..b75ad16 100644
--- a/mandala/tests/test_cfs.py
+++ b/mandala/tests/test_cfs.py
@@ -1,117 +1,18 @@
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.queries.weaver import *
+from mandala._next.imports import *
-def test_construction():
-
- storage = Storage()
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- @op
- def f(x: int, a: int = 1) -> Tuple[int, int]:
- return x + a, x - a
-
- with storage.run():
- things = [add(i, i + 1) for i in range(10)]
- other_things = [f(i) for i in range(10)]
-
- rf1 = ComputationFrame.from_refs(things, storage=storage)
- rf2 = ComputationFrame.from_refs([t[0] for t in other_things], storage=storage)
- rf3 = ComputationFrame.from_op(func=add, storage=storage)
- rf4 = ComputationFrame.from_op(func=f, storage=storage)
-
-
-def test_back():
- storage = Storage()
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- @op
- def f(x: int, a: int = 1) -> Tuple[int, int]:
- return x + a, x - a
-
- @op
- def mul(x: int, y: int) -> int:
- return x * y
-
- with storage.run():
- cs = []
- for i in range(10):
- a = add(i, i + 1)
- b = f(a)
- c = mul(b[0], b[1])
- cs.append(c)
-
- rf1 = ComputationFrame.from_refs(cs, storage=storage, name="c")
- rf1.back("c")
- rf1.back("c", inplace=True)
- rf1.back()
-
-
-def test_evals():
- storage = Storage()
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- @op
- def f(x: int, a: int = 1) -> Tuple[int, int]:
- return x + a, x - a
-
- @op
- def mul(x: int, y: int) -> int:
- return x * y
-
- with storage.run():
- cs = []
- for i in range(10):
- a = add(i, i + 1)
- b = f(a)
- c = mul(b[0], b[1])
- cs.append(c)
-
- rf = ComputationFrame.from_refs(cs, storage=storage, name="c")
-
- rf[["c"]]
- rf[list(rf.var_nodes.keys())]
- df = rf.eval("c")
- assert len(df) == 10
- df = rf.eval()
- assert len(df) == 10
-
- rf = ComputationFrame.from_op(func=add, storage=storage)
- sub_rf = rf[rf.eval("x") < 5]
- assert len(sub_rf) == 5
-
-
-def test_deletion():
-
+def test_single_func():
storage = Storage()
@op
def inc(x: int) -> int:
- print("hey!")
return x + 1
- with storage.run():
- for i in range(10):
- inc(i)
-
- # check the number of rows goes down by the expected amount
- rf = ComputationFrame.from_op(func=inc, storage=storage)
- rf[rf.eval("x") < 5].delete(delete_dependents=True, ask=False)
- rf = ComputationFrame.from_op(func=inc, storage=storage)
- assert len(rf) == 5
-
- # re-compute the calls
- storage.cache.evict_all()
- with storage.run():
+ with storage:
for i in range(10):
inc(i)
+
+ cf = storage.cf(inc)
+ df = cf.df()
+ assert df.shape == (10, 3)
+ assert (df['output_0'] == df['x'] + 1).all()
\ No newline at end of file
diff --git a/mandala/tests/test_components.py b/mandala/tests/test_components.py
deleted file mode 100644
index f604095..0000000
--- a/mandala/tests/test_components.py
+++ /dev/null
@@ -1,622 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-def test_rel_storage():
- rel_storage = SQLiteRelStorage()
- assert set(rel_storage.get_tables()) == set()
- rel_storage.create_relation(
- name="test", columns=[("a", None), ("b", None)], primary_key="a", defaults={}
- )
- assert set(rel_storage.get_tables()) == {"test"}
- assert rel_storage.get_data(table="test").empty
- df = pd.DataFrame({"a": ["x", "y"], "b": ["z", "w"]})
- ta = pa.Table.from_pandas(df)
- rel_storage.insert(relation="test", ta=ta)
- assert (rel_storage.get_data(table="test") == df).all().all()
- rel_storage.upsert(relation="test", ta=ta)
- assert (rel_storage.get_data(table="test") == df).all().all()
- rel_storage.create_column(relation="test", name="c", default_value="a")
- df["c"] = ["a", "a"]
- assert (rel_storage.get_data(table="test") == df).all().all()
- # rel_storage.delete(relation="test", index=["x", "y"])
- # assert rel_storage.get_data(table="test").empty
-
-
-def test_main_storage():
- storage = Storage()
- # check that things work on an empty storage
- storage.sig_adapter.load_state()
- storage.rel_adapter.get_all_call_data()
-
-
-def test_wrapping():
- assert unwrap(23) == 23
- assert unwrap(23.0) == 23.0
- assert unwrap("23") == "23"
- assert unwrap([1, 2, 3]) == [1, 2, 3]
-
- vref = wrap_atom(23)
- assert wrap_atom(vref) is vref
- try:
- wrap_atom(vref, uid="aaaaaa")
- except:
- assert True
-
-
-def test_unwrapping():
- # tuples
- assert unwrap((1, 2, 3)) == (1, 2, 3)
- vrefs = (wrap_atom(23), wrap_atom(24), wrap_atom(25))
- assert unwrap(vrefs, through_collections=True) == (23, 24, 25)
- assert unwrap(vrefs, through_collections=False) == vrefs
- # sets
- assert unwrap({1, 2, 3}) == {1, 2, 3}
- vrefs = {wrap_atom(23), wrap_atom(24), wrap_atom(25)}
- assert unwrap(vrefs, through_collections=True) == {23, 24, 25}
- assert unwrap(vrefs, through_collections=False) == vrefs
- # lists
- assert unwrap([1, 2, 3]) == [1, 2, 3]
- vrefs = [wrap_atom(23), wrap_atom(24), wrap_atom(25)]
- assert unwrap(vrefs, through_collections=True) == [23, 24, 25]
- assert unwrap(vrefs, through_collections=False) == vrefs
- # dicts
- assert unwrap({"a": 1, "b": 2, "c": 3}) == {"a": 1, "b": 2, "c": 3}
- vrefs = {"a": wrap_atom(23), "b": wrap_atom(24), "c": wrap_atom(25)}
- assert unwrap(vrefs, through_collections=True) == {"a": 23, "b": 24, "c": 25}
- assert unwrap(vrefs, through_collections=False) == vrefs
-
-
-def test_reprs():
- x = wrap_atom(23)
- repr(x), str(x)
-
-
-################################################################################
-### contexts
-################################################################################
-def test_nesting_new_api():
- storage = Storage()
-
- with storage.run() as c:
- assert c.mode == MODES.run
-
- with storage.query() as q:
- assert q.mode == MODES.query
-
- with storage.run() as c_1:
- with storage.run() as c_2:
- with storage.run() as c_3:
- assert c_1 is c_2 is c_3
-
- with storage.run() as c:
- assert c.mode == MODES.run
- assert c.storage is storage
- with storage.query() as q:
- assert q is c
- assert q.storage is storage
- assert q.mode == MODES.query
- assert c.mode == MODES.run
-
-
-def test_noop():
- # check that ops are noops when not in a context
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- assert inc(23) == 24
-
-
-def test_failures():
- storage = Storage()
-
- try:
- with storage.run(bla=23):
- pass
- except:
- assert True
-
-
-################################################################################
-### test ops
-################################################################################
-def test_signatures():
- sig = Signature(
- ui_name="f",
- input_names={"x", "y"},
- n_outputs=1,
- defaults={"y": 42},
- version=0,
- input_annotations={"x": int, "y": int},
- output_annotations=[Any],
- )
-
- # if internal data has not been set, it should not be accessible
- try:
- sig.internal_name
- except ValueError:
- assert True
-
- try:
- sig.ui_to_internal_input_map
- except ValueError:
- assert True
-
- with_internal = sig._generate_internal()
- assert not sig.has_internal_data
- assert with_internal.has_internal_data
-
- ### invalid signature changes
- # remove input
- new = Signature(
- ui_name="f",
- input_names={"x", "z"},
- n_outputs=1,
- defaults={"y": 42},
- version=0,
- input_annotations={"x": int, "z": int},
- output_annotations=[Any],
- )
- try:
- sig.update(new=new)
- except ValueError:
- assert True
- new = Signature(
- ui_name="f",
- input_names={"x", "y"},
- n_outputs=1,
- defaults={},
- version=0,
- input_annotations={"x": int, "y": int},
- output_annotations=[Any],
- )
- try:
- sig.update(new=new)
- except ValueError:
- assert True
- # change version
- new = Signature(
- ui_name="f",
- input_names={"x", "y"},
- n_outputs=1,
- defaults={"y": 42},
- version=1,
- input_annotations={"x": int, "y": int},
- output_annotations=[Any],
- )
- try:
- sig.update(new=new)
- except ValueError:
- assert True
-
- # add input
- sig = sig._generate_internal()
- try:
- sig.create_input(name="y", default=23, annotation=Any)
- except ValueError:
- assert True
- new = sig.create_input(name="z", default=23, annotation=Any)
- assert new.input_names == {"x", "y", "z"}
-
-
-def test_output_name_failure():
-
- try:
-
- @op
- def f(output_0: int) -> int:
- return output_0
-
- except:
- assert True
-
-
-def test_changing_num_outputs():
-
- storage = Storage()
-
- @op
- def f(x: int):
- return x
-
- try:
- with storage.run():
- f(1)
- except Exception:
- assert True
-
- @op
- def f(x: int) -> int:
- return x
-
- with storage.run():
- f(1)
-
- @op
- def f(x: int) -> Tuple[int, int]:
- return x
-
- try:
- with storage.run():
- f(1)
- except ValueError:
- assert True
-
- @op
- def f(x: int) -> int:
- return x
-
- with storage.run():
- f(1)
-
-
-def test_nout():
- storage = Storage()
-
- @op(nout=2)
- def f(x: int):
- return x, x
-
- with storage.run():
- a, b = f(1)
- assert unwrap(a) == 1 and unwrap(b) == 1
-
- @op(nout=0)
- def g(x: int) -> Tuple[int, int]:
- pass
-
- with storage.run():
- c = g(1)
- assert c is None
-
-
-################################################################################
-### test storage
-################################################################################
-OUTPUT_ROOT = Path(__file__).parent / "output"
-
-
-def test_get():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- y = inc(23)
-
- y_full = storage.rel_adapter.obj_get(uid=y.uid)
- assert y_full.in_memory
- assert unwrap(y_full) == 24
-
- y_lazy = storage.rel_adapter.obj_get(uid=y.uid, _attach_atoms=False)
- assert not y_lazy.in_memory
- assert y_lazy.obj is None
-
- @op
- def get_prime_factors(n: int) -> Set[int]:
- factors = set()
- d = 2
- while d * d <= n:
- while (n % d) == 0:
- factors.add(d)
- n //= d
- d += 1
- if n > 1:
- factors.add(n)
- return factors
-
- with storage.run():
- factors = get_prime_factors(42)
-
- factors_full = storage.rel_adapter.obj_get(uid=factors.uid)
- assert factors_full.in_memory
- assert all([x.in_memory for x in factors_full])
-
- factors_shallow = storage.rel_adapter.obj_get(uid=factors.uid, depth=1)
- assert factors_shallow.in_memory
- assert all([not x.in_memory for x in factors_shallow])
-
- storage.rel_adapter.mattach(vrefs=[factors_shallow])
- assert all([x.in_memory for x in factors_shallow])
-
- @superop
- def get_factorizations(n: int) -> List[List[int]]:
- # get all factorizations of a number into factors
- n = unwrap(n)
- divisors = [i for i in range(2, n + 1) if n % i == 0]
- result = [[n]]
- for divisor in divisors:
- sub_solutions = unwrap(get_factorizations(n // divisor))
- result.extend(
- [
- [divisor] + sub_solution
- for sub_solution in sub_solutions
- if min(sub_solution) >= divisor
- ]
- )
- return result
-
- with storage.run():
- factorizations = get_factorizations(42)
-
- factorizations_full = storage.rel_adapter.obj_get(uid=factorizations.uid)
- assert unwrap(factorizations_full) == [[42], [2, 21], [2, 3, 7], [3, 14], [6, 7]]
- factorizations_shallow = storage.rel_adapter.obj_get(
- uid=factorizations.uid, depth=1
- )
- assert factorizations_shallow.in_memory
- storage.rel_adapter.mattach(vrefs=[factorizations_shallow.obj[0]])
- assert unwrap(factorizations_shallow.obj[0]) == [42]
-
- # result, call, wrapped_inputs = storage.call_run(
- # func_op=get_factorizations.func_op,
- # inputs={"n": 42},
- # _call_depth=0,
- # )
-
-
-def test_persistent():
- db_path = OUTPUT_ROOT / "test_persistent.db"
- if db_path.exists():
- db_path.unlink()
- storage = Storage(db_path=db_path)
-
- try:
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def get_prime_factors(n: int) -> Set[int]:
- factors = set()
- d = 2
- while d * d <= n:
- while (n % d) == 0:
- factors.add(d)
- n //= d
- d += 1
- if n > 1:
- factors.add(n)
- return factors
-
- @superop
- def get_factorizations(n: int) -> List[List[int]]:
- # get all factorizations of a number into factors
- n = unwrap(n)
- divisors = [i for i in range(2, n + 1) if n % i == 0]
- result = [[n]]
- for divisor in divisors:
- sub_solutions = unwrap(get_factorizations(n // divisor))
- result.extend(
- [
- [divisor] + sub_solution
- for sub_solution in sub_solutions
- if min(sub_solution) >= divisor
- ]
- )
- return result
-
- with storage.run():
- y = inc(23)
- factors = get_prime_factors(42)
- factorizations = get_factorizations(42)
- assert all([x.in_memory for x in (y, factors, factorizations)])
-
- with storage.run():
- y = inc(23)
- factors = get_prime_factors(42)
- factorizations = get_factorizations(42)
- assert all([not x.in_memory for x in (y, factors, factorizations)])
-
- with storage.run(lazy=False):
- y = inc(23)
- factors = get_prime_factors(42)
- factorizations = get_factorizations(42)
- assert all([x.in_memory for x in (y, factors, factorizations)])
-
- with storage.run():
- y = inc(23)
- assert not y.in_memory
- y._auto_attach()
- assert y.in_memory
- factors = get_prime_factors(42)
- assert not factors.in_memory
- 7 in factors
- assert factors.in_memory
-
- factorizations = get_factorizations(42)
- assert not factorizations.in_memory
- n = len(factorizations)
- assert factorizations.in_memory
- #! this is now in memory b/c new caching
- # assert not factorizations[0].in_memory
- # factorizations[0][0]
- # assert factorizations[0].in_memory
-
- #! this is now in memory b/c new caching
- # for elt in factorizations[1]:
- # assert not elt.in_memory
-
- except Exception as e:
- raise e
- finally:
- db_path.unlink()
-
-
-def test_magics():
- db_path = OUTPUT_ROOT / "test_magics.db"
- if db_path.exists():
- db_path.unlink()
- storage = Storage(db_path=db_path)
-
- try:
- Config.enable_ref_magics = True
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- x = inc(23)
-
- with storage.run():
- x = inc(23)
- assert not x.in_memory
- if x > 0:
- y = inc(x)
- assert x.in_memory
-
- with storage.run():
- x = inc(23)
- y = inc(x)
- if x + y > 0:
- z = inc(x)
-
- with storage.run():
- x = inc(23)
- y = inc(x)
- if x:
- z = inc(x)
-
- except Exception as e:
- raise e
- finally:
- db_path.unlink()
-
-
-def test_spillover():
- db_path = OUTPUT_ROOT / "test_spillover.db"
- if db_path.exists():
- db_path.unlink()
- spillover_dir = OUTPUT_ROOT / "test_spillover/"
- if spillover_dir.exists():
- shutil.rmtree(spillover_dir)
- storage = Storage(db_path=db_path, spillover_dir=spillover_dir)
-
- try:
- import numpy as np
-
- @op
- def create_large_array() -> np.ndarray:
- return np.random.rand(10_000_000)
-
- with storage.run():
- x = create_large_array()
-
- assert len(os.listdir(spillover_dir)) == 1
- path = spillover_dir / os.listdir(spillover_dir)[0]
- with open(path, "rb") as f:
- data = unwrap(joblib.load(f))
- assert np.allclose(data, unwrap(x))
-
- with storage.run():
- x = create_large_array()
- assert not x.in_memory
- x = unwrap(x)
-
- except Exception as e:
- raise e
- finally:
- db_path.unlink()
- shutil.rmtree(spillover_dir)
-
-
-def test_batching_unit():
-
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.batch():
- y = inc(23)
-
- assert unwrap(y) == 24
- assert y.uid is not None
- all_data = storage.rel_storage.get_all_data()
- assert all_data[Config.vref_table].shape[0] == 2
- assert all_data[inc.func_op.sig.versioned_ui_name].shape[0] == 1
-
-
-def test_exclude_arg():
- storage = Storage()
-
- @op
- def inc(x: int, __excluded__=False) -> int:
- return x + 1
-
- with storage.run():
- y = inc(23)
-
-
-def test_provenance():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- y = inc(23)
- storage.prov(y)
- storage.print_graph(y)
-
- ### struct inputs
- @op
- def avg_list(nums: List[int]) -> float:
- return sum(nums) / len(nums)
-
- @op
- def avg_dict(nums: Dict[str, int]) -> float:
- return sum(nums.values()) / len(nums)
-
- @op
- def avg_set(nums: Set[int]) -> float:
- return sum(nums) / len(nums)
-
- with storage.run():
- x = avg_list([1, 2, 3])
- y = avg_dict({"a": 1, "b": 2, "c": 3})
- z = avg_set({1, 2, 3})
- for v in [x, y, z]:
- storage.prov(v)
- storage.print_graph(v)
-
- ### struct outputs
- @op
- def get_list() -> List[int]:
- return [1, 2, 3]
-
- @op
- def get_dict() -> Dict[str, int]:
- return {"a": 1, "b": 2, "c": 3}
-
- @op
- def get_set() -> Set[int]:
- return {1, 2, 3}
-
- with storage.run():
- x = get_list()
- y = get_dict()
- z = get_set()
- a = x[0]
- b = y["a"]
- for v in [x, y, z, a, b]:
- storage.prov(v)
- storage.print_graph(v)
-
- ### a mess of stuff
- with storage.run():
- a = get_list()
- x = avg_list(a[:2])
- y = avg_dict(get_dict())
- z = avg_set(get_set())
- for v in [a, x, y, z]:
- storage.prov(v)
- storage.print_graph(v)
diff --git a/mandala/tests/test_deps.py b/mandala/tests/test_deps.py
deleted file mode 100644
index fbabce9..0000000
--- a/mandala/tests/test_deps.py
+++ /dev/null
@@ -1,615 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.deps.shallow_versions import DAG
-import numpy as np
-
-
-def test_dag():
- d = DAG(content_type="code")
- try:
- d.commit("something")
- except AssertionError:
- pass
-
- content_hash_1 = d.init(initial_content="something")
- assert len(d.commits) == 1
- assert d.head == content_hash_1
- content_hash_2 = d.commit(content="something else", is_semantic_change=True)
- assert (
- d.commits[content_hash_2].semantic_hash
- != d.commits[content_hash_1].semantic_hash
- )
- assert d.head == content_hash_2
- content_hash_3 = d.commit(content="something else #2", is_semantic_change=False)
- assert (
- d.commits[content_hash_3].semantic_hash
- == d.commits[content_hash_2].semantic_hash
- )
-
- content_hash_4 = d.sync(content="something else")
- assert content_hash_4 == content_hash_2
- assert d.head == content_hash_2
-
- d.show()
-
-
-MODULE_NAME = "mandala.tests.test_deps"
-DEPS_PACKAGE = "mandala.tests"
-DEPS_PATH = Path(__file__).absolute().resolve()
-# MODULE_NAME = '__main__'
-
-
-def _test_version_reprs(storage: Storage):
- for dag in storage.get_versioner().component_dags.values():
- for compact in [True, False]:
- dag.show(compact=compact)
- for version in storage.get_versioner().get_flat_versions().values():
- storage.get_versioner().present_dependencies(commits=version.semantic_expansion)
- storage.get_versioner().global_topology.show(path=generate_path(ext=".png"))
- repr(storage.get_versioner().global_topology)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_unit(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
-
- # to be able to import this name
- global f_1, A
-
- A = 42
-
- @op
- def f_1(x) -> int:
- return 23 + A
-
- with storage.run():
- f_1(1)
-
- vs = storage.get_versioner()
- f_1_versions = vs.versions[MODULE_NAME, "f_1"]
- assert len(f_1_versions) == 1
- version = f_1_versions[list(f_1_versions.keys())[0]]
- assert set(version.support) == {(MODULE_NAME, "f_1"), (MODULE_NAME, "A")}
- _test_version_reprs(storage=storage)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_libraries(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
- if tracer_impl is SysTracer:
- track = lambda x: x
- else:
- from mandala.deps.tracers.dec_impl import track
-
- global f_2, f_1
-
- @track
- def f_1(x) -> int:
- return 23
-
- # use functions from libraries to make sure we don't trace them
- @op
- def f_2(x) -> int:
- df = pd.DataFrame({"a": [1, 2, 3]})
- array = np.array([1, 2, 3])
- x = array.mean() + np.random.uniform()
- return f_1(x)
-
- with storage.run():
- f_2(1)
-
- vs = storage.get_versioner()
- f_2_versions = vs.versions[MODULE_NAME, "f_2"]
- assert len(f_2_versions) == 1
- version = f_2_versions[list(f_2_versions.keys())[0]]
- assert set(version.support) == {(MODULE_NAME, "f_2"), (MODULE_NAME, "f_1")}
- _test_version_reprs(storage=storage)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_deps(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
- global dep_1, f_2, A
- if tracer_impl is SysTracer:
- track = lambda x: x
- else:
- from mandala.deps.tracers.dec_impl import track
-
- A = 42
-
- @track
- def dep_1(x) -> int:
- return 23
-
- @track
- @op
- def f_2(x) -> int:
- return dep_1(x) + A
-
- with storage.run():
- f_2(1)
-
- vs = storage.get_versioner()
- f_2_versions = vs.versions[MODULE_NAME, "f_2"]
- assert len(f_2_versions) == 1
- version = f_2_versions[list(f_2_versions.keys())[0]]
- assert set(version.support) == {
- (MODULE_NAME, "f_2"),
- (MODULE_NAME, "dep_1"),
- (MODULE_NAME, "A"),
- }
- _test_version_reprs(storage=storage)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_changes(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
-
- global f
-
- @op
- def f(x) -> int:
- return x + 1
-
- with storage.run():
- f(1)
- commit_1 = storage.sync_component(
- component=f,
- is_semantic_change=None,
- )
-
- @op
- def f(x) -> int:
- return x + 2
-
- commit_2 = storage.sync_component(component=f, is_semantic_change=True)
- assert commit_1 != commit_2
- with storage.run():
- f(1)
-
- @op
- def f(x) -> int:
- return x + 1
-
- # confirm we reverted to the previous version
- commit_3 = storage.sync_component(
- component=f,
- is_semantic_change=None,
- )
- assert commit_3 == commit_1
- with storage.run(allow_calls=False):
- f(1)
-
- # create a new branch
- @op
- def f(x) -> int:
- return x + 3
-
- commit_4 = storage.sync_component(
- component=f,
- is_semantic_change=True,
- )
- assert commit_4 not in (commit_1, commit_2)
- with storage.run():
- f(1)
-
- f_versions = storage.get_versioner().versions[MODULE_NAME, "f"]
- assert len(f_versions) == 3
- semantic_versions = [v.semantic_version for v in f_versions.values()]
- assert len(set(semantic_versions)) == 3
- _test_version_reprs(storage=storage)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_superops(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
- Config.enable_ref_magics = True
-
- global f_1, f_2, f_3
-
- @op
- def f_1(x) -> int:
- return x + 1
-
- @superop
- def f_2(x) -> int:
- return f_1(x) + 1
-
- with storage.run():
- f_1(1)
-
- with storage.run(attach_call_to_outputs=True):
- a = f_2(1)
- call: Call = a._call
- version = storage.get_versioner().versions[MODULE_NAME, "f_2"][call.content_version]
- assert set(version.support) == {(MODULE_NAME, "f_1"), (MODULE_NAME, "f_2")}
-
- @superop
- def f_3(x) -> int:
- return f_2(x) + 1
-
- with storage.run(attach_call_to_outputs=True):
- a = f_3(1)
- call: Call = a._call
- version = storage.get_versioner().versions[MODULE_NAME, "f_3"][call.content_version]
- assert set(version.support) == {
- (MODULE_NAME, "f_1"),
- (MODULE_NAME, "f_2"),
- (MODULE_NAME, "f_3"),
- }
- _test_version_reprs(storage=storage)
- Config.enable_ref_magics = False
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_dependency_patterns(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
- global A, B, f_1, f_2, f_3, f_4, f_5, f_6
- if tracer_impl is SysTracer:
- track = lambda x: x
- else:
- from mandala.deps.tracers.dec_impl import track
-
- # global vars
- A = 23
- B = [1, 2, 3]
-
- # using a global var
- @track
- def f_1(x) -> int:
- return x + A
-
- # calling another function
- @track
- def f_2(x) -> int:
- return f_1(x) + B[0]
-
- # different dependencies per call
- @op
- def f_3(x) -> int:
- if x % 2 == 0:
- return f_2(2 * x)
- else:
- return f_1(x + 1)
-
- with storage.run(attach_call_to_outputs=True):
- x = f_3(0)
- call: Call = x._call
- version = storage.get_versioner().get_flat_versions()[call.content_version]
- assert version.support == {
- (MODULE_NAME, "f_3"),
- (MODULE_NAME, "f_2"),
- (MODULE_NAME, "A"),
- (MODULE_NAME, "B"),
- (MODULE_NAME, "f_1"),
- }
- with storage.run(attach_call_to_outputs=True):
- x = f_3(1)
- call: Call = x._call
- version = storage.get_versioner().get_flat_versions()[call.content_version]
- assert version.support == {
- (MODULE_NAME, "f_3"),
- (MODULE_NAME, "f_1"),
- (MODULE_NAME, "A"),
- }
-
- # using a lambda
- @op
- def f_4(x) -> int:
- f = lambda y: f_1(y) + B[0]
- return f(x)
-
- # make sure the call in the lambda is detected
- with storage.run(attach_call_to_outputs=True):
- x = f_4(10)
- call: Call = x._call
- version = storage.get_versioner().get_flat_versions()[call.content_version]
- assert version.support == {
- (MODULE_NAME, "f_4"),
- (MODULE_NAME, "f_1"),
- (MODULE_NAME, "A"),
- (MODULE_NAME, "B"),
- }
-
- # using comprehensions and generators
- @superop
- def f_5(x) -> int:
- x = unwrap(x)
- a = [f_1(y) for y in range(x)]
- b = {f_2(y) for y in range(x)}
- c = {y: f_3(y) for y in range(x)}
- return sum(unwrap(f_4(y)) for y in range(x))
-
- with storage.run():
- f_5(10)
-
- f_5_versions = storage.get_versioner().versions[MODULE_NAME, "f_5"]
- assert len(f_5_versions) == 1
- version = f_5_versions[list(f_5_versions.keys())[0]]
- assert set(version.support) == {
- (MODULE_NAME, "f_5"),
- (MODULE_NAME, "f_4"),
- (MODULE_NAME, "f_3"),
- (MODULE_NAME, "f_2"),
- (MODULE_NAME, "f_1"),
- (MODULE_NAME, "A"),
- (MODULE_NAME, "B"),
- }
-
- # nested comprehensions and generators
- @superop
- def f_6(x) -> int:
- x = unwrap(x)
- # nested list comprehension
- a = sum([sum([f_1(y) for y in range(x)]) for z in range(x)])
- # nested comprehension with generator
- b = sum(sum(f_2(y) for y in range(x)) for z in range(unwrap(f_3(x))))
- return a + b
-
- with storage.run():
- f_6(2)
-
- f_6_versions = storage.get_versioner().versions[MODULE_NAME, "f_6"]
- assert len(f_6_versions) == 1
- version = f_6_versions[list(f_6_versions.keys())[0]]
- assert set(version.support) == {
- (MODULE_NAME, "f_6"),
- (MODULE_NAME, "f_3"),
- (MODULE_NAME, "f_2"),
- (MODULE_NAME, "f_1"),
- (MODULE_NAME, "A"),
- (MODULE_NAME, "B"),
- }
- _test_version_reprs(storage=storage)
- storage.versions(f_6)
- storage.sources(f_6)
- storage.get_code(version_id=version.content_version)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_recursion(tracer_impl):
- ### mutual recursion
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
- Config.enable_ref_magics = True
- global s_1, s_2
-
- @superop
- def s_1(x) -> int:
- if x == 0:
- return 0
- else:
- return s_2(x - 1) + 1
-
- @superop
- def s_2(x) -> int:
- if x == 0:
- return 0
- else:
- return s_1(x - 1) + 1
-
- with storage.run(attach_call_to_outputs=True):
- a = s_1(0)
- call_1 = a._call
- b = s_1(1)
- call_2 = b._call
- version_1 = storage.get_versioner().get_flat_versions()[call_1.content_version]
- assert version_1.support == {(MODULE_NAME, "s_1")}
- version_2 = storage.get_versioner().get_flat_versions()[call_2.content_version]
- assert version_2.support == {(MODULE_NAME, "s_1"), (MODULE_NAME, "s_2")}
- _test_version_reprs(storage=storage)
- Config.enable_ref_magics = False
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_memoized_tracking(tracer_impl):
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
- if tracer_impl is SysTracer:
- track = lambda x: x
- else:
- from mandala.deps.tracers.dec_impl import track
-
- global f_1, f_2, f, g
-
- @track
- def f_1(x):
- return x + 1
-
- @track
- def f_2(x):
- return x + 2
-
- @op
- def f(x) -> int:
- if x % 2 == 0:
- return f_1(x)
- else:
- return f_2(x)
-
- @superop
- def g(x) -> List[int]:
- return [f(i) for i in range(unwrap(x))]
-
- with storage.run():
- for i in range(10):
- f(i)
-
- with storage.run():
- z = g(10)
-
-
-@pytest.mark.parametrize("tracer_impl", [SysTracer, DecTracer])
-def test_transient(tracer_impl):
- global f
-
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
-
- @op
- def f(x) -> int:
- return Transient(x + 1)
-
- with storage.run():
- a = f(42)
- with storage.run(recompute_transient=True):
- a = f(42)
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_queries_unit(tracer_impl):
- Config.query_engine = "sql" # doesn't work yet with the naive engine
-
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
-
- global g_1
-
- ### create an op, run it and check the query result
- @op
- def g_1(x) -> int:
- return x + 2
-
- with storage.run():
- [g_1(i) for i in range(10)]
-
- with storage.query():
- i = Q()
- j = g_1(i)
- df = storage.df(i.named("i"), j.named("j"))
-
- assert df.shape == (10, 2)
- assert (df["j"] == df["i"] + 2).all()
-
- ### change the op semantically and check that the query result is empty
- @op
- def g_1(x) -> int:
- return x + 3
-
- storage.sync_component(
- component=g_1,
- is_semantic_change=True,
- )
-
- with storage.query():
- i = Q()
- j = g_1(i)
- df = storage.df(i.named("i"), j.named("j"))
-
- assert df.empty
-
- ### run the op again and check that the query result is correct for the new version
- with storage.run():
- [g_1(i) for i in range(10)]
-
- with storage.query():
- i = Q()
- j = g_1(i)
- df = storage.df(i.named("i"), j.named("j"))
-
- assert df.shape == (10, 2)
- assert (df["j"] == df["i"] + 3).all()
-
- ### go back to the old version and check that the query result is correct
- @op
- def g_1(x) -> int:
- return x + 2
-
- storage.sync_component(
- component=g_1,
- is_semantic_change=None,
- )
- with storage.query():
- i = Q()
- j = g_1(i)
- df = storage.df(i.named("i"), j.named("j"))
- assert df.shape == (10, 2)
- assert (df["j"] == df["i"] + 2).all()
-
-
-@pytest.mark.parametrize("tracer_impl", [DecTracer, SysTracer])
-def test_queries_multiple_versions(tracer_impl):
- Config.query_engine = "sql" # doesn't work yet with the naive engine
- storage = Storage(
- deps_path=DEPS_PATH, deps_package=DEPS_PACKAGE, tracer_impl=tracer_impl
- )
-
- global f_1, f_2, f_3
-
- ### define an op with multiple semantically-compatible versions
- @op
- def f_1(x) -> int:
- return x + 2
-
- @op
- def f_2(x) -> int:
- return x + 3
-
- @op
- def f_3(x) -> int:
- if x % 2 == 0:
- return f_2(2 * x)
- else:
- return f_1(x + 1)
-
- with storage.run():
- for i in range(10): # make sure both versions are used
- f_3(i)
-
- with storage.query():
- i = Q()
- j = f_3(i)
- df_1 = storage.df(i.named("i"), j.named("j"))
- assert df_1.shape == (10, 2)
- assert (df_1["j"] == df_1["i"].apply(f_3)).all()
-
- # change one of the dependencies semantically and check that the query
- # result is what remains from the other branch
- @op
- def f_1(x) -> int:
- return x + 4
-
- storage.sync_component(
- f_1,
- is_semantic_change=True,
- )
-
- with storage.query():
- i = Q()
- j = f_3(i)
- df_2 = storage.df(i.named("i"), j.named("j"))
- assert df_2.shape == (5, 2)
- assert sorted(df_2["i"].values.tolist()) == [0, 2, 4, 6, 8]
- assert (df_2["j"] == df_2["i"].apply(f_3)).all()
-
- ### go back to the old version and check that the query result is recovered
- @op
- def f_1(x) -> int:
- return x + 2
-
- storage.sync_component(
- f_1,
- is_semantic_change=None,
- )
-
- with storage.query():
- i = Q()
- j = f_3(i)
- df_3 = storage.df(i.named("i"), j.named("j"))
- assert (df_1 == df_3).all().all()
diff --git a/mandala/tests/test_memoization.py b/mandala/tests/test_memoization.py
index 2c6eccc..2f47bda 100644
--- a/mandala/tests/test_memoization.py
+++ b/mandala/tests/test_memoization.py
@@ -1,142 +1,133 @@
-from mandala.all import *
-from mandala.tests.utils import *
+from mandala._next.imports import *
-def test_func_creation():
+def test_storage():
storage = Storage()
- @op
- def add(x: int, y: int = 42) -> int:
- return x + y
-
- assert add.func_op.sig.n_outputs == 1
- assert add.func_op.sig.input_names == {"x", "y"}
- assert add.func_op.sig.defaults == {"y": 42}
- check_invariants(storage)
-
-
-@pytest.mark.parametrize("storage", generate_storages())
-def test_computation(storage):
@op
def inc(x: int) -> int:
return x + 1
+
+ with storage:
+ x = 1
+ y = inc(x)
+ z = inc(2)
+ w = inc(y)
+
+ assert w.cid == z.cid
+ assert w.hid != y.hid
+ assert w.cid != y.cid
+ assert storage.unwrap(y) == 2
+ assert storage.unwrap(z) == 3
+ assert storage.unwrap(w) == 3
+ for ref in (y, z, w):
+ assert storage.attach(ref).in_memory
+ assert storage.attach(ref).obj == storage.unwrap(ref)
+
+
+def test_signatures():
+ storage = Storage()
- @op
- def add(x: int, y: int) -> int:
- return x + y
+ @op # a function with a wild input/output signature
+ def add(x, *args, y: int = 1, **kwargs):
+ # just sum everything
+ res = x + sum(args) + y + sum(kwargs.values())
+ if kwargs:
+ return res, kwargs
+ elif args:
+ return None
+ else:
+ return res
+
+ with storage:
+ # call the func in all the ways
+ sum_1 = add(1)
+ sum_2 = add(1, 2, 3, 4, )
+ sum_3 = add(1, 2, 3, 4, y=5)
+ sum_4 = add(1, 2, 3, 4, y=5, z=6)
+ sum_5 = add(1, 2, 3, 4, z=5, w=7)
+
+ assert storage.unwrap(sum_1) == 2
+ assert storage.unwrap(sum_2) == None
+ assert storage.unwrap(sum_3) == None
+ assert storage.unwrap(sum_4) == (21, {'z': 6})
+ assert storage.unwrap(sum_5) == (23, {'z': 5, 'w': 7})
+
+
+def test_retracing():
+ storage = Storage()
- # chain some functions
- with storage.run():
- x = 23
- y = inc(x)
- z = add(x, y=y)
+ @op
+ def inc(x):
+ return x + 1
- check_invariants(storage)
- # run it again
- with storage.run():
- x = 23
- y = inc(x)
- z = add(x, y)
- check_invariants(storage)
- # do some more things
- with storage.run():
- x = 42
- y = inc(x)
- z = add(x, y)
+ ### iterating a function
+ with storage:
+ start = 1
for i in range(10):
- z = add(z, i)
- check_invariants(storage)
+ start = inc(start)
+ with storage:
+ start = 1
+ for i in range(10):
+ start = inc(start)
-@pytest.mark.parametrize("storage", generate_storages())
-def test_retracing(storage):
- @op
- def inc(x: int) -> int:
- return x + 1
-
+ ### composing functions
@op
- def add(x: int, y: int) -> int:
+ def add(x, y):
return x + y
-
- with storage.run():
- x = 23
- y = inc(x)
- z = add(x, y)
-
- with storage.run(allow_calls=False):
- x = 23
- y = inc(x)
- z = add(x, y)
-
- try:
- with storage.run(allow_calls=False):
- x = 24
- y = inc(x)
- z = add(x, y)
- assert False
- except Exception as e:
- assert True
-
-
-def test_debugging():
+
+ with storage:
+ inp = [1, 2, 3, 4, 5]
+ stage_1 = [inc(x) for x in inp]
+ stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)]
+
+ with storage:
+ inp = [1, 2, 3, 4, 5]
+ stage_1 = [inc(x) for x in inp]
+ stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)]
+
+
+def test_lists():
storage = Storage()
@op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage.run(debug_calls=True):
- x = 23
- y = inc(x)
- z = add(x, y)
-
- with storage.run(debug_calls=True):
- x = 23
- y = inc(x)
- z = add(x, y)
-
-
-def _a():
+ def get_sum(elts: MList[int]) -> int:
+ return sum(elts)
+
@op
- def generate_dataset() -> Tuple[int, int]:
- return 1, 2
-
+ def primes_below(n: int) -> MList[int]:
+ primes = []
+ for i in range(2, n):
+ for p in primes:
+ if i % p == 0:
+ break
+ else:
+ primes.append(i)
+ return primes
+
@op
- def train_model(
- train_dataset: int,
- test_dataset: int,
- learning_rate: float,
- batch_size: int,
- num_epochs: int,
- ) -> Tuple[int, float]:
- return train_dataset + test_dataset + learning_rate, batch_size + num_epochs
-
- storage = Storage()
-
- with storage.run():
- X, y = generate_dataset()
- for batch_size in (100, 200, 400):
- for learning_rate in (1, 2, 3):
- model, acc = train_model(
- X,
- y,
- learning_rate=learning_rate,
- batch_size=batch_size,
- num_epochs=10,
- )
-
- with storage.run():
- X, y = generate_dataset()
- for batch_size in (100, 200, 400):
- for learning_rate in (1, 2, 3):
- model, acc = train_model(
- X,
- y,
- learning_rate=learning_rate,
- batch_size=batch_size,
- num_epochs=10,
- )
+ def chunked_square(elts: MList[int]) -> MList[int]:
+ # a model for an op that does something on chunks of a big thing
+ # to prevent OOM errors
+ return [x*x for x in elts]
+
+ with storage:
+ n = 10
+ primes = primes_below(n)
+ sum_primes = get_sum(primes)
+ assert len(primes) == 4
+ # check indexing
+ assert storage.unwrap(primes[0]) == 2
+ assert storage.unwrap(primes[:2]) == [2, 3]
+
+ ### lists w/ overlapping elements
+ with storage:
+ n = 100
+ primes = primes_below(n)
+ for i in range(0, len(primes), 2):
+ sum_primes = get_sum(primes[:i+1])
+
+ with storage:
+ elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ squares = chunked_square(elts)
diff --git a/mandala/tests/test_queries.py b/mandala/tests/test_queries.py
deleted file mode 100644
index 37a642e..0000000
--- a/mandala/tests/test_queries.py
+++ /dev/null
@@ -1,380 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.queries.viz import visualize_graph
-from mandala.queries.weaver import *
-from mandala.queries.graphs import *
-from mandala.queries.viz import *
-
-OUTPUT_ROOT = Path(__file__).parent / "output/"
-
-
-def get_graph(vqs: Set[ValNode]) -> InducedSubgraph:
- vqs, fqs = traverse_all(vqs=vqs, direction="both")
- return InducedSubgraph(vqs=vqs, fqs=fqs)
-
-
-@pytest.mark.parametrize("storage", generate_storages())
-def test_queries_basics(storage):
- Config.query_engine = "_test"
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage.run():
- for i in range(20, 25):
- j = inc(i)
- final = add(i, y=j)
-
- with storage.query():
- i = Q().named("i")
- j = inc(i).named("j")
- final = add(i, y=j).named("final")
- df = storage.df(i, j, final)
- assert set(df["i"]) == {i for i in range(20, 25)}
- assert all(df["j"] == df["i"] + 1)
- df_refs = storage.df(i, j, final, values="refs")
- df_uids = storage.df(i, j, final, values="uids")
- df_lazy = storage.df(i, j, final, values="lazy")
- assert compare_dfs_as_relations(df_refs.applymap(lambda x: x.uid), df_uids)
- assert compare_dfs_as_relations(
- df_refs.applymap(lambda x: x.detached()), df_lazy
- )
- assert compare_dfs_as_relations(df_refs.applymap(unwrap), df)
- assert compare_dfs_as_relations(df_lazy.applymap(unwrap), df)
-
- check_invariants(storage)
- vqs, fqs = traverse_all([i, j, final])
- visualize_graph(
- vqs=vqs, fqs=fqs, output_path=OUTPUT_ROOT / "test_basics.svg", names=None
- )
-
-
-def test_empty():
- storage = Storage()
-
- with storage.query():
- try:
- df = storage.df()
- except:
- pass
-
-
-def test_queries_static_builder():
- """
- A one-stop test for all of the following cases:
- - queries to match data structures incl. nested structures
- - queries to match elements of data structures
- - wrapping raw objects into queries
- - merging the query graph from a run block and a query block
-
- """
- storage = Storage()
-
- @op
- def f(x) -> int:
- return x + 1
-
- @op
- def g(x, y) -> Tuple[int, int]:
- return x + y, x * y
-
- @op
- def avg(numbers: list) -> float:
- return sum(numbers) / len(numbers)
-
- @op
- def avg_dict(numbers: dict) -> float:
- return sum(numbers.values()) / len(numbers)
-
- @op
- def avg_set(numbers: set) -> float:
- return sum(numbers) / len(numbers)
-
- @op
- def repeat(x, n) -> list:
- return [x] * n
-
- @op
- def dictify(x, y) -> dict:
- return {"x": x, "y": y}
-
- @op
- def nested_dictify(x, y) -> Dict[str, List[int]]:
- return {"x": [x, x], "y": [y, y]}
-
- def get_graph_1():
- with storage.query():
- x = Q()
- y = f(x)
- z, w = g(y, x)
- ### constructive stuff
- lst_1 = qwrap([z, w, {Q(): x, ...: ...}, {x, y, ...}, ...])
- a_1 = avg(lst_1)
- dct_1 = qwrap({Q(): z, ...: ...})
- a_2 = avg_dict(dct_1)
- st_1 = qwrap({x, y, ...})
- a_3 = avg_set(st_1)
- ### destructive stuff
- lst_2 = repeat(x, Q())
- elt_0 = lst_2[Q()]
- dct_2 = dictify(x, y)
- val_0 = dct_2[Q()]
- dct_3 = nested_dictify(x, y)
- val_1 = dct_3[Q()]
- elt_1 = val_1[Q()]
- res = get_graph(vqs={elt_1, val_0, elt_0, a_3, a_2, a_1})
- return res
-
- def get_graph_2():
- with storage.query():
- x = Q()
- y = f(x)
- z, w = g(y, x)
- helper_1 = BuiltinQueries.DictQ(dct={Q(): x})
- helper_2 = BuiltinQueries.SetQ(elts={x, y})
- lst_1 = BuiltinQueries.ListQ(elts=[z, w, helper_1, helper_2])
- a_1 = avg(lst_1)
- dct_1 = BuiltinQueries.DictQ(dct={Q(): z})
- a_2 = avg_dict(dct_1)
- st_1 = BuiltinQueries.SetQ(elts={x, y})
- a_3 = avg_set(st_1)
- lst_2 = repeat(x, Q())
- elt_0 = lst_2[Q()]
- dct_2 = dictify(x, y)
- val_0 = dct_2[Q()]
- dct_3 = nested_dictify(x, y)
- val_1 = dct_3[Q()]
- elt_1 = val_1[Q()]
- res = get_graph(vqs={elt_1, val_0, elt_0, a_3, a_2, a_1})
- return res
-
- g_1, g_2 = get_graph_1(), get_graph_2()
- assert InducedSubgraph.are_canonically_isomorphic(g_1, g_2)
-
-
-def test_queries_visualization():
- storage = Storage()
-
- @op
- def f(x: int, y: int) -> Tuple[int, int, int]:
- return x, y, x + y
-
- @op
- def g() -> int:
- return 42
-
- @op
- def h(z: int, w: int):
- pass
-
- for func in [f, g, h]:
- storage.synchronize(func)
-
- with storage.query():
- a = g().named("a")
- b = Q().named("b")
- c, d, e = f(x=a, y=b)
- c.named("c"), d.named("d"), e.named("e")
- h(z=c, w=d)
- h(z=a, w=b)
- x, y, z = f(x=d, y=e)
- x.named("x"), y.named("y"), z.named("z")
- storage.draw_graph(a, b, c, d, e, traverse="both", show_how="none")
- storage.print_graph(a, b, c, d, e, traverse="both")
-
-
-def test_queries_exceptions():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return unwrap(x) + 1
-
- try:
- # passing raw values in queries should (currently) raise an error
- with storage.query():
- y = inc(23)
- assert False
- except Exception:
- assert True
-
-
-def test_queries_filter_duplicates():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage.run():
- for x in range(5):
- for y in range(5):
- z = inc(x)
- print(z.causal_uid)
- w = add(z, y)
-
- with storage.query():
- x = Q().named("x")
- y = Q().named("y")
- z = inc(x).named("z")
- w = add(z, y).named("w")
- df_1 = storage.df(x, z, drop_duplicates=True, context=False)
- df_2 = storage.df(x, z, drop_duplicates=False, context=False)
- df_3 = storage.df(
- x,
- z,
- drop_duplicates=True,
- context=False,
- engine="naive",
- _visualize_steps_at=OUTPUT_ROOT,
- )
- df_4 = storage.df(
- x,
- z,
- drop_duplicates=False,
- context=False,
- engine="naive",
- _visualize_steps_at=OUTPUT_ROOT,
- )
-
- assert len(df_1) == 5
- assert len(df_2) == 25
- assert compare_dfs_as_relations(df_1, df_3)
- assert compare_dfs_as_relations(df_2, df_4)
-
-
-def test_queries_weird():
- Config.autowrap_inputs = True
- Config.autounwrap_inputs = True
- Config.query_engine = "_test"
-
- storage = Storage()
-
- @op
- def a(f: int, g: int):
- return
-
- @op
- def b(k: int, l: int):
- return
-
- @op
- def c(h: int, i: int) -> int:
- return h + i
-
- @op
- def d(j: int) -> Tuple[int, int, int]:
- return j, j, j
-
- @op
- def e(m: int) -> int:
- return m + 1
-
- for f in [a, b, c, d, e]:
- storage.synchronize(f)
-
- with storage.run():
- var_0 = 23
- var_1 = 42
- a(var_0, var_0)
- b(var_0, var_0)
- a(f=var_0, g=var_0)
- var_2 = c(h=var_0, i=var_0)
- a(f=var_0, g=var_0)
- var_3, var_4, var_5 = d(j=var_1)
- b(k=var_4, l=var_1)
- var_6 = e(m=var_2)
-
- with storage.query():
- var_0 = Q()
- var_1 = Q()
- a(f=var_0, g=var_0)
- b(k=var_0, l=var_0)
- a(f=var_0, g=var_0)
- var_2 = c(h=var_0, i=var_0)
- a(f=var_0, g=var_0)
- var_3, var_4, var_5 = d(j=var_1)
- b(k=var_4, l=var_1)
- var_6 = e(m=var_2)
- df = storage.df(var_0, var_1, var_2, var_3, var_4, var_5, var_6)
-
-
-def test_queries_required_args():
- storage = Storage()
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- with storage.run():
- for x in [1, 2, 3]:
- for y in [4, 5, 6]:
- add(x, y)
-
- with storage.query():
- x = Q().named("x")
- result = add(x=x).named("result")
- df = storage.df(x, result)
- assert df.shape == (9, 2)
-
-
-def test_generalized():
- Config.query_engine = "sql"
- storage = Storage()
-
- @op
- def create_list(x) -> list:
- return [x + i for i in range(10)]
-
- @op
- def consume_list(nums: list) -> int:
- return sum(nums)
-
- with storage.run():
- for x in wrap(list(range(10))):
- nums = create_list(x)
- for i in [2, 4, 6, 8, 10]:
- res = consume_list(nums[:i])
- df = storage.similar(res, context=True)
- assert df.shape[0] == 300
-
- with storage.query():
- idx0 = Q()
- idx0.pin(0)
- x = Q()
- nums = create_list(x=x)
- a0 = nums[idx0]
- a1 = ListQ(elts=[a0], idxs=[idx0])
- res = consume_list(nums=a1)
- result = storage.df(idx0, x, nums, a0, a1, res)
- assert result.shape[0] == 50
-
-
-def test_defaults():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- z = inc(23)
-
- @op
- def inc(x: int, amount: int = 1) -> int:
- return x + amount
-
- with storage.run():
- z = inc(23, 10)
-
- df = storage.similar(z)
- assert df.shape[0] == 2
diff --git a/mandala/tests/test_refactoring.py b/mandala/tests/test_refactoring.py
deleted file mode 100644
index 645033f..0000000
--- a/mandala/tests/test_refactoring.py
+++ /dev/null
@@ -1,408 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-def test_add_input():
- Config.autowrap_inputs = True
- Config.autounwrap_inputs = True
-
- storage = Storage()
-
- ############################################################################
- ### check that old calls are preserved
- ############################################################################
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- x = inc(23)
-
- @op
- def inc(x: int, y=1) -> int:
- return x + y
-
- with storage.run():
- x = inc(23)
-
- assert inc.func_op.sig.input_names == {"x", "y"}
- df = storage.get_table(inc)
- assert all(c in df.columns for c in ["x", "y"])
- assert df.shape[0] == 1
-
- ############################################################################
- ### check that defaults can be overridden
- ############################################################################
- @op
- def add_many(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=add_many)
-
- @op
- def add_many(x: int, y: int = 23, z: int = 42) -> int:
- return x + y + z
-
- storage.synchronize(f=add_many)
-
- with storage.run():
- add_many(0)
- add_many(0, 1)
- add_many(0, 1, 2)
-
- with storage.query():
- x, y, z = Q(), Q(), Q()
- w = add_many(x, y, z)
- df = storage.df(x, y, z, w)
-
- rows = set(tuple(row) for row in df.values.tolist())
- assert rows == {(0, 1, 2, 3), (0, 1, 42, 43), (0, 23, 42, 65)}
-
- ############################################################################
- ### check that queries work with defaults
- ############################################################################
- with storage.query() as q:
- x = Q()
- w = add_many(x)
- df = storage.df(x, w)
-
- ############################################################################
- ### check that invalid ways to add an input are not allowed
- ############################################################################
-
- ### no default
- @op
- def no_default(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=no_default)
-
- try:
-
- @op
- def no_default(x: int, y: int) -> int:
- return x + y
-
- storage.synchronize(f=no_default)
- except:
- assert True
-
-
-def test_add_input_bug():
- """
- The issue was that the sync logic was expecting to see the UIDs for the
- defaults upon re-synchronizing the updated version.
- """
- storage = Storage()
-
- @op
- def f() -> int:
- return 1
-
- with storage.run():
- f()
-
- @op
- def f(x: int = 23) -> int:
- return x
-
- with storage.run():
- f()
-
- @op
- def f(x: int = 23) -> int:
- return x
-
- with storage.run():
- f()
-
-
-def test_default_change():
- """
- Changing default values is not allowed for @ops
- """
- storage = Storage()
-
- @op
- def f(x: int = 23) -> int:
- return x
-
- with storage.run():
- a = f()
-
- @op
- def f(x: int = 42) -> int:
- return x
-
- try:
- with storage.run():
- b = f()
- except:
- assert True
-
-
-def test_func_renaming():
-
- storage = Storage()
-
- ############################################################################
- ### unit test
- ############################################################################
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- x = inc(23)
-
- storage.rename_func(func=inc, new_name="inc_new")
- assert inc.is_invalidated
-
- # define correct function
- @op
- def inc_new(x: int) -> int:
- return x + 1
-
- with storage.run():
- inc_new(23)
-
- df = storage.get_table(inc_new)
- # make sure the call was not new
- assert df.shape[0] == 1
- # make sure we did not create a new function
- sigs = [v for k, v in storage.sig_adapter.load_state().items() if not v.is_builtin]
- assert len(sigs) == 1
-
- ############################################################################
- ### check that name collisions are not allowed
- ############################################################################
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def new_inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=inc)
- storage.synchronize(f=new_inc)
-
- try:
- storage.rename_func(func=inc, new_name="new_inc")
- except:
- assert True
-
- ############################################################################
- ### permute names
- ############################################################################
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def new_inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=inc)
- storage.synchronize(f=new_inc)
-
- storage.rename_func(func=inc, new_name="temp")
- storage.rename_func(func=new_inc, new_name="inc")
-
- @op
- def temp(x: int) -> int:
- return x + 1
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=temp)
- storage.synchronize(f=inc)
-
- storage.rename_func(func=temp, new_name="new_inc")
-
-
-def test_arg_renaming():
- storage = Storage()
-
- ############################################################################
- ### unit test
- ############################################################################
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- x = inc(23)
-
- storage.rename_arg(func=inc, name="x", new_name="x_new")
- assert inc.is_invalidated
-
- # define correct function
- @op
- def inc(x_new: int) -> int:
- return x_new + 1
-
- with storage.run():
- x = inc(23)
-
- df = storage.get_table(inc)
- # make sure the call was not new
- assert df.shape[0] == 1
- # make sure we did not create a new function
- sigs = [v for k, v in storage.sig_adapter.load_state().items() if not v.is_builtin]
- assert len(sigs) == 1
-
- ############################################################################
- ### check collisions are not allowed
- ############################################################################
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- storage.synchronize(f=add)
- try:
- storage.rename_arg(func=add, name="x", new_name="y")
- except:
- assert True
-
-
-def test_renaming_failures_1():
- """
- Try to do a rename on a function that was invalidated
- """
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=inc)
-
- storage.rename_func(func=inc, new_name="inc_new")
- try:
- storage.rename_func(func=inc, new_name="inc_other")
- except:
- assert True
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- storage.synchronize(f=add)
-
- storage.rename_arg(func=add, name="x", new_name="z")
- try:
- storage.rename_arg(func=add, name="y", new_name="w")
- except:
- assert True
-
-
-def test_renaming_failures_2():
- """
- Try renaming a function to a name that already exists for another function
- """
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- for f in (inc, add):
- storage.synchronize(f=f)
-
- try:
- storage.rename_func(func=inc, new_name="add")
- except:
- assert True
-
-
-def test_renaming_inside_context_1():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=inc)
-
- try:
- with storage.run():
- storage.rename_func(func=inc, new_name="inc_new")
- inc(23)
- except:
- assert True
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- storage.synchronize(f=add)
-
- try:
- with storage.run():
- storage.rename_arg(func=add, name="x", new_name="z")
- add(23, 42)
- except:
- assert True
-
-
-def test_renaming_inside_context_2():
- """
- Like the previous one, but with uncommitted work
- """
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=inc)
-
- try:
- with storage.run():
- inc(23)
- storage.rename_func(func=inc, new_name="inc_new")
- except:
- assert True
-
- @op
- def add(x: int, y: int) -> int:
- return x + y
-
- storage.synchronize(f=add)
-
- try:
- with storage.run():
- add(23, 42)
- storage.rename_arg(func=add, name="x", new_name="z")
- except:
- assert True
-
-
-def test_other_refactoring_failures():
- storage = Storage()
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- storage.synchronize(f=inc)
-
- @op
- def inc(y: int) -> int:
- return y + 1
-
- try:
- storage.synchronize(f=inc)
- except:
- assert True
diff --git a/mandala/tests/test_stateful_bugs.py b/mandala/tests/test_stateful_bugs.py
deleted file mode 100644
index f2c3883..0000000
--- a/mandala/tests/test_stateful_bugs.py
+++ /dev/null
@@ -1,273 +0,0 @@
-from .test_stateful_slow import SingleClientSimulator, MultiClientSimulator
-
-
-def test_1():
- state = SingleClientSimulator(n_clients=1)
- state.add_workflow()
- state.add_input_var_to_workflow()
- state.create_op()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.add_input()
- state.execute_workflow()
- state.execute_workflow()
-
-
-def test_2():
- state = SingleClientSimulator(n_clients=1)
- state.add_workflow()
- state.create_op()
- state.add_input_var_to_workflow()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.add_input()
- state.execute_workflow()
- state.verify_state()
-
-
-def test_3():
- state = SingleClientSimulator(n_clients=1)
- state.create_op()
- state.add_workflow()
- state.add_workflow()
- state.rename_func()
- state.create_op()
- state.add_workflow()
- state.create_op()
- state.add_input()
- state.create_op()
- state.add_workflow()
- state.rename_input()
- state.add_workflow()
- state.rename_input()
- state.create_op()
- state.rename_input()
- state.add_input_var_to_workflow()
- state.rename_func()
- state.add_input()
- state.add_op_to_workflow()
- state.rename_input()
- state.add_call_to_workflow()
- state.add_input()
- state.add_input()
- state.add_input_var_to_workflow()
- state.add_op_to_workflow()
- state.rename_func()
- state.add_op_to_workflow()
- state.rename_func()
- state.add_input_var_to_workflow()
- state.rename_func()
- state.rename_func()
- state.rename_input()
- state.rename_func()
- state.rename_func()
- state.add_input()
- state.execute_workflow()
- state.add_op_to_workflow()
- state.add_op_to_workflow()
- state.rename_input()
- state.add_op_to_workflow()
- state.execute_workflow()
- state.rename_input()
- state.add_op_to_workflow()
- state.execute_workflow()
- state.add_op_to_workflow()
- state.rename_func()
- state.add_call_to_workflow()
- state.add_input_var_to_workflow()
- state.rename_func()
- state.rename_func()
- state.execute_workflow()
- state.rename_func()
- state.add_call_to_workflow()
- state.rename_func()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.add_call_to_workflow()
- state.add_input_var_to_workflow()
- state.add_call_to_workflow()
- state.rename_input()
- state.rename_func()
- state.rename_input()
- state.rename_func()
- state.execute_workflow()
- state.add_input_var_to_workflow()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.add_call_to_workflow()
- state.execute_workflow()
-
-
-def test_4():
- state = SingleClientSimulator(n_clients=1)
- state.create_op()
- state.create_new_version()
- state.rename_func()
- state.add_input()
-
-
-def test_5():
- state = SingleClientSimulator(n_clients=1)
- state.add_workflow()
- state.add_input_var_to_workflow()
- state.create_op()
- state.add_input()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.create_new_version()
- state.add_input_var_to_workflow()
- state.add_input()
- state.rename_input()
- state.create_new_version()
- state.add_input_var_to_workflow()
- state.add_op_to_workflow()
- state.rename_input()
- state.add_op_to_workflow()
- state.verify_state()
-
-
-def test_6():
- state = SingleClientSimulator(n_clients=1)
- state.add_workflow()
- state.create_op()
- state.add_input()
- state.rename_func()
- state.rename_func()
- state.rename_input()
- state.create_op()
- state.add_input()
- state.add_input_var_to_workflow()
- state.add_op_to_workflow()
- state.add_input()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.add_input()
- state.add_op_to_workflow()
- state.execute_workflow()
- state.add_call_to_workflow()
- state.execute_workflow()
- state.verify_state()
-
-
-def test_7():
- state = SingleClientSimulator(n_clients=1)
- state.create_op()
- state.create_new_version()
- state.rename_func()
-
-
-def _test_8():
- state = MultiClientSimulator(n_clients=3)
- state.create_op()
- state.create_op()
- state.add_input()
- state.add_input()
- state.create_op()
- state.add_input()
- state.add_input()
- state.add_input()
- state.rename_input()
- state.create_op()
- state.add_input()
- state.add_input()
- state.add_input()
- state.rename_input()
- state.add_input()
- state.add_input()
-
-
-def _test_9():
- state = MultiClientSimulator(n_clients=3)
- state.create_op()
- state.rename_input()
- state.sync_one()
- state.sync_one()
- state.add_workflow()
- state.rename_func()
- state.add_workflow()
- state.add_input()
- state.create_new_version()
- state.add_input_var_to_workflow()
- state.add_input()
- state.add_op_to_workflow()
- state.rename_func()
- state.add_call_to_workflow()
- state.execute_workflow()
- state.add_call_to_workflow()
- state.add_input()
- state.add_input_var_to_workflow()
- state.sync_one()
-
-
-def _test_10():
- state = MultiClientSimulator()
- state.add_workflow()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.create_op()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.execute_workflow()
- state.add_input()
- state.sync_one()
-
-
-def _test_11():
- state = MultiClientSimulator(n_clients=3)
- state.add_workflow()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.create_op()
- state.add_input()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.execute_workflow()
- state.add_input()
- state.sync_one()
- state.verify_state()
-
-
-def _test_12():
- state = MultiClientSimulator(n_clients=3)
- state.add_workflow()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.create_op()
- state.add_input_var_to_workflow()
- state.add_input_var_to_workflow()
- state.add_op_to_workflow()
- state.add_call_to_workflow()
- state.execute_workflow()
- state.add_input()
- state.sync_one()
- state.verify_state()
-
-
-def _test_13():
- state = MultiClientSimulator(n_clients=3)
- state.create_op()
- state.rename_func()
- state.rename_input()
- state.add_input()
- state.rename_func()
- state.rename_func()
- state.add_input()
- state.create_op()
- state.add_workflow()
- state.check_mock_storage_single()
- state.create_op()
- state.sync_one()
- state.create_op()
- state.add_input()
- state.rename_func()
- state.add_input()
- state.rename_input()
- state.check_mock_storage_single()
- state.rename_func()
- state.create_op()
- state.create_op()
- state.add_input()
- state.create_new_version()
- state.verify_state()
diff --git a/mandala/tests/test_stateful_slow.py b/mandala/tests/test_stateful_slow.py
deleted file mode 100644
index 397a560..0000000
--- a/mandala/tests/test_stateful_slow.py
+++ /dev/null
@@ -1,737 +0,0 @@
-from hypothesis.stateful import (
- RuleBasedStateMachine,
- Bundle,
- rule,
- initialize,
- precondition,
- invariant,
- run_state_machine_as_test,
-)
-from hypothesis._settings import settings, Verbosity
-from hypothesis import strategies as st
-
-from mandala.common_imports import *
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.tests.stateful_utils import *
-from mandala.queries.workflow import Workflow, CallStruct
-from mandala.core.utils import Hashing, get_uid, parse_full_uid
-from mandala.queries.compiler import *
-from mandala.core.model import Type, ListType
-from mandala.core.builtins_ import Builtins
-from mandala.core.sig import _get_return_annotations
-from mandala.storages.remote_storage import RemoteStorage
-from mandala.ui.executors import SimpleWorkflowExecutor
-from mandala.ui.funcs import FuncInterface
-from mandala.ui.storage import make_delayed
-
-
-class MockStorage:
- """
- A simple storage simulator that
- - stores all data in memory: calls as tables, vrefs as a dictionary
- - only uses internal names for signatures;
- - can be synced in a "naive" way with another storage: the state of the
- other storage is upserted into this storage.
- - note: currently, it replicates only the memoization tables and the table
- of values (not the causal UIDs table)
-
- It is invariant-checked at the entry and exit of every method. The
- state of this object should only be changed through these methods. This
- makes it easy to track down when an inconsistent update happened.
- """
-
- def __init__(self):
- self.calls: Dict[str, pd.DataFrame] = {}
- self.values: Dict[str, Any] = {}
- # versioned internal op name -> (internal input name -> default uid)
- self.default_uids: Dict[str, Dict[str, str]] = {}
- for builtin_op in Builtins.OPS.values():
- name = builtin_op.sig.versioned_internal_name
- self.calls[name] = pd.DataFrame(
- columns=list(builtin_op.sig.input_names) + Config.special_call_cols
- )
- self.default_uids[name] = {}
- self.check_invariants()
-
- def __eq__(self, other: Any):
- if not isinstance(other, MockStorage):
- return False
- values_equal = self.values == other.values
- default_uids_equal = self.default_uids == other.default_uids
- calls_equal = self.calls.keys() == other.calls.keys() and all(
- compare_dfs_as_relations(self.calls[k], other.calls[k])
- for k in self.calls.keys()
- )
- return values_equal and default_uids_equal and calls_equal
-
- def check_invariants(self):
- assert self.default_uids.keys() == self.calls.keys()
- # get all vref uids that appear in calls
- vref_uids_from_calls = []
- for k, df in self.calls.items():
- for col in df.columns:
- if col not in Config.special_call_cols:
- vref_uids_from_calls += df[col].values.tolist()
- vref_uids_from_calls = [parse_full_uid(x)[0] for x in vref_uids_from_calls]
- assert set(vref_uids_from_calls) <= set(self.values.keys())
- for versioned_internal_name, defaults in self.default_uids.items():
- df = self.calls[versioned_internal_name]
- for internal_input_name, default_uid in defaults.items():
- assert internal_input_name in df.columns
-
- def create_op(self, func_op: FuncOp):
- self.check_invariants()
- sig = func_op.sig
- if sig.versioned_internal_name in self.calls.keys():
- raise ValueError()
- if sig.versioned_internal_name in self.default_uids:
- raise ValueError()
- self.calls[sig.versioned_internal_name] = pd.DataFrame(
- columns=Config.special_call_cols
- + list(sig.ui_to_internal_input_map.values())
- + [dump_output_name(index=i) for i in range(sig.n_outputs)]
- )
- self.default_uids[sig.versioned_internal_name] = {}
- self.check_invariants()
-
- def add_input(
- self,
- func_op: FuncOp,
- internal_name: str,
- default_value: Any,
- default_full_uid: str,
- ):
- self.check_invariants()
- default_uid, default_causal_uid = parse_full_uid(default_full_uid)
- sig = func_op.sig
- df = self.calls[sig.versioned_internal_name]
- df[internal_name] = [default_full_uid for _ in range(len(df))]
- self.values[default_uid] = default_value
- self.default_uids[sig.versioned_internal_name][internal_name] = default_full_uid
- self.check_invariants()
-
- def rename_func(self, func_op: FuncOp, new_name: str):
- pass
-
- def rename_input(self, func_op: FuncOp, old_name: str, new_name: str):
- pass
-
- def create_new_version(self, new_version: FuncOp):
- self.check_invariants()
- self.create_op(func_op=new_version)
- self.check_invariants()
-
- def add_call(self, call: Call):
- self.check_invariants()
- func_op, inputs, outputs = call.func_op, call.inputs, call.outputs
- sig = func_op.sig
- row = {
- Config.uid_col: call.uid,
- Config.causal_uid_col: call.causal_uid,
- Config.content_version_col: call.content_version,
- Config.semantic_version_col: call.semantic_version,
- Config.transient_col: call.transient,
- **{sig.ui_to_internal_input_map[k]: v.full_uid for k, v in inputs.items()},
- **{dump_output_name(index=i): v.full_uid for i, v in enumerate(outputs)},
- }
- df = self.calls[sig.versioned_internal_name]
- # handle stale calls
- for k in df.columns:
- if k not in row.keys():
- row[k] = self.default_uids[sig.versioned_internal_name][k]
- row_df = pd.DataFrame([row])
- if row[Config.uid_col] not in df[Config.uid_col].values:
- self.calls[sig.versioned_internal_name] = pd.concat(
- [df, row_df],
- ignore_index=True,
- )
- for vref in itertools.chain(inputs.values(), outputs):
- self.values[vref.uid] = unwrap(vref)
- self.check_invariants()
-
- def sync_from_other(self, other: "MockStorage"):
- """
- Update this storage with the new data from another storage.
-
- NOTE: always copy the other storage's state into this storage; never use
- shared objects.
- """
- self.check_invariants()
- # update values
- for k, v in other.values.items():
- if k in self.values.keys():
- assert v == self.values[k]
- # avoid shared objects
- self.values[k] = copy.deepcopy(v)
- # update defaults
- for versioned_internal_name, defaults in other.default_uids.items():
- if versioned_internal_name in self.calls.keys():
- df = self.calls[versioned_internal_name]
- for internal_input_name, default_uid in defaults.items():
- if internal_input_name not in df.columns:
- df[internal_input_name] = [default_uid for _ in range(len(df))]
- self.default_uids[versioned_internal_name][
- internal_input_name
- ] = default_uid
- else:
- # use a copy to avoid a shared mutable object
- self.default_uids[versioned_internal_name] = copy.deepcopy(defaults)
- # update calls
- for versioned_internal_name, df in other.calls.items():
- if versioned_internal_name in self.calls.keys():
- current_df = self.calls[versioned_internal_name]
- new_df = df[~df[Config.uid_col].isin(current_df[Config.uid_col])].copy()
- self.calls[versioned_internal_name] = pd.concat(
- [current_df, new_df], ignore_index=True
- )
- else:
- self.calls[versioned_internal_name] = df.copy()
- self.check_invariants()
-
- def compare_with_real(self, real_storage: Storage):
- self.check_invariants()
- # extract values
- values_df = real_storage.rel_adapter.get_vrefs()
- values = {
- row[Config.uid_col]: row[Config.vref_value_col]
- for _, row in values_df.iterrows()
- }
- sigs = real_storage.sig_adapter.load_ui_sigs()
- # extract calls and defaults
- all_call_data = real_storage.rel_adapter.get_all_call_data()
- calls = {}
- default_uids = {}
- for versioned_ui_name, df in all_call_data.items():
- sig = sigs[Signature.parse_versioned_name(versioned_ui_name)]
- versioned_internal_name = sig.versioned_internal_name
- calls[versioned_internal_name] = df.rename(
- columns=sig.ui_to_internal_input_map
- )
- default_uids[versioned_internal_name] = sig._new_input_defaults_uids
- sess.d()
- return (
- values == self.values
- and all(
- compare_dfs_as_relations(calls[k], self.calls[k]) for k in calls.keys()
- )
- and default_uids == self.default_uids
- )
-
-
-class ClientState:
- def __init__(self, root: Optional[RemoteStorage]):
- self.storage = Storage(root=root)
- self.mock_storage = MockStorage()
- self.workflows: List[Workflow] = []
- self.func_ops: List[FuncOp] = []
- self.num_func_renames = 0
- self.num_input_renames = 0
-
-
-class Preconditions:
- """
- A namespace for preconditions for the rules of the state machine.
-
- NOTE: Preconditions are defined as functions instead of lambdas to enable
- type introspection, autorefactoring, etc.
- """
-
- # control some of the transitions to avoid long chains, especially ones that
- # make the DB larger
- #! (does this actually optimize things? we need to benchmark)
- MAX_OPS_PER_CLIENT = 10
- MAX_INPUTS_PER_OP = 5
- MAX_WORKFLOWS_PER_CLIENT = 5
- MAX_WORKFLOW_SIZE_TO_ADD_VAR = 5
- MAX_WORKFLOW_SIZE_TO_ADD_OP = 10
- # prevent too many renames
- MAX_FUNC_RENAMES_PER_CLIENT = 20
- MAX_INPUT_RENAMES_PER_CLIENT = 20
-
- @staticmethod
- def create_op(instance: "SingleClientSimulator") -> Tuple[bool, List[ClientState]]:
- # return clients for which an op can be created
- candidates = [
- c
- for c in instance.clients
- if len(c.func_ops) < Preconditions.MAX_OPS_PER_CLIENT
- ]
- return len(candidates) > 0, candidates
-
- ############################################################################
- ### refactoring
- ############################################################################
- @staticmethod
- def add_input(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]:
- # return tuples of (client, op, op_idx) for which an input can be added
- candidates = []
- for c in instance.clients:
- for idx, func_op in enumerate(c.func_ops):
- if len(func_op.sig.input_names) < Preconditions.MAX_INPUTS_PER_OP:
- candidates.append((c, func_op, idx))
- return len(candidates) > 0, candidates
-
- @staticmethod
- def rename_func(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]:
- # return tuples of (client, op, idx) for which the op can be renamed
- candidates = []
- for c in instance.clients:
- if c.num_func_renames < Preconditions.MAX_FUNC_RENAMES_PER_CLIENT:
- for idx, func_op in enumerate(c.func_ops):
- candidates.append((c, func_op, idx))
- return len(candidates) > 0, candidates
-
- @staticmethod
- def rename_input(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]:
- # return tuples of (client, op, idx) for which an input can be renamed
- candidates = []
- for c in instance.clients:
- if c.num_input_renames < Preconditions.MAX_INPUT_RENAMES_PER_CLIENT:
- candidates += [(c, op, idx) for idx, op in enumerate(c.func_ops)]
- return len(candidates) > 0, candidates
-
- @staticmethod
- def create_new_version(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, FuncOp, int]]]:
- # return tuples of (client, op, idx of this op) for which a new version can be created
- candidates = []
- for c in instance.clients:
- num_ops = len(c.func_ops)
- if num_ops > 0 and num_ops < Preconditions.MAX_OPS_PER_CLIENT:
- candidates += [(c, op, idx) for idx, op in enumerate(c.func_ops)]
- return len(candidates) > 0, candidates
-
- ############################################################################
- ### growing workflows
- ############################################################################
- @staticmethod
- def add_workflow(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[ClientState]]:
- # return clients for which a workflow can be created
- candidates = [
- c
- for c in instance.clients
- if len(c.workflows) < Preconditions.MAX_WORKFLOWS_PER_CLIENT
- ]
- return len(candidates) > 0, candidates
-
- @staticmethod
- def add_input_var_to_workflow(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, Workflow]]]:
- # return tuples of (client, workflow) for which an input var can be added
- candidates = []
- for c in instance.clients:
- for w in c.workflows:
- if w.shape_size < Preconditions.MAX_WORKFLOW_SIZE_TO_ADD_VAR:
- candidates.append((c, w))
- return len(candidates) > 0, candidates
-
- @staticmethod
- def add_op_to_workflow(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, Workflow, FuncOp]]]:
- # return tuples of (client, workflow, op) for which an op can be added
- candidates = []
- for c in instance.clients:
- for w in c.workflows:
- if (
- w.shape_size < Preconditions.MAX_WORKFLOW_SIZE_TO_ADD_OP
- and len(w.var_nodes) > 0
- ):
- candidates += [(c, w, op) for op in c.func_ops]
- return len(candidates) > 0, candidates
-
- @staticmethod
- def add_call_to_workflow(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, Workflow, CallNode]]]:
- # return tuples of (client, workflow, op_node) for which a call can be added
- candidates = []
- for c in instance.clients:
- for w in c.workflows:
- candidates += [(c, w, op_node) for op_node in w.callable_op_nodes]
- return len(candidates) > 0, candidates
-
- @staticmethod
- def execute_workflow(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, Workflow]]]:
- # return tuples of (client, workflow) for which a workflow can be
- # executed
- candidates = []
- for c in instance.clients:
- for w in c.workflows:
- if w.num_calls > 0:
- candidates.append((c, w))
- return len(candidates) > 0, candidates
-
- @staticmethod
- def check_mock_storage_single(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[ClientState]]:
- candidates = [c for c in instance.clients if len(c.workflows) > 0]
- return len(candidates) > 0, candidates
-
- @staticmethod
- def sync_all(
- instance: "SingleClientSimulator",
- ) -> bool:
- # require at least some calls to be executed
- for client in instance.clients:
- if any(
- len(df) > 0
- for df in client.storage.rel_adapter.get_all_call_data().values()
- ):
- return True
- return False
-
- @staticmethod
- def query_workflow(
- instance: "SingleClientSimulator",
- ) -> Tuple[bool, List[Tuple[ClientState, Workflow]]]:
- # return tuples of (client, workflow) for which a workflow can be
- # queried
- candidates = []
- for c in instance.clients:
- for w in c.workflows:
- if len(w.var_nodes) > 0 and w.is_saturated:
- candidates.append((c, w))
- return len(candidates) > 0, candidates
-
-
-class SingleClientSimulator(RuleBasedStateMachine):
- def __init__(self, n_clients: int = 1):
- super().__init__()
- # client = mongomock.MongoClient()
- # root = MongoMockRemoteStorage(db_name="test", client=client)
- root = None
- self.clients = [ClientState(root=root) for _ in range(n_clients)]
- # a central storage on which all operations are performed
- self.mock_storage = MockStorage()
- #! keep everything deterministic. only use `random` for generating
- random.seed(0)
-
- ############################################################################
- ### schema modifications
- ############################################################################
- @precondition(lambda machine: Preconditions.create_op(machine)[0])
- @rule()
- def create_op(self):
- """
- Add a random op to the storage.
- """
- client = random.choice(Preconditions.create_op(self)[1])
- new_func_op = make_op(
- ui_name=random_string(),
- input_names=[random_string() for _ in range(random.randint(1, 3))],
- n_outputs=random.randint(0, 3),
- defaults={},
- version=0,
- deterministic=True,
- )
- client.storage.synchronize_op(func_op=new_func_op)
- client.func_ops.append(new_func_op)
- for mock_storage in [client.mock_storage, self.mock_storage]:
- mock_storage.create_op(func_op=new_func_op)
- mock_storage.check_invariants()
-
- @precondition(lambda machine: Preconditions.add_input(machine)[0])
- @rule()
- def add_input(self):
- candidates = Preconditions.add_input(self)[1]
- client, func_op, idx = random.choice(candidates)
- # idx = random.randint(0, len(self._func_ops) - 1)
- # func_op = self._func_ops[idx]
- f = func_op.func
- sig = func_op.sig
- # simulate update using low-level API
- new_name = random_string()
- default_value = 23
- new_sig = sig.create_input(name=new_name, default=default_value, annotation=Any)
- # TODO: provide a new function with extra input as a user would
- new_func_op = FuncOp._from_data(func=make_func_from_sig(new_sig), sig=new_sig)
- client.storage.synchronize_op(func_op=new_func_op)
- client.func_ops[idx] = new_func_op
- new_sig = new_func_op.sig
- for mock_storage in [client.mock_storage, self.mock_storage]:
- mock_storage.add_input(
- func_op=new_func_op,
- internal_name=new_sig.ui_to_internal_input_map[new_name],
- default_value=default_value,
- default_full_uid=new_sig.new_ui_input_default_uids[new_name],
- )
- mock_storage.check_invariants()
-
- @precondition(lambda machine: Preconditions.rename_func(machine)[0])
- @rule()
- def rename_func(self):
- candidates = Preconditions.rename_func(self)[1]
- client, func_op, idx = random.choice(candidates)
- # idx = random.randint(0, len(self._func_ops) - 1)
- # func_op = self._func_ops[idx]
- new_name = random_string()
- # find and rename all versions
- all_versions = [
- (i, other_func_op)
- for i, other_func_op in enumerate(client.func_ops)
- if other_func_op.sig.internal_name == func_op.sig.internal_name
- ]
- rename_done = False
- for version_idx, func_op_version in all_versions:
- if not rename_done:
- # use the API the user would use to rename. This will rename all
- # versions.
- func_interface = FuncInterface(func_op=func_op_version)
- client.storage.synchronize(f=func_interface)
- new_sig = client.storage.rename_func(
- func=func_interface, new_name=new_name
- )
- rename_done = True
- else:
- # after the rename, get the true signature from the storage
- new_sig = client.storage.sig_adapter.load_state()[
- func_op_version.sig.internal_name, func_op_version.sig.version
- ]
- # now update the state of the simulator
- new_func_op_version = FuncOp._from_data(
- func=func_op_version.func, sig=new_sig
- )
- client.storage.synchronize_op(func_op=new_func_op_version)
- client.func_ops[version_idx] = new_func_op_version
- client.num_func_renames += 1
-
- @precondition(lambda machine: Preconditions.rename_input(machine)[0])
- @rule()
- def rename_input(self):
- candidates = Preconditions.rename_input(self)[1]
- client, func_op, func_idx = random.choice(candidates)
- # func_idx = random.randint(0, len(self._func_ops) - 1)
- # func_op = self._func_ops[func_idx]
- input_to_rename = random.choice(sorted(list(func_op.sig.input_names)))
- new_name = random_string()
- # use the API the user would use to rename
- func_interface = FuncInterface(func_op=func_op)
- client.storage.synchronize(f=func_interface)
- new_sig = client.storage.rename_arg(
- func=func_interface,
- name=input_to_rename,
- new_name=new_name,
- )
- # now update the state of the simulator
- new_func = make_func_from_sig(sig=new_sig)
- new_func_op = FuncOp._from_data(func=new_func, sig=new_sig)
- client.storage.synchronize_op(func_op=new_func_op)
- client.func_ops[func_idx] = new_func_op
- client.num_input_renames += 1
-
- @precondition(lambda machine: Preconditions.create_new_version(machine)[0])
- @rule()
- def create_new_version(self):
- candidates = Preconditions.create_new_version(self)[1]
- client, func_op, func_idx = random.choice(candidates)
- # func_idx = random.randint(0, len(self._func_ops) - 1)
- # func_op = self._func_ops[func_idx]
- latest_version = client.storage.sig_adapter.get_latest_version(sig=func_op.sig)
- new_version = latest_version.version + 1
- new_func_op = make_op(
- ui_name=func_op.sig.ui_name,
- input_names=[random_string() for _ in range(random.randint(1, 3))],
- n_outputs=random.randint(0, 3),
- defaults={},
- version=new_version,
- deterministic=True,
- )
- client.storage.synchronize_op(func_op=new_func_op)
- client.func_ops.append(new_func_op)
- for mock_storage in [client.mock_storage, self.mock_storage]:
- mock_storage.create_new_version(new_version=new_func_op)
- mock_storage.check_invariants()
-
- ############################################################################
- ### generating workflows
- ############################################################################
- @precondition(lambda machine: Preconditions.add_workflow(machine)[0])
- @rule()
- def add_workflow(self):
- """
- Add a new (empty) workflow to the test.
- """
- candidates = Preconditions.add_workflow(self)[1]
- client = random.choice(candidates)
- res = Workflow()
- client.workflows.append(res)
-
- @precondition(lambda machine: Preconditions.add_input_var_to_workflow(machine)[0])
- @rule()
- def add_input_var_to_workflow(self):
- candidates = Preconditions.add_input_var_to_workflow(self)[1]
- client, workflow = random.choice(candidates)
- # workflow = random.choice([w for w in self._workflows if w.shape_size < 5])
- # always add a value to make sampling proceed faster
- var = workflow.add_var()
- workflow.add_value(value=wrap_atom(get_uid()), var=var)
-
- @precondition(lambda machine: Preconditions.add_op_to_workflow(machine)[0])
- @rule()
- def add_op_to_workflow(self):
- """
- Add an instance of some op to some workflow.
- """
- candidates = Preconditions.add_op_to_workflow(self)[1]
- client, workflow, func_op = random.choice(candidates)
- # func_op = random.choice(self._func_ops)
- # workflow = random.choice([w for w in self._workflows if len(w.var_nodes) > 0])
- # pick inputs randomly from workflow
- inputs = {
- name: random.choice(workflow.var_nodes) for name in func_op.sig.input_names
- }
- # add function over these inputs
- _, _ = workflow.add_op(inputs=inputs, func_op=func_op)
-
- @precondition(lambda machine: Preconditions.add_call_to_workflow(machine)[0])
- @rule()
- def add_call_to_workflow(self):
- candidates = Preconditions.add_call_to_workflow(self)[1]
- client, workflow, op_node = random.choice(candidates)
- input_vars = op_node.inputs
- # pick random values
- var_to_values = workflow.var_to_values()
- input_values = {
- name: random.choice(var_to_values[var]) for name, var in input_vars.items()
- }
- func_op = op_node.func_op
- output_types = [Type.from_annotation(a) for a in func_op.output_annotations]
- outputs = [make_delayed(tp=tp) for tp in output_types]
- call_struct = CallStruct(
- func_op=op_node.func_op, inputs=input_values, outputs=outputs
- )
- workflow.add_call_struct(call_struct=call_struct)
-
- def _execute_workflow(self, client: ClientState, workflow: Workflow) -> List[Call]:
- client.storage.sync_from_remote()
- client.mock_storage.sync_from_other(other=self.mock_storage)
- calls = SimpleWorkflowExecutor().execute(
- workflow=workflow, storage=client.storage
- )
- client.storage.commit(calls=calls)
- client.storage.sync_to_remote()
- for mock_storage in [client.mock_storage, self.mock_storage]:
- for call in calls:
- mock_storage.add_call(call=call)
- mock_storage.check_invariants()
- return calls
-
- @precondition(lambda machine: Preconditions.execute_workflow(machine)[0])
- @rule()
- def execute_workflow(self):
- candidates = Preconditions.execute_workflow(self)[1]
- client, workflow = random.choice(candidates)
- calls = self._execute_workflow(client=client, workflow=workflow)
- # client.storage.sync_from_remote()
- # calls = SimpleWorkflowExecutor().execute(
- # workflow=workflow, storage=client.storage
- # )
- # client.storage.commit(calls=calls)
- # client.storage.sync_to_remote()
-
- ############################################################################
- ### multi-client rules
- ############################################################################
- @rule()
- def sync_one(self):
- client = random.choice(self.clients)
- client.storage.sync_with_remote()
-
- @precondition(lambda machine: Preconditions.check_mock_storage_single(machine)[0])
- @rule()
- def check_mock_storage_single(self):
- # run all workflows for a given client and check that the state equals
- # the state of the mock storage
- candidates = Preconditions.check_mock_storage_single(self)[1]
- client = random.choice(candidates)
- for workflow in client.workflows:
- self._execute_workflow(client=client, workflow=workflow)
- assert client.mock_storage.compare_with_real(client.storage)
-
- @precondition(lambda machine: Preconditions.sync_all(machine))
- @rule()
- def sync_all(self):
- for client in self.clients:
- client.storage.sync_with_remote()
- client.mock_storage.sync_from_other(other=self.mock_storage)
- for client in self.clients:
- # have to do it again!
- client.storage.sync_with_remote()
- client.mock_storage.sync_from_other(other=self.mock_storage)
- for client in self.clients:
- assert client.mock_storage == self.mock_storage
- assert self.mock_storage.compare_with_real(real_storage=client.storage)
-
- @invariant()
- def verify_state(self):
- for client in self.clients:
- # make sure that functions called on the storage work
- client.storage.rel_adapter.get_all_call_data()
- for table in client.storage.rel_storage.get_tables():
- client.storage.rel_storage.get_count(table=table)
- # check storage invariants
- check_invariants(storage=client.storage)
- # check invariants on the workflows
- for w in client.workflows:
- w.check_invariants()
- # w.print_shape()
- # check the individual signatures
- for func_op in client.func_ops:
- func_op.sig.check_invariants()
- # check the set of signatures
- client.storage.sig_adapter.check_invariants()
-
- def check_sig_synchronization(self):
- for client in self.clients:
- assert client.storage.root.sigs == client.storage.sig_adapter.load_state()
-
- # @precondition(lambda machine: Preconditions.query_workflow(machine)[0])
- # @rule()
- # def query_workflow(self):
- # candidates = Preconditions.query_workflow(self)[1]
- # client, workflow = random.choice(candidates)
- # # workflow = random.choice([w for w in client.workflows if w.is_saturated])
- # # path = Path(__file__).parent / f"bug.cloudpickle"
- # # op_nodes = copy.deepcopy(workflow.op_nodes)
- # # db_dump_path = Path(__file__).parent.absolute() / 'db_dump/'
- # # self.storage.rel_storage.execute_no_results(query=f"EXPORT DATABASE '{db_dump_path}';")
- # # data = (self._ops)
- # # with open(path, "wb") as f:
- # # cloudpickle.dump(data, f)
- # val_queries, op_queries = workflow.var_nodes, workflow.op_nodes
- # # workflow.print_shape()
- # df = client.storage.execute_query(select_queries=val_queries, engine='naive')
-
-
-class MultiClientSimulator(SingleClientSimulator):
- def __init__(self, n_clients: int = 3):
- super().__init__(n_clients=n_clients)
-
-
-MAX_EXAMPLES = 100
-STEP_COUNT = 25
-
-TestCaseSingle = SingleClientSimulator.TestCase
-TestCaseSingle.settings = settings(
- max_examples=MAX_EXAMPLES, deadline=None, stateful_step_count=STEP_COUNT
-)
-
-TestCaseMany = MultiClientSimulator.TestCase
-TestCaseMany.settings = settings(
- max_examples=MAX_EXAMPLES, deadline=None, stateful_step_count=STEP_COUNT
-)
diff --git a/mandala/tests/test_structs.py b/mandala/tests/test_structs.py
deleted file mode 100644
index bf71da5..0000000
--- a/mandala/tests/test_structs.py
+++ /dev/null
@@ -1,141 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.queries.weaver import *
-
-
-@pytest.mark.parametrize("storage", generate_storages())
-def test_unit(storage):
- Config.query_engine = "_test"
-
- ### lists
- @op
- def repeat(x: int, times: int = None) -> List[int]:
- return [x] * times
-
- @op
- def get_list_mean(nums: List[float]) -> float:
- return sum(nums) / len(nums)
-
- with storage.run():
- lst = repeat(x=42, times=23)
- a = lst[0]
- x = get_list_mean(nums=lst)
- y = get_list_mean(nums=lst[:10])
- assert unwrap(lst[1]) == 42
- assert len(lst) == 23
- storage.rel_adapter.obj_get(uid=lst.uid)
-
- with storage.query():
- x = Q().named("x")
- lst = repeat(x).named("lst")
- idx = Q().named("idx")
- elt = BuiltinQueries.GetListItemQuery(lst=lst, idx=idx).named("elt")
- df = storage.df(x, lst, elt, idx)
- assert df.shape == (23, 4)
- assert all(df["elt"] == 42)
- assert sorted(df["idx"]) == list(range(23))
-
- # test list constructor
- with storage.query():
- # a query for all the lists whose mean we've taken
- lst = BuiltinQueries.ListQ(elts=[Q()]).named("lst")
- x = get_list_mean(nums=lst).named("x")
- df = storage.df(lst, x)
-
- # test syntax sugar
- with storage.query():
- x = Q().named("x")
- lst = repeat(x).named("lst")
- first_elt = lst[Q()].named("first_elt")
- df = storage.df(x, lst, first_elt)
-
- ### dicts
- @op
- def get_dict_mean(nums: Dict[str, float]) -> float:
- return sum(nums.values()) / len(nums)
-
- @op
- def describe_sequence(seq: List[int]) -> Dict[str, float]:
- return {
- "min": min(seq),
- "max": max(seq),
- "mean": sum(seq) / len(seq),
- }
-
- with storage.run():
- dct = describe_sequence(seq=[1, 2, 3])
- dct_mean = get_dict_mean(nums=dct)
- dct_mean_2 = get_dict_mean(nums={"a": dct["min"]})
- assert unwrap(dct_mean) == 2.0
- assert unwrap(dct["min"]) == 1
- assert len(dct) == 3
- storage.rel_adapter.obj_get(uid=dct.uid)
-
- with storage.query():
- seq = Q().named("seq")
- dct = describe_sequence(seq=seq).named("dct")
- dct_mean = get_dict_mean(nums=dct).named("dct_mean")
- df = storage.df(seq, dct, dct_mean)
-
- # test dict constructor
- with storage.query():
- # a query for all the dicts whose mean we've taken
- dct = BuiltinQueries.DictQ(dct={Q(): Q()}).named("dct")
- dct_mean = get_dict_mean(nums=dct).named("dct_mean")
- df = storage.df(dct, dct_mean)
-
- # test syntax sugar
- # with storage.query():
- # seq = Q().named("seq")
- # dct = describe_sequence(seq=seq).named("dct")
- # dct_mean = dct["mean"].named("dct_mean")
- # df = storage.df(seq, dct, dct_mean)
-
- ### sets
- @op
- def mean_set(nums: Set[float]) -> float:
- return sum(nums) / len(nums)
-
- @op
- def get_prime_factors(num: int) -> Set[int]:
- factors = set()
- for i in range(2, num):
- while num % i == 0:
- factors.add(i)
- num /= i
- return factors
-
- with storage.run():
- factors = get_prime_factors(num=42)
- factors_mean = mean_set(nums=factors)
- assert unwrap(factors_mean) == 4.0
- assert len(factors) == 3
- storage.rel_adapter.obj_get(uid=factors.uid)
-
- with storage.query():
- num = Q().named("num")
- factors = BuiltinQueries.SetQ(elts={num}).named("factors")
- factors_mean = mean_set(nums=factors).named("factors_mean")
- df = storage.df(num, factors, factors_mean)
-
-
-@pytest.mark.parametrize("storage", generate_storages())
-def test_nested(storage):
- @op
- def sum_rows(mat: List[List[float]]) -> List[float]:
- return [sum(row) for row in mat]
-
- with storage.run():
- mat = [[1, 2, 3], [4, 5, 6]]
- sums = sum_rows(mat=mat)
- mat_2 = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
- sums_2 = sum_rows(mat=mat_2)
- assert sums_2[0].uid == sums[0].uid
- assert sums_2[1].uid == sums[1].uid
-
- with storage.query():
- x = Q().named("x")
- row = BuiltinQueries.ListQ(elts=[x]).named("row")
- mat = BuiltinQueries.ListQ(elts=[row]).named("mat")
- row_sums = sum_rows(mat=mat).named("row_sums")
- df = storage.df(x, row, mat, row_sums)
diff --git a/mandala/tests/test_superops.py b/mandala/tests/test_superops.py
deleted file mode 100644
index be71443..0000000
--- a/mandala/tests/test_superops.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-from mandala.queries.weaver import *
-
-
-@pytest.mark.parametrize("storage", generate_storages())
-def test_unit(storage):
- @op
- def f(x: int, y: int) -> int:
- return x + y
-
- @op
- def mean(nums: List[int]) -> float:
- return sum(nums) / len(nums)
-
- @op
- def dictmean(nums: Dict[str, Any]) -> float:
- return sum(nums.values()) / len(nums)
-
- @op
- def repeat(num: int, times: int) -> List[int]:
- return [num for something in range(times)]
-
- @op
- def make_dict(a: int, b: int) -> Dict[str, int]:
- return {"a": a, "b": b}
-
- @op
- def swap(x: int, y: int) -> Tuple[int, int]:
- return y, x
-
- @op
- def concat_lists(a: List[Any], b: List[Any]) -> List[Any]:
- return a + b
-
- @superop
- def workflow(a: int, b: int, c: int) -> List[Any]:
- dct = {"a": a, "b": b}
- dct_mean = dictmean(nums=dct)
- things = repeat(num=a, times=b)
- things_mean = mean(things)
- new = {"x": dct_mean, "y": things_mean}
- final = dictmean(nums=new)
- x, y = swap(x=final, y=c)
- return [x, y]
-
- @superop
- def super_workflow(a: int, b: int) -> List[Any]:
- things = workflow(a, a, a)
- other_things = workflow(b, b, b)
- return concat_lists(things, other_things)
-
- with storage.run():
- a = f(23, 42)
- b = f(4, 8)
- c = f(15, 16)
- avg = mean([a, b, c])
- sames = repeat(num=23, times=5)
- elt = sames[3]
- dict_avg = dictmean({"a": 23, "b": 42})
- dct = make_dict(a=23, b=42)
- z = workflow(a=23, b=42, c=10)
- w = super_workflow(a=a, b=b)
-
- with storage.run():
- a = f(23, 42)
- b = f(4, 8)
- c = f(15, 16)
- avg = mean([a, b, c])
- sames = repeat(num=23, times=5)
- elt = sames[3]
- dict_avg = dictmean({"a": 23, "b": 42})
- dct = make_dict(a=23, b=42)
- z = workflow(a=23, b=42, c=10)
- w = super_workflow(a=a, b=b)
-
-
-@pytest.mark.parametrize("storage", generate_storages())
-def test_mutual_recursion(storage):
- @superop
- def f(x: int) -> int:
- if unwrap(x) == 0:
- return 0
- else:
- return g(unwrap(x) - 1)
-
- @superop
- def g(x: int) -> int:
- if unwrap(x) == 0:
- return 0
- else:
- return f(unwrap(x) - 1)
-
- with storage.run():
- a = f(23)
diff --git a/mandala/tests/test_transient.py b/mandala/tests/test_transient.py
deleted file mode 100644
index f883fef..0000000
--- a/mandala/tests/test_transient.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-def test_unit():
- storage = Storage()
-
- @op
- def f(x) -> int:
- return Transient(x + 1)
-
- with storage.run(attach_call_to_outputs=True):
- a = f(42)
- call: Call = a._call
- assert call.transient
- assert a.transient
- assert unwrap(a) == 43
-
- with storage.run(attach_call_to_outputs=True):
- a = f(42)
- call: Call = a._call
- # assert not a.in_memory - no longer true b/c caching
- assert a.transient
- assert call.transient
-
- with storage.run(recompute_transient=True, attach_call_to_outputs=True):
- a = f(42)
- call: Call = a._call
-
- assert a.in_memory
- assert a.transient
- assert call.transient
-
-
-def test_composition():
- storage = Storage()
-
- @op
- def f(x) -> int:
- return Transient(x + 1)
-
- @op
- def g(x) -> int:
- return Transient(x**2)
-
- with storage.run():
- a = f(42)
- b = g(a)
-
- with storage.run():
- a = f(23)
-
- storage.cache.evict_all()
-
- try:
- with storage.run():
- a = f(23)
- b = g(a)
- assert False
- except ValueError as e:
- assert True
-
- with storage.run(recompute_transient=True):
- a = f(23)
- b = g(a)
- assert unwrap(b) == 576
diff --git a/mandala/_next/tests/test_versioning.py b/mandala/tests/test_versioning.py
similarity index 100%
rename from mandala/_next/tests/test_versioning.py
rename to mandala/tests/test_versioning.py
diff --git a/mandala/tests/test_versions.py b/mandala/tests/test_versions.py
deleted file mode 100644
index 1b2fac0..0000000
--- a/mandala/tests/test_versions.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from mandala.all import *
-from mandala.tests.utils import *
-
-
-def test_unit():
- storage = Storage()
-
- ############################################################################
- ### unit test
- ############################################################################
-
- @op
- def inc(x: int) -> int:
- return x + 1
-
- with storage.run():
- inc(23)
-
- @op(version=1)
- def inc(x: int) -> int:
- return x + 1
-
- # check unsynchronized functions work for this
- storage.sig_adapter.get_versions(sig=inc.func_op.sig)
-
- with storage.run():
- inc(23)
-
- sigs = [v for k, v in storage.sig_adapter.load_state().items() if not v.is_builtin]
- assert len(sigs) == 2
diff --git a/mandala/tests/utils.py b/mandala/tests/utils.py
deleted file mode 100644
index c2ab0bb..0000000
--- a/mandala/tests/utils.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import uuid
-import pytest
-
-from ..common_imports import *
-from mandala.all import *
-from mandala.ui.storage import Storage
-from mandala.ui.storage import MODES
-from mandala.core.config import Config, is_output_name
-from mandala.core.wrapping import compare_dfs_as_relations
-
-# from mandala.storages.remote_impls.mongo_impl import MongoRemoteStorage
-# from mandala.storages.remote_impls.mongo_mock import MongoMockRemoteStorage
-from mandala.storages.rels import RelAdapter
-
-
-def generate_db_path() -> Path:
- output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output")
- fname = str(uuid.uuid4()) + ".db"
- return output_dir / fname
-
-
-def generate_spillover_dir() -> Path:
- output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output")
- fname = str(uuid.uuid4())
- return output_dir / fname
-
-
-def generate_path(ext: str) -> Path:
- output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output")
- fname = str(uuid.uuid4()) + ext
- return output_dir / fname
-
-
-def generate_storages() -> List[Storage]:
- results = []
- for db_backend in ("sqlite",):
- for persistent in (True, False):
- for spillover in (True, False):
- results.append(
- Storage(
- db_backend=db_backend,
- db_path=generate_db_path() if persistent else None,
- spillover_dir=generate_spillover_dir() if spillover else None,
- spillover_threshold_mb=0,
- )
- )
- return results
-
-
-def signatures_are_equal(storage_1: Storage, storage_2: Storage) -> bool:
- sigs_1 = storage_1.sig_adapter.load_state()
- sigs_2 = storage_2.sig_adapter.load_state()
- if sigs_1.keys() != sigs_2.keys():
- return False
- for (internal_name, version), sig_1 in sigs_1.items():
- sig_2 = sigs_2[internal_name, version]
- if sig_1 != sig_2:
- return False
- return True
-
-
-def data_is_equal(
- storage_1: Storage, storage_2: Storage, return_reason: bool = False
-) -> Union[bool, Tuple[bool, str]]:
- data_1 = storage_1.rel_storage.get_all_data()
- data_2 = storage_2.rel_storage.get_all_data()
- #! remove some internal tables from the comparison
- for _internal_table in [RelAdapter.DEPS_TABLE, RelAdapter.PROVENANCE_TABLE]:
- if _internal_table in data_1:
- data_1.pop(_internal_table)
- data_2.pop(_internal_table)
- # compare the keys
- if data_1.keys() != data_2.keys():
- result, reason = False, f"Tables differ: {data_1.keys()} vs {data_2.keys()}"
- # compare the signatures
- sigs_1 = storage_1.sig_adapter.load_state()
- sigs_2 = storage_2.sig_adapter.load_state()
- if sigs_1.keys() != sigs_2.keys():
- result, reason = (
- False,
- f"Signature keys differ: {sigs_1.keys()} vs {sigs_2.keys()}",
- )
- if sigs_1 != sigs_2:
- result, reason = False, f"Signatures differ: {sigs_1} vs {sigs_2}"
- # compare the data
- elementwise_comparisons = {
- k: compare_dfs_as_relations(data_1[k], data_2[k], return_reason=True)
- for k in data_1.keys()
- if k != Config.schema_table
- }
- if all(result for result, _ in elementwise_comparisons.values()):
- result, reason = True, ""
- else:
- result, reason = (
- False,
- f"Found differences between tables: {elementwise_comparisons}",
- )
- if return_reason:
- return result, reason
- else:
- return result
-
-
-def check_invariants(storage: Storage):
- # check that signatures match tables
- ui_sigs = storage.sig_adapter.load_ui_sigs()
- call_tables = storage.rel_adapter.get_call_tables()
- columns_by_table = {}
- # collect table columns
- for call_table in call_tables:
- columns = storage.rel_storage._get_cols(relation=call_table)
- columns_by_table[call_table] = columns
- # check that all signatures are accounted for in the tables
- for sig in ui_sigs.values():
- table_name = sig.versioned_ui_name
- assert table_name in columns_by_table
- columns = columns_by_table[table_name]
- assert sig.input_names.issubset(set(columns))
- # check that all tables are accounted for in the signatures
- for call_table, columns in columns_by_table.items():
- input_cols = [
- col
- for col in columns
- if not is_output_name(col) and col not in Config.special_call_cols
- ]
- ui_name, version = Signature.parse_versioned_name(versioned_name=call_table)
- assert set(input_cols).issubset(ui_sigs[ui_name, version].input_names)
diff --git a/mandala/_next/tps.py b/mandala/tps.py
similarity index 100%
rename from mandala/_next/tps.py
rename to mandala/tps.py
diff --git a/mandala/_next/tutorials/gotchas.ipynb b/mandala/tutorials/gotchas.ipynb
similarity index 100%
rename from mandala/_next/tutorials/gotchas.ipynb
rename to mandala/tutorials/gotchas.ipynb
diff --git a/mandala/_next/tutorials/hello.ipynb b/mandala/tutorials/hello.ipynb
similarity index 100%
rename from mandala/_next/tutorials/hello.ipynb
rename to mandala/tutorials/hello.ipynb
diff --git a/mandala/_next/tutorials/ml.ipynb b/mandala/tutorials/ml.ipynb
similarity index 100%
rename from mandala/_next/tutorials/ml.ipynb
rename to mandala/tutorials/ml.ipynb
diff --git a/mandala/_next/ui.py b/mandala/ui.py
similarity index 100%
rename from mandala/_next/ui.py
rename to mandala/ui.py
diff --git a/mandala/ui/cfs.py b/mandala/ui/cfs.py
deleted file mode 100644
index bcf9d3e..0000000
--- a/mandala/ui/cfs.py
+++ /dev/null
@@ -1,895 +0,0 @@
-import textwrap
-from ..common_imports import *
-from ..core.config import *
-from ..core.tps import Type
-from ..core.model import Ref, FuncOp, Call
-from ..core.builtins_ import StructOrientations, Builtins
-from ..storages.rel_impls.utils import Transactable, transaction, Connection
-from ..queries.graphs import copy_subgraph
-from ..core.prov import propagate_struct_provenance
-from ..queries.weaver import CallNode, ValNode
-from ..queries.viz import GraphPrinter, visualize_graph
-from .funcs import FuncInterface
-from .storage import Storage, ValueLoader
-from .cfs_utils import estimate_uid_storage, convert_bytes_to
-
-
-class ComputationFrame(Transactable):
- """
- In-memory, dynamic representation of a slice of storage representing some
- computation, with methods for indexing, evaluating (i.e. loading from
- storage), and navigating back/forward along computation paths. These methods
- turn a `ComputationFrame` into a generalized dataframe over a computational
- graph, with columns corresponding to variables in the computation, and rows
- corresponding to values of those variables for a single instance of the
- computation.
-
- This is a simple declarative interface for exploring the storage, in
- contrast with using an imperative computational context (i.e., manipulating
- some memoized piece of code to interface with storage).
-
- The main differences between a `ComputationFrame` and a `DataFrame` are as
- follows:
- - by default, the `ComputationFrame` is lazy, i.e. it does not load the
- values of the variables it represents from the storage, but only their
- metadata in the form of `Ref` objects;
- - the `eval(vars)` method allows for loading the values of chosen
- variables from the storage, and returns an ordinary `pandas` dataframe
- with columns corresponding to the variables;
- - the `forward(vars)` and `back(vars)` methods allow for navigation
- along the computation graph, by expanding the graph to include the
- operation(s) that created/used given variables represented in the
- `ComputationFrame`;
- - the `creators(var)` and `consumers(var)` methods allow for inspecting
- the operations that created/used a variable (including ones not
- currently represented in the `ComputationFrame`);
-
- It has several important limitations:
- - the main one being that it can only represent a single computation in a
- given instance (i.e., a single composition of functions). This comes at the
- benefit of simplicity and ease of use.
- - when it finds a data structure that was constructed "by hand" out of
- multiple elements, it picks one of them arbitrarily to represent the
- computational history of the entire structure. Sometimes this makes sense
- (e.g., when the structure is a list of values with analogous provenance),
- sometimes it doesn't.
- - Similarly, when it finds a data structure that is the output of a
- computation, it picks one of its elements arbitrarily to represent the
- computational continuation of the entire structure. Dual caveats apply.
- """
-
- def __init__(
- self,
- call_nodes: Dict[str, CallNode],
- val_nodes: Dict[str, ValNode],
- storage: Storage,
- prov_df: Optional[pd.DataFrame] = None,
- ):
- self.op_nodes = call_nodes
- self.var_nodes = val_nodes
-
- self.storage = storage
- if prov_df is None:
- prov_df = storage.rel_storage.get_data(table=Config.provenance_table)
- prov_df = propagate_struct_provenance(prov_df)
- self.prov_df = prov_df
-
- def __len__(self) -> int:
- if len(self.op_nodes) > 0:
- return len(self.op_nodes[list(self.op_nodes.keys())[0]].calls)
- if len(self.var_nodes) > 0:
- return len(self.var_nodes[list(self.var_nodes.keys())[0]].refs)
- return 0
-
- def to_pandas(
- self,
- columns: Optional[Union[str, Iterable[str]]] = None,
- values: Literal["refs", "lazy", "objs"] = "lazy",
- ) -> Union[pd.Series, pd.DataFrame]:
- """
- Extract a series/dataframe of refs from the RefFunctor, analogous to
- indexing into a pandas dataframe.
- """
- if columns is None:
- columns = (
- list(self.var_nodes.keys())
- if len(self.var_nodes) > 1
- else list(self.var_nodes.keys())
- )
- if isinstance(columns, str):
- refs = pd.Series(self.var_nodes[columns].refs, name=columns)
- full_uids_df = refs.apply(lambda x: x.full_uid)
- res_df = self.storage.eval_df(
- full_uids_df=full_uids_df.to_frame(), values=values
- )
- # turn into a series
- return res_df[columns]
- elif isinstance(columns, list) and all(isinstance(x, str) for x in columns):
- refs = pd.DataFrame(
- {col: self.var_nodes[col].refs for col in columns}
- )
- return self.storage.eval_df(
- full_uids_df=refs.applymap(
- lambda x: x.full_uid if x is not None else None
- ),
- values=values,
- )
- else:
- raise ValueError(f"Invalid columns type: {type(columns)}")
-
- def __getitem__(
- self, indexer: Union[str, Iterable[str], np.ndarray]
- ) -> "ComputationFrame":
- """
- Analogous to pandas __getitem__, but tries to return a `ComputationFrame`
- """
- if isinstance(indexer, str):
- if indexer in self.var_nodes:
- return self.copy_subgraph(
- val_nodes=[self.var_nodes[indexer]], call_nodes=[]
- )
- elif indexer in self.op_nodes:
- return self.copy_subgraph(
- val_nodes=[], call_nodes=[self.op_nodes[indexer]]
- )
- else:
- raise ValueError(
- f"Column {indexer} not found in variables or operations"
- )
- elif isinstance(indexer, list) and all(isinstance(x, str) for x in indexer):
- var_keys = [x for x in self.var_nodes.keys() if x in indexer]
- call_keys = [x for x in self.op_nodes.keys() if x in indexer]
- return self.copy_subgraph(
- val_nodes={self.var_nodes[k] for k in var_keys},
- call_nodes={self.op_nodes[k] for k in call_keys},
- )
- elif isinstance(indexer, (np.ndarray, pd.Series)):
- if isinstance(indexer, pd.Series):
- indexer = indexer.values
- # boolean mask
- if indexer.dtype == bool:
- res = self.copy_subgraph()
- for k, v in res.var_nodes.items():
- v.inplace_mask(indexer)
- for k, v in res.op_nodes.items():
- v.inplace_mask(indexer)
- return res
- else:
- raise NotImplementedError(
- "Indexing with a non-boolean mask is not supported"
- )
- else:
- raise ValueError(
- f"Invalid indexer type into {self.__class__}: {type(indexer)}"
- )
-
- @transaction()
- def eval(
- self,
- indexer: Optional[Union[str, List[str]]] = None,
- conn: Optional[Connection] = None,
- ) -> Union[pd.Series, pd.DataFrame]:
- if indexer is None:
- indexer = list(self.var_nodes.keys())
- if isinstance(indexer, str):
- full_uids_df = (
- self[indexer]
- .to_pandas()
- .applymap(lambda x: x.full_uid if x is not None else None)
- )
- res_df = self.storage.eval_df(full_uids_df=full_uids_df, values="objs")
- res = res_df[indexer]
- return res
- else:
- full_uids_df = (
- self[indexer]
- .to_pandas()
- .applymap(lambda x: x.full_uid if x is not None else None)
- )
- return self.storage.eval_df(
- full_uids_df=full_uids_df, values="objs", conn=conn
- )
-
- @transaction()
- def creators(self, col: str, conn: Optional[Connection] = None) -> np.ndarray:
- calls, output_names = self.storage.get_creators(
- refs=self.var_nodes[col].refs, prov_df=self.prov_df, conn=conn
- )
- return np.array(
- [
- call.func_op.sig.versioned_ui_name if call is not None else None
- for call in calls
- ]
- )
-
- @transaction()
- def consumers(self, col: str, conn: Optional[Connection] = None) -> np.ndarray:
- calls_list, input_names_list = self.storage.get_consumers(
- refs=self.var_nodes[col].refs, prov_df=self.prov_df, conn=conn
- )
- res = np.empty(len(calls_list), dtype=object)
- res[:] = [
- tuple(
- [
- call.func_op.sig.versioned_ui_name if call is not None else None
- for call in calls
- ]
- )
- for calls in calls_list
- ]
- return res
-
- @transaction()
- def get_adjacent_calls(
- self,
- col: str,
- direction: Literal["back", "forward"],
- conn: Optional[Connection] = None,
- ) -> Dict[Tuple[str, str], List[Optional[Call]]]:
- """
- Given a column and a direction to traverse the graph (back or forward),
- return the calls that created/used the values in the column, along with
- the output/input names under which the values appear in the calls.
-
- The calls are grouped by the operation and the output/input name under
- which the values in this column appear in the calls.
- """
- refs: List[Ref] = self.var_nodes[col].refs
- if direction == "back":
- calls_list, names_list = self.storage.get_creators(
- refs=refs, prov_df=self.prov_df, conn=conn
- )
- calls_list = [[c] if c is not None else [] for c in calls_list]
- names_list = [[n] if n is not None else [] for n in names_list]
- elif direction == "forward":
- calls_list, names_list = self.storage.get_consumers(
- refs=refs, prov_df=self.prov_df, conn=conn
- )
- else:
- raise ValueError(f"Unknown direction: {direction}")
- index_dict = defaultdict(list) # (op_id, input/output name) -> (idx, call)
- # for i, (calls, names) in enumerate(zip(calls_list, names_list)):
- for i in range(len(refs)):
- calls = calls_list[i]
- names = names_list[i]
- for call, name in zip(calls, names):
- index_dict[(call.func_op.sig.versioned_ui_name, name)].append((i, call))
- res = {}
- for (op_id, name), indices_and_calls in index_dict.items():
- calls_list = [None for _ in range(len(refs))]
- for i, call in indices_and_calls:
- calls_list[i] = call
- res[(op_id, name)] = calls_list
- return res
-
- @transaction()
- def _back_all(
- self,
- res: "ComputationFrame",
- inplace: bool = False,
- verbose: bool = False,
- conn: Optional[Connection] = None,
- ) -> "ComputationFrame":
- # this means we want to expand the entire graph
- node_frontier = res.var_nodes.keys()
- visited = set()
- while True:
- res = res.back(
- cols=list(node_frontier),
- inplace=inplace,
- skip_failures=True,
- verbose=verbose,
- conn=conn,
- )
- visited |= node_frontier
- nodes_after = set(res.var_nodes.keys())
- node_frontier = nodes_after - visited
- if not node_frontier:
- break
- return res
-
- def join_var_node(
- self,
- refs: List[Ref],
- tp: Type,
- name_hint: Optional[str] = None,
- ) -> Tuple[str, ValNode]:
- refs_hash = ValNode.get_refs_hash(refs=refs)
- for var_name, var_node in self.var_nodes.items():
- if var_node.refs_hash == refs_hash:
- return var_name, var_node
- else:
- res = ValNode(
- tp=tp,
- refs=refs,
- constraint=None,
- )
- res_name = self.get_new_vname(hint=name_hint)
- self.var_nodes[res_name] = res
- return res_name, res
-
- def join_op_node(
- self,
- calls: List[Call],
- out_map: Optional[Dict[str, str]] = None, # output name -> var name
- in_map: Optional[Dict[str, str]] = None, # input name -> var name
- ) -> Tuple[str, CallNode]:
- """
- Join an op node to the graph. If a node with this hash already exists,
- only connect any not yet connected inputs/outputs; otherwise, create a
- new node and then connect inputs/outputs.
- """
- out_map = {} if out_map is None else out_map
- in_map = {} if in_map is None else in_map
- calls_hash = CallNode.get_calls_hash(calls=calls)
- for op_name, op_node in self.op_nodes.items():
- if op_node.calls_hash == calls_hash:
- res = op_node
- res_name = op_name
- break
- else:
- call_representative = calls[0]
- op = call_representative.func_op
-
- #! for struct calls, figure out the orientation; currently ad-hoc
- if op.is_builtin:
- logging.warning(
- f"Found a ref data structure: {op.sig.ui_name}; ComputationFrame support for this is experimental and may not work as expected."
- )
- output_names = list(out_map.keys())
- input_names = list(in_map.keys())
- if len(output_names) == 0:
- if any(x in input_names for x in ("elt", "value")):
- orientation = StructOrientations.construct
- elif any(x in input_names for x in ("lst", "dct", "st")):
- orientation = StructOrientations.destruct
- else:
- raise NotImplementedError
- else:
- if any(x in output_names for x in ("lst", "dct", "st")):
- orientation = StructOrientations.construct
- elif any(x in output_names for x in ("idx", "key")):
- orientation = StructOrientations.destruct
- else:
- raise NotImplementedError
- in_map, out_map = Builtins.reassign_io_using_orientation(
- in_dict=in_map,
- out_dict=out_map,
- orientation=orientation,
- builtin_id=op.sig.ui_name,
- )
- in_map_for_linking = {**in_map, **out_map}
- out_map_for_linking = {}
- else:
- orientation = None
- in_map_for_linking = in_map
- out_map_for_linking = out_map
-
- res = CallNode.link(
- calls=calls,
- inputs={k: self.var_nodes[v] for k, v in in_map_for_linking.items()},
- outputs={k: self.var_nodes[v] for k, v in out_map_for_linking.items()},
- constraint=None,
- func_op=call_representative.func_op,
- orientation=orientation,
- )
- res_name = self.get_new_cname(op)
- self.op_nodes[res_name] = res
- for k, v in out_map.items():
- if k not in res.outputs.keys():
- # connect manually
- res.outputs[k] = self.var_nodes[v]
- self.var_nodes[v].add_creator(creator=res, created_as=k)
- for k, v in in_map.items():
- if k not in res.inputs.keys():
- # connect manually
- res.inputs[k] = self.var_nodes[v]
- self.var_nodes[v].add_consumer(consumer=res, consumed_as=k)
- return res_name, res
-
- @transaction()
- def back(
- self,
- cols: Optional[Union[str, List[str]]] = None,
- inplace: bool = False,
- skip_failures: bool = False,
- verbose: bool = False,
- conn: Optional[Connection] = None,
- ) -> "ComputationFrame":
- res = self if inplace else self.copy_subgraph()
- if cols is None:
- # this means we want to expand the entire graph
- return res._back_all(res, inplace=inplace, verbose=verbose, conn=conn)
- if verbose:
- logger.info(f"Expanding graph to include the provenance of columns {cols}")
- if isinstance(cols, str):
- cols = [cols]
-
- adjacent_calls_data = {
- col: res.get_adjacent_calls(col=col, direction="back", conn=conn)
- for col in cols
- }
-
- filtered_cols = []
- for col, calls_dict in adjacent_calls_data.items():
- if len(calls_dict) > 1:
- reason = f"Values in column {col} were created by multiple ops and/or as different outputs: {[f'{op}::{out}' for op, out in calls_dict.keys()]}"
- if skip_failures:
- if verbose:
- logger.info(f"{reason}; skipping column {col}")
- continue
- else:
- raise ValueError(reason)
- if len(calls_dict) == 0 or any(call is None for call in calls_dict[list(calls_dict.keys())[0]]):
- reason = f"Some refs in column {col} were not created by any op"
- if skip_failures:
- if verbose:
- logger.info(f"{reason}; skipping column {col}")
- continue
- else:
- raise ValueError(reason)
- filtered_cols.append(col)
- cols = filtered_cols
-
- for col in cols:
- calls_dict = adjacent_calls_data[col]
- op_id, output_name = list(calls_dict.keys())[0]
- calls = calls_dict[(op_id, output_name)]
- # produce input nodes
- representative_call: Call = calls[0]
- op = representative_call.func_op
- input_node_names = {}
- for input_name, input_tp in op.input_types.items():
- input_node_names[input_name], _ = res.join_var_node(
- refs=[call.inputs[input_name] for call in calls],
- tp=input_tp,
- name_hint=input_name,)
- res.join_op_node(
- calls=calls,
- out_map={output_name: col},
- in_map={k: v for k, v in input_node_names.items()},
- )
-
- return res
-
- @transaction()
- def forward(
- self,
- cols: Optional[Union[str, List[str]]] = None,
- inplace: bool = False,
- skip_failures: bool = False,
- verbose: bool = False,
- conn: Optional[Connection] = None,
- ) -> "ComputationFrame":
- res = self if inplace else self.copy_subgraph()
- if cols is None:
- # this means we want to expand the entire graph
- return res._forward_all(res, inplace=inplace, verbose=verbose, conn=conn)
- if verbose:
- logger.info(f"Expanding graph to include the consumers of columns {cols}")
- if isinstance(cols, str):
- cols = [cols]
-
- adjacent_calls_data = {
- col: res.get_adjacent_calls(col=col, direction="forward", conn=conn)
- for col in cols
- }
-
- filtered_cols = []
- raise NotImplementedError
- # for col, calls_dict in adjacent_calls_data.items():
-
- #
- # filtered_cols.append(col)
- # cols = filtered_cols
-
- # for col in cols:
- # calls_dict = adjacent_calls_data[col]
- # op_id, input_name = list(calls_dict.keys())[0]
- # calls = calls_dict[(op_id, input_name)]
- # # produce output nodes
- # representative_call: Call = calls[0]
- # op = representative_call.func_op
- # output_node_names = {}
- # for output_name, output_tp in enumerate(op.output_types):
- # output_node_names[dump_output_name(i=output_name)], _ = res.join_var_node(
- # refs=[call.outputs[dump_output_name(i=output_name)] for call in calls],
- # tp=output_tp
-
- @transaction()
- def delete(
- self,
- delete_dependents: bool,
- verbose: bool = True,
- ask: bool = True,
- conn: Optional[Connection] = None,
- ):
- """
- ! this is a powerful method that can delete a lot of data, use with caution
-
- Delete the calls referenced by this ComputationFrame from the storage, and
- clean up any orphaned refs.
-
- Warning: You probably want to apply this only on RefFunctors that are
- "forward-closed", i.e., that have been expanded to include all the calls
- that use their values. Otherwise, you may end up with obscure refs for
- which you have no provenance, i.e. "zombie" refs that have no meaning in
- the context of the rest of the storage. Alternatively, you can set
- `delete_dependents` to True, which will delete all the calls that depend
- on the calls in this ComputationFrame, and then clean up the orphaned refs.
- """
- # gather all the calls to be deleted
- call_uids_to_delete = defaultdict(list)
- call_outputs = {}
- for x in self.op_nodes.values():
- # process the call uids
- for call in x.calls:
- call_uids_to_delete[x.func_op.sig.versioned_ui_name].append(
- call.causal_uid
- )
- # process the call outputs
- if delete_dependents:
- for vnode in x.outputs.values():
- for ref in vnode.refs:
- call_outputs[ref.causal_uid] = ref
- if delete_dependents:
- dependent_calls = self.storage.get_dependent_calls(
- refs=list(call_outputs.values()), prov_df=self.prov_df, conn=conn
- )
- for call in dependent_calls:
- call_uids_to_delete[call.func_op.sig.versioned_ui_name].append(
- call.causal_uid
- )
- if verbose:
- # summarize the number of calls per op to be deleted
- for op, uids in call_uids_to_delete.items():
- print(f"Op {op} has {len(uids)} calls to be deleted")
- if ask:
- if input("Proceed? (y/n) ").strip().lower() != "y":
- logging.info("Aborting deletion")
- return
- for versioned_ui_name, call_uids in call_uids_to_delete.items():
- self.storage.rel_adapter.delete_calls(
- versioned_ui_name=versioned_ui_name,
- causal_uids=call_uids,
- conn=conn,
- )
- self.storage.rel_adapter.cleanup_vrefs(conn=conn, verbose=verbose)
-
- ############################################################################
- ### creating new RefFunctors
- ############################################################################
- def copy_subgraph(
- self,
- val_nodes: Optional[Iterable[ValNode]] = None,
- call_nodes: Optional[Iterable[CallNode]] = None,
- ) -> "ComputationFrame":
- """
- Get a copy of the ComputationFrame supported on the given nodes.
- """
- # must copy the graph
- val_nodes = (
- set(self.var_nodes.values()) if val_nodes is None else set(val_nodes)
- )
- call_nodes = (
- set(self.op_nodes.values()) if call_nodes is None else set(call_nodes)
- )
- val_map, call_map = copy_subgraph(
- vqs=val_nodes,
- fqs=call_nodes,
- )
- return ComputationFrame(
- call_nodes={
- k: call_map[v] for k, v in self.op_nodes.items() if v in call_map
- },
- val_nodes={
- k: val_map[v] for k, v in self.var_nodes.items() if v in val_map
- },
- storage=self.storage,
- prov_df=self.prov_df,
- )
-
- @staticmethod
- def from_refs(
- refs: Iterable[Ref],
- storage: Storage,
- prov_df: Optional[pd.DataFrame] = None,
- name: Optional[str] = None,
- ) -> "ComputationFrame":
- val_node = ValNode(
- constraint=None,
- tp=None,
- # refs=list(refs),
- refs=refs,
- )
- name = "v0" if name is None else name
- return ComputationFrame(
- call_nodes={},
- val_nodes={name: val_node},
- storage=storage,
- prov_df=prov_df,
- )
-
- @staticmethod
- def from_op(
- func: FuncInterface,
- storage: Storage,
- prov_df: Optional[pd.DataFrame] = None,
- ) -> "ComputationFrame":
- """
- Get a ComputationFrame expressing the memoization table for a single function
- """
- storage.synchronize(f=func)
- reftable = storage.get_table(func, values="lazy", meta=True)
- op = func.func_op
- if op.is_builtin:
- raise ValueError("Cannot create a ComputationFrame from a builtin op")
- input_nodes = {
- input_name: ValNode(
- constraint=None,
- tp=op.input_types[input_name],
- refs=reftable[input_name].values.tolist(),
- )
- for input_name in op.input_types.keys()
- }
- output_nodes = {
- dump_output_name(i): ValNode(
- constraint=None,
- tp=tp,
- refs=reftable[dump_output_name(i)].values.tolist(),
- )
- for i, tp in enumerate(op.output_types)
- }
- call_uids = reftable[Config.causal_uid_col].values.tolist()
- calls = storage.cache.call_mget(
- uids=call_uids,
- versioned_ui_name=op.sig.versioned_ui_name,
- by_causal=True,
- )
- call_node = CallNode.link(
- inputs=input_nodes,
- func_op=op,
- outputs=output_nodes,
- constraint=None,
- calls=calls,
- orientation=None,
- )
- return ComputationFrame(
- call_nodes={func.func_op.sig.versioned_ui_name: call_node},
- val_nodes={
- k: v
- for k, v in itertools.chain(input_nodes.items(), output_nodes.items())
- },
- storage=storage,
- prov_df=prov_df,
- )
-
- ############################################################################
- ### visualization
- ############################################################################
- def get_printer(self) -> GraphPrinter:
- printer = GraphPrinter(
- vqs=set(self.var_nodes.values()),
- fqs=set(self.op_nodes.values()),
- names={v: k for k, v in self.var_nodes.items()},
- fnames={v: k for k, v in self.op_nodes.items()},
- value_loader=ValueLoader(storage=self.storage),
- )
- return printer
-
- def _get_string_representation(self) -> str:
- printer = self.get_printer()
- graph_description = printer.print_computational_graph(
- show_sources_as="name_only"
- )
- # indent the graph description
- graph_description = textwrap.indent(graph_description, " ")
- return f"{self.__class__.__name__} with {self.num_vars} variable(s), {self.num_ops} operation(s) and {len(self)} row(s), representing the computation:\n{graph_description}"
-
- def __repr__(self) -> str:
- return self._get_string_representation()
- # return self.to_pandas().head(5).to_string()
-
- def print(self):
- print(self._get_string_representation())
-
- def show(self, how: Literal["inline", "browser"] = "browser"):
- visualize_graph(
- vqs=set(self.var_nodes.values()),
- fqs=set(self.op_nodes.values()),
- layout="computational",
- names={v: k for k, v in self.var_nodes.items()},
- show_how=how,
- )
-
- def get_new_vname(self, hint: Optional[str] = None) -> str:
- """
- Return the first name of the form `v{i}` that is not in self.val_nodes
- """
- if hint is not None and hint not in self.var_nodes:
- return hint
- i = 0
- prefix = "v" if hint is None else hint
- while f"{prefix}{i}" in self.var_nodes:
- i += 1
- return f"{prefix}{i}"
-
- def get_new_cname(self, op: FuncOp) -> str:
- if op.sig.versioned_ui_name not in self.op_nodes:
- return op.sig.versioned_ui_name
- i = 0
- while f"{op.sig.versioned_ui_name}_{i}" in self.op_nodes:
- i += 1
- return f"{op.sig.versioned_ui_name}_{i}"
-
- def rename(self, columns: Dict[str, str], inplace: bool = False):
- for old_name, new_name in columns.items():
- if old_name not in self.var_nodes:
- raise ValueError(f"Column {old_name} does not exist")
- if new_name in self.var_nodes:
- raise ValueError(f"Column {new_name} already exists")
- if inplace:
- res = self
- else:
- res = self.copy_subgraph()
- for old_name, new_name in columns.items():
- res.var_nodes[new_name] = res.var_nodes.pop(old_name)
- return res
-
- def r(
- self,
- inplace: bool = False,
- **kwargs,
- ) -> "ComputationFrame":
- """
- Fast alias for rename
- """
- return self.rename(columns=kwargs, inplace=inplace)
-
- @property
- def num_vars(self) -> int:
- return len(self.var_nodes)
-
- @property
- def num_ops(self) -> int:
- return len(self.op_nodes)
-
- def get_var_info(
- self,
- include_uniques: bool = False,
- small_threshold_bytes: int = 4096,
- units: Literal["bytes", "KB", "MB", "GB"] = "MB",
- sample_size: int = 20,
- ) -> pd.DataFrame:
- var_rows = []
- for k, v in self.var_nodes.items():
- if len(v.refs) == 0:
- avg_size, std = 0, 0
- else:
- avg_size_bytes, std_bytes = estimate_uid_storage(
- uids=[ref.uid for ref in v.refs],
- storage=self.storage,
- units="bytes",
- sample_size=sample_size,
- )
- avg_size, std = convert_bytes_to(
- num_bytes=avg_size_bytes, units=units
- ), convert_bytes_to(num_bytes=std_bytes, units=units)
- # round to 2 decimal places
- avg_size, std = round(avg_size, 2), round(std, 2)
- var_data = {
- "name": k,
- "size": f"{avg_size}±{std} {units}",
- "nunique": len(set(ref.uid for ref in v.refs)),
- }
- if include_uniques:
- if avg_size_bytes < small_threshold_bytes:
- uniques = {ref.uid: ref for ref in v.refs}
- uniques_values = self.storage.unwrap(list(uniques.values()))
- try:
- uniques_values = sorted(uniques_values)
- except:
- pass
- var_data["unique_values"] = uniques_values
- else:
- var_data["unique_values"] = ""
- var_rows.append(var_data)
- var_df = pd.DataFrame(var_rows)
- var_df.set_index("name", inplace=True)
- var_df = var_df.sort_values(by="size", ascending=False)
- return var_df
-
- def get_op_info(self) -> pd.DataFrame:
- rows = []
- for k, op_node in self.op_nodes.items():
- input_types = op_node.func_op.input_types
- output_types = {
- dump_output_name(index=i): op_node.func_op.output_types[i]
- for i in range(len(op_node.func_op.output_types))
- }
- input_types_dict = {
- k: input_types.get(k, output_types.get(k))
- for k in op_node.inputs.keys()
- }
- output_types_dict = {
- k: output_types.get(k, input_types.get(k))
- for k in op_node.outputs.keys()
- }
- signature = f'{op_node.func_op.sig.ui_name}({", ".join([f"{k}: {v}" for k, v in input_types_dict.items()])}) -> {", ".join([f"{k}: {v}" for k, v in output_types_dict.items()])}'
- rows.append(
- {
- "name": k,
- "function": op_node.func_op.sig.ui_name,
- "version": op_node.func_op.sig.version,
- "num_calls": len(op_node.calls),
- "num_unique_calls": len(
- set(call.causal_uid for call in op_node.calls)
- ),
- "signature": signature,
- }
- )
- op_df = pd.DataFrame(rows)
- op_df.set_index("name", inplace=True)
- return op_df
-
- def info(
- self,
- units: Literal["bytes", "KB", "MB", "GB"] = "MB",
- sample_size: int = 20,
- show_uniques: bool = False,
- small_threshold_bytes: int = 4096,
- ):
- """
- Print some basic info about the ComputationFrame
- """
- # print(self.__class__)
- var_df = self.get_var_info(
- include_uniques=show_uniques,
- small_threshold_bytes=small_threshold_bytes,
- units=units,
- sample_size=sample_size,
- )
- op_df = self.get_op_info()
- print(
- f"{self.__class__.__name__} with {self.num_vars} variable(s), {self.num_ops} operation(s), {len(self)} row(s)"
- )
- printer = self.get_printer()
- print("Computation graph:")
- print(
- textwrap.indent(
- printer.print_computational_graph(show_sources_as="name_only"), " "
- )
- )
- try:
- print("Variables:")
- import prettytable
- from io import StringIO
-
- output = StringIO()
- var_df.to_csv(output)
- output.seek(0)
- pt = prettytable.from_csv(output)
- print(textwrap.indent(pt.get_string(), " "))
- print("Operations:")
- output = StringIO()
- op_df.to_csv(output)
- output.seek(0)
- pt = prettytable.from_csv(output)
- print(textwrap.indent(pt.get_string(), " "))
- except ImportError:
- print("Variables:")
- print(textwrap.indent(var_df.to_string(), " "))
- print("Operations:")
- print(textwrap.indent(op_df.to_string(), " "))
- # representative_node = self.val_nodes[list(self.val_nodes.keys())[0]]
- # num_rows = len(representative_node.refs)
- # print(f"ComputationFrame with {self.num_vars} variable(s) and {self.num_ops} operations(s), representing {num_rows} computations")
-
- ############################################################################
- ### `Transactable` interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return self.storage.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.storage.rel_storage._end_transaction(conn=conn)
diff --git a/mandala/ui/cfs_utils.py b/mandala/ui/cfs_utils.py
deleted file mode 100644
index 9d99375..0000000
--- a/mandala/ui/cfs_utils.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from ..common_imports import *
-from ..core.config import *
-from .storage import Storage
-
-################################################################################
-### tools to summarize tables of refs
-################################################################################
-def convert_bytes_to(
- num_bytes: float, units: Literal["bytes", "KB", "MB", "GB"]
-) -> float:
- if units == "KB":
- return num_bytes / 1024
- elif units == "MB":
- return num_bytes / (1024**2)
- elif units == "GB":
- return num_bytes / (1024**3)
- elif units == "bytes":
- return num_bytes
- else:
- raise ValueError(f"Unknown units: {units}")
-
-
-def estimate_uid_storage(
- uids: List[str],
- storage: Storage,
- units: Literal["bytes", "KB", "MB", "GB"] = "bytes",
- sample_size: int = 20,
-) -> Tuple[float, float]:
- # sample 5 random elements from the column
- sample_uids = pd.Series(uids).sample(sample_size, replace=True).values
- query = (
- "SELECT length(value) AS size_in_bytes FROM __vrefs__ WHERE __uid__ IN "
- + str(tuple(sample_uids))
- )
- sample_counts = storage.rel_storage.execute_df(query)
- mean, std = (
- sample_counts["size_in_bytes"].astype(int).mean(),
- sample_counts["size_in_bytes"].astype(int).std(),
- )
- # check if std is nan
- if np.isnan(std):
- std = 0
- # convert to requested units
- if units == "KB":
- avg_column_size = mean / 1024
- std = std / 1024
- elif units == "MB":
- avg_column_size = mean / (1024**2)
- std = std / (1024**2)
- elif units == "GB":
- avg_column_size = mean / (1024**3)
- std = std / (1024**3)
- elif units == "bytes":
- avg_column_size = mean
- else:
- raise ValueError(f"Unknown units: {units}")
- # round to 2 decimal places
- avg_column_size = round(avg_column_size, 2)
- std = round(std, 2)
- return avg_column_size, std
diff --git a/mandala/ui/cfs_wip.py b/mandala/ui/cfs_wip.py
deleted file mode 100644
index f51d89b..0000000
--- a/mandala/ui/cfs_wip.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import textwrap
-from ..common_imports import *
-from ..core.config import *
-from ..core.tps import Type
-from ..core.model import Ref, FuncOp, Call
-from ..core.builtins_ import StructOrientations, Builtins
-from ..storages.rel_impls.utils import Transactable, transaction, Connection
-from ..queries.graphs import copy_subgraph
-from ..core.prov import propagate_struct_provenance
-from ..queries.weaver import CallNode, ValNode, PaddedList
-from ..queries.viz import GraphPrinter, visualize_graph
-from .funcs import FuncInterface
-from .storage import Storage, ValueLoader
-from .cfs_utils import estimate_uid_storage, convert_bytes_to
-
-
-
-class CanonicalCF:
- """
- Some experiments with more general cfs
- """
- pass
-
-
-
-class PaddedList(Sequence[Optional[T]]):
- """
- A list-like object that is backed by a list of values and a list of indices,
- and has length `length`. When indexed, it returns the value from the list at
- the corresponding index, or None if the index is not in the list of indices.
- """
-
- def __init__(self, support: Dict[int, T], length: int):
- self.support = support
- self.length = length
-
- def __repr__(self) -> str:
- return f"PaddedList({self.tolist()})"
-
- def tolist(self) -> List[Optional[T]]:
- return [self.support.get(i, None) for i in range(self.length)]
-
- def copy(self) -> "PaddedList":
- return PaddedList(support=self.support.copy(), length=self.length)
-
- def copy_item(self, i: int, times: int, inplace: bool = False) -> "PaddedList":
- res = self if inplace else self.copy()
- if i not in res.support:
- res.length = res.length + times
- else:
- for j in range(res.length, res.length + times):
- res.support[j] = res.support[i]
- res.length += times
- return res
-
- def change_length(self, length: int, inplace: bool = False) -> "PaddedList":
- res = self if inplace else self.copy()
- res.length = length
- return res
-
- def append_items(self, items: List[T], inplace: bool = False) -> "PaddedList":
- res = self if inplace else self.copy()
- for i, item in enumerate(items):
- res.support[res.length + i] = item
- res.length += len(items)
- return res
-
- @staticmethod
- def from_list(lst: List[Optional[T]], length: Optional[int] = None) -> "PaddedList":
- if length is None:
- length = len(lst)
- items = {i: v for i, v in enumerate(lst) if v is not None}
- return PaddedList(support=items, length=length)
-
- def dropna(self) -> List[T]:
- return [self.support[k] for k in sorted(self.support.keys())]
-
- @staticmethod
- def padded_like(plist: "PaddedList[T1]", values: List[T]) -> "PaddedList[T]":
- support = {i: values[j] for j, i in enumerate(sorted(plist.support.keys()))}
- return PaddedList(support=support, length=len(plist))
-
- def __len__(self) -> int:
- return self.length
-
- def __iter__(self) -> Iterator[Optional[T]]:
- return (self.support.get(i, None) for i in range(self.length))
-
- def __getitem__(
- self, idx: Union[int, slice, List[bool], np.ndarray]
- ) -> Union[T, "PaddedList"]:
- if isinstance(idx, int):
- return self.support.get(idx, None)
- elif isinstance(idx, slice):
- raise NotImplementedError
- elif isinstance(idx, (list, np.ndarray)):
- return self.masked(idx)
- else:
- raise NotImplementedError(
- "Indexing only supported for integers, slices, and boolean arrays"
- )
-
- def masked(self, mask: Union[List[bool], np.ndarray]) -> "PaddedList":
- """
- Return a new `PaddedList` object with the values masked by the given
- boolean array, and the indices updated accordingly.
- """
- if len(mask) != self.length:
- raise ValueError("Boolean mask must have the same length as the list")
- result_items = {}
- cur_masked_idx = 0
- for mask_idx, m in enumerate(mask):
- if m:
- if mask_idx in self.support:
- result_items[cur_masked_idx] = self.support[mask_idx]
- cur_masked_idx += 1
- return PaddedList(support=result_items, length=cur_masked_idx)
-
- def keep_only(self, indices: Set[int]) -> "PaddedList":
- """
- Return a new `PaddedList` object with the values masked by the given
- list of indices.
- """
- items = {i: v for i, v in self.support.items() if i in indices}
- return PaddedList(support=items, length=self.length)
-
diff --git a/mandala/ui/context_cache.py b/mandala/ui/context_cache.py
deleted file mode 100644
index 848290f..0000000
--- a/mandala/ui/context_cache.py
+++ /dev/null
@@ -1,255 +0,0 @@
-from ..common_imports import *
-from ..core.model import Call, Ref, collect_detached
-from ..storages.rels import RelAdapter, VersionAdapter
-from ..deps.versioner import Versioner
-from ..storages.kv import KVCache, InMemoryStorage, MultiProcInMemoryStorage
-from ..storages.rel_impls.utils import Transactable, transaction, Connection
-
-
-class Cache(Transactable):
- """
- A layer between calls happening in the context and the persistent storage.
- Also responsible for detaching objects from the computational graph.
-
- All calls represented in the cache are detached (see `Call.detached`), but
- objects are represented as `Ref` instances (which are not detached).
-
- TODO: protect the objects in the cache from being modified.
- """
-
- def __init__(
- self,
- rel_adapter: RelAdapter,
- ):
- self.rel_adapter = rel_adapter
- # uid -> detached call
- self.call_cache_by_uid = InMemoryStorage()
- # causal uid -> detached call
- self.call_cache_by_causal = InMemoryStorage()
- # uid -> unlinked ref without causal
- self.obj_cache = InMemoryStorage()
-
- def mcache_call_and_objs(self, calls: List[Call]) -> None:
- # a more efficient version of `cache_call_and_objs` for multiple calls
- # that avoids calling `unlinked` multiple times on the same object,
- # which could be expensive for large collections.
- unique_vrefs = {}
- unique_calls = {}
- for call in calls:
- for vref in itertools.chain(call.inputs.values(), call.outputs):
- unique_vrefs[vref.uid] = vref
- unique_calls[call.causal_uid] = call
- for vref in unique_vrefs.values():
- self.obj_cache[vref.uid] = vref.unlinked(keep_causal=False)
- for call in unique_calls.values():
- self.cache_call(causal_uid=call.causal_uid, call=call)
-
- def cache_call_and_objs(self, call: Call) -> None:
- for vref in itertools.chain(call.inputs.values(), call.outputs):
- self.obj_cache[vref.uid] = vref.unlinked(keep_causal=False)
- self.cache_call(causal_uid=call.causal_uid, call=call)
-
- def cache_call(self, causal_uid: str, call: Call) -> None:
- self.call_cache_by_causal.set(causal_uid, call.detached())
- self.call_cache_by_uid.set(call.uid, call.detached())
-
- def mattach(
- self, vrefs: List[Ref], shallow: bool = False, _attach_atoms: bool = True
- ):
- """
- Regardless of `shallow`, recursively find all the refs that can be found
- in the cache. Then pass what's left to the storage method.
- """
- vrefs = collect_detached(vrefs, include_transient=False)
- cur_frontier = vrefs
- new_frontier = []
- not_found = []
- while len(cur_frontier) > 0:
- for vref in cur_frontier:
- if (
- vref.uid in self.obj_cache.keys()
- and self.obj_cache[vref.uid].in_memory
- ):
- vref.attach(reference=self.obj_cache[vref.uid])
- new_frontier.extend(
- collect_detached([vref], include_transient=False)
- )
- else:
- not_found.append(vref)
- cur_frontier = new_frontier
- new_frontier = []
- if len(not_found) > 0:
- self.rel_adapter.mattach(
- vrefs=not_found, shallow=shallow, _attach_atoms=_attach_atoms
- )
-
- def obj_get(self, obj_uid: str, causal_uid: Optional[str] = None) -> Ref:
- """
- Get the given object from the cache or the storage.
- """
- #! note that this is not transactional to avoid creating a connection
- #! when the object is already in the cache
- if self.obj_cache.exists(obj_uid):
- res = self.obj_cache.get(obj_uid)
- else:
- res = self.rel_adapter.obj_get(uid=obj_uid)
- res = res.clone()
- if causal_uid is not None:
- res.causal_uid = causal_uid
- return res
-
- def call_exists(self, uid: str, by_causal: bool) -> bool:
- #! note that this is not transactional to avoid creating a connection
- #! when the object is already in the cache
- if by_causal:
- return self.call_cache_by_causal.exists(
- uid
- ) or self.rel_adapter.call_exists(uid=uid, by_causal=True)
- else:
- return self.call_cache_by_uid.exists(uid) or self.rel_adapter.call_exists(
- uid=uid, by_causal=False
- )
-
- @transaction()
- def call_mget(
- self,
- uids: List[str],
- versioned_ui_name: str,
- by_causal: bool = True,
- lazy: bool = True,
- conn: Optional[Connection] = None,
- ) -> List[Call]:
- if not by_causal:
- raise NotImplementedError()
- if not lazy:
- raise NotImplementedError()
- res = [None for _ in uids]
- missing_indices = []
- missing_uids = []
- for i, uid in enumerate(uids):
- if self.call_cache_by_causal.exists(uid):
- res[i] = self.call_cache_by_causal.get(uid)
- else:
- missing_indices.append(i)
- missing_uids.append(uid)
- if len(missing_uids) > 0:
- lazy_calls = self.rel_adapter.mget_call_lazy(
- versioned_ui_name=versioned_ui_name,
- uids=missing_uids,
- by_causal=True,
- conn=conn,
- )
- for i, lazy_call in zip(missing_indices, lazy_calls):
- res[i] = lazy_call
- return res
-
- def call_get(self, uid: str, by_causal: bool, lazy: bool = True) -> Call:
- """
- Return a *detached* call with the given UID, if it exists.
- """
- #! note that this is not transactional to avoid creating a connection
- #! when the object is already in the cache
- if by_causal and self.call_cache_by_causal.exists(uid):
- return self.call_cache_by_causal.get(uid)
- elif not by_causal and self.call_cache_by_uid.exists(uid):
- return self.call_cache_by_uid.get(uid)
- else:
- lazy_call = self.rel_adapter.call_get_lazy(uid=uid, by_causal=by_causal)
- if not lazy:
- #! you need to be more careful here about guarantees provided by
- #! the cache
- raise NotImplementedError
- # # load the values of the inputs and outputs
- # inputs = {
- # k: self.obj_get(v.uid)
- # for k, v in lazy_call.inputs.items()
- # }
- # outputs = [self.obj_get(v.uid) for v in lazy_call.outputs]
- # call_without_outputs = lazy_call.set_input_values(inputs=inputs)
- # call = call_without_outputs.set_output_values(outputs=outputs)
- # return call
- else:
- return lazy_call
-
- @transaction()
- def commit(
- self,
- calls: Optional[List[Call]] = None,
- versioner: Optional[Versioner] = None,
- version_adapter: VersionAdapter = None,
- conn: Optional[Connection] = None,
- ):
- """
- Flush dirty (written since last time) calls and objs from the cache to
- persistent storage, and mark them as clean.
-
- Note that the cache keeps calls and objects in memory in case they are
- needed again, but the storage is the only source of truth.
- """
- if calls is None:
- new_objs = {
- key: self.obj_cache.get(key) for key in self.obj_cache.dirty_entries
- }
- new_calls = [
- self.call_cache_by_causal.get(key)
- for key in self.call_cache_by_causal.dirty_entries
- ]
- else:
- #! if calls are provided, we assume they are attached
- new_objs = {}
- for call in calls:
- for vref in itertools.chain(call.inputs.values(), call.outputs):
- new_obj = self.obj_cache[vref.uid]
- assert new_obj.in_memory
- new_objs[vref.uid] = new_obj
- new_calls = calls
- self.rel_adapter.obj_sets(new_objs, conn=conn)
- self.rel_adapter.upsert_calls(new_calls, conn=conn)
- # if self.evict_on_commit:
- # self.evict_caches()
- if versioner is not None:
- version_adapter.dump_state(state=versioner, conn=conn)
- self.clear_all()
-
- @transaction()
- def preload_objs(self, uids: List[str], conn: Optional[Connection] = None):
- """
- Put the objects with the given UIDs in the cache. Should be used for
- bulk loading b/c it opens a connection
- """
- uids_not_in_cache = [uid for uid in uids if not self.obj_cache.exists(uid)]
- for uid, vref in zip(
- uids_not_in_cache,
- self.rel_adapter.obj_gets(uids=uids_not_in_cache, conn=conn),
- ):
- self.obj_cache.set(k=uid, v=vref)
-
- def evict_all(self):
- """
- Remove all entries from the cache.
- """
- self.call_cache_by_causal.evict_all()
- self.call_cache_by_uid.evict_all()
- self.obj_cache.evict_all()
-
- def clear_all(self):
- """
- Mark all entries as clean, but don't remove them from the cache.
- """
- self.call_cache_by_causal.clear_all()
- self.call_cache_by_uid.clear_all()
- self.obj_cache.clear_all()
-
- def detach_all(self):
- for k, v in self.call_cache_by_causal.items():
- self.call_cache_by_causal[k] = v.detached()
- for k, v in self.call_cache_by_uid.items():
- self.call_cache_by_uid[k] = v.detached()
- for k, v in self.obj_cache.items():
- self.obj_cache[k] = v.detached()
-
- def _get_connection(self) -> Connection:
- return self.rel_adapter._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_adapter._end_transaction(conn=conn)
diff --git a/mandala/ui/contexts.py b/mandala/ui/contexts.py
deleted file mode 100644
index 5c65d0e..0000000
--- a/mandala/ui/contexts.py
+++ /dev/null
@@ -1,166 +0,0 @@
-from typing import Literal
-
-from ..common_imports import *
-from ..core.config import Config
-from ..core.model import Call
-from ..queries.workflow import Workflow
-from ..queries.graphs import get_canonical_order
-from ..queries.weaver import traverse_all
-from ..queries.viz import get_names, extract_names_from_scope
-from ..queries.main import Querier
-
-from ..deps.versioner import CodeState, Versioner
-from .utils import MODES
-
-from ..queries.weaver import (
- qwrap,
-)
-
-
-class GlobalContext:
- current: Optional["Context"] = None
-
-
-class Context:
- def __init__(
- self,
- storage: "storage.Storage" = None,
- mode: str = MODES.run,
- lazy: bool = False,
- allow_calls: bool = True,
- debug_calls: bool = False,
- recompute_transient: bool = False,
- _attach_call_to_outputs: bool = False, # for debugging
- debug_truncate: Optional[int] = 20,
- ):
- self.storage = storage
- self.mode = mode
- self.lazy = lazy
- self.allow_calls = allow_calls
- self.debug_calls = debug_calls
- self.recompute_transient = recompute_transient
- self._attach_call_to_outputs = _attach_call_to_outputs
- self.debug_truncate = debug_truncate
- self.updates = {}
- self._updates_stack = []
- self._call_depth = 0
- self._call_structs = []
- self._call_uids: Dict[Tuple[str, int], List[str]] = defaultdict(list)
- self._defined_funcs: List["FuncInterface"] = []
- self._call_buffer: List[Call] = []
- self._code_state: CodeState = None
- self._cached_versioner: Versioner = None
-
- def _backup_state(self, keys: Iterable[str]) -> Dict[str, Any]:
- res = {}
- for k in keys:
- cur_v = self.__dict__[f"{k}"]
- if k == "storage": # gotta use a pointer
- res[k] = cur_v
- else:
- res[k] = copy.deepcopy(cur_v)
- return res
-
- def __enter__(self) -> "Context":
- is_top = len(self._updates_stack) == 0
- ### verify update keys
- updates = self.updates
- if not all(
- k
- in (
- "storage",
- "mode",
- "lazy",
- "allow_calls",
- "debug_calls",
- "recompute_transient",
- "_attach_call_to_outputs",
- )
- for k in updates.keys()
- ):
- raise ValueError(updates.keys())
- if "mode" in updates.keys() and updates["mode"] not in MODES.all_:
- raise ValueError(updates["mode"])
- ### backup state
- before_update = self._backup_state(keys=updates.keys())
- # self._updates_stack.append(before_update)
- ### apply updates
- for k, v in updates.items():
- if v is not None:
- self.__dict__[f"{k}"] = v
- # Load state from remote
- if self.storage is not None:
- # self.storage.sync_with_remote()
- self.storage.sync_from_remote()
- if (
- self.mode in (MODES.run, MODES.query)
- and self.storage is not None
- and self.storage.versioned
- ):
- storage = self.storage
- if is_top:
- versioner, code_state = storage.sync_code()
- self._cached_versioner = versioner
- self._code_state = code_state
- # this is last so that any exceptions don't leave the context in an
- # inconsistent state
- self._updates_stack.append(before_update)
- return self
-
- def _undo_updates(self):
- """
- Roll back the updates from the current level
- """
- if not self._updates_stack:
- raise InternalError("No context to exit from")
- ascent_updates = self._updates_stack.pop()
- for k, v in ascent_updates.items():
- self.__dict__[f"{k}"] = v
- # unlink from global if done
- if len(self._updates_stack) == 0:
- GlobalContext.current = None
-
- def __exit__(self, exc_type, exc_value, exc_traceback):
- if exc_type is None:
- if self.mode == MODES.run:
- if self.storage is not None:
- # commit calls from temp partition to main and tabulate them
- if Config.autocommit:
- self.storage.commit(versioner=self._cached_versioner)
- self.storage.sync_to_remote()
- elif self.mode == MODES.query:
- pass
- elif self.mode == MODES.noop:
- pass
- elif self.mode == MODES.batch:
- executor = SimpleWorkflowExecutor()
- workflow = Workflow.from_call_structs(self._call_structs)
- calls = executor.execute(workflow=workflow, storage=self.storage)
- self.storage.commit(calls=calls, versioner=self._cached_versioner)
- else:
- raise InternalError(self.mode)
- self._undo_updates()
- return None
-
- def __call__(
- self,
- storage: Optional["Storage"] = None,
- allow_calls: bool = True,
- debug_calls: bool = False,
- recompute_transient: bool = False,
- _attach_call_to_outputs: bool = False,
- **updates,
- ):
- self.updates = {
- "storage": storage,
- "allow_calls": allow_calls,
- "debug_calls": debug_calls,
- "recompute_transient": recompute_transient,
- "_attach_call_to_outputs": _attach_call_to_outputs,
- **updates,
- }
- return self
-
-
-from . import storage
-from .executors import SimpleWorkflowExecutor
diff --git a/mandala/ui/executors.py b/mandala/ui/executors.py
deleted file mode 100644
index 034484c..0000000
--- a/mandala/ui/executors.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from ..common_imports import *
-from ..queries.workflow import Workflow
-from abc import ABC, abstractmethod
-from ..core.model import Call
-from ..core.config import MODES
-from ..core.wrapping import unwrap
-
-
-class WorkflowExecutor(ABC):
- @abstractmethod
- def execute(self, workflow: Workflow, storage: "storage.Storage") -> List[Call]:
- pass
-
-
-class SimpleWorkflowExecutor(WorkflowExecutor):
- def execute(self, workflow: Workflow, storage: "storage.Storage") -> List[Call]:
- result = []
- for op_node in workflow.op_nodes:
- call_structs = workflow.op_node_to_call_structs[op_node]
- for call_struct in call_structs:
- func_op, inputs, outputs = (
- call_struct.func_op,
- call_struct.inputs,
- call_struct.outputs,
- )
- assert all([not inp.is_delayed() for inp in inputs.values()])
- inputs = unwrap(inputs)
- fi = funcs.FuncInterface(func_op=func_op)
- with storage.run():
- _, call = fi.call(args=tuple(), kwargs=inputs)
- vref_outputs = call.outputs
- # vref_outputs, call, _ = r.call(args=tuple(), kwargs=inputs, conn=None)
- # vref_outputs, call, _ = storage.call_run(
- # func_op=func_op,
- # inputs=inputs,
- # _call_depth=0,
- # )
- # overwrite things
- for output, vref_output in zip(outputs, vref_outputs):
- output._obj = vref_output.obj
- output.uid = vref_output.uid
- output.in_memory = True
- result.append(call)
- # filter out repeated calls
- result = list({call.uid: call for call in result}.values())
- return result
-
-
-from . import storage
-from . import runner
-from . import contexts
-from . import funcs
diff --git a/mandala/ui/funcs.py b/mandala/ui/funcs.py
deleted file mode 100644
index de5331e..0000000
--- a/mandala/ui/funcs.py
+++ /dev/null
@@ -1,243 +0,0 @@
-from functools import wraps
-from ..common_imports import *
-from ..core.model import FuncOp, TransientObj, Call
-from ..core.utils import unwrap_decorators
-from ..core.sig import Signature, _postprocess_outputs
-from ..core.tps import AnyType
-from ..queries.weaver import ValNode, qwrap, call_query
-from ..deps.tracers.dec_impl import DecTracer
-
-from . import contexts
-from .utils import bind_inputs, format_as_outputs, MODES, wrap_atom
-
-
-def Q(pattern: Optional[Any] = None) -> "pattern":
- """
- Create a `ValQuery` instance to be used as a placeholder in a query
- """
- if pattern is None:
- # return ValQuery(creator=None, created_as=None)
- return ValNode(creators=[], created_as=[], constraint=None, tp=AnyType())
- else:
- return qwrap(obj=pattern)
-
-
-T = TypeVar("T")
-
-
-def Transient(obj: T, unhashable: bool = False) -> T:
- if contexts.GlobalContext.current is not None:
- return TransientObj(obj=obj, unhashable=unhashable)
- else:
- return obj
-
-
-class FuncInterface:
- """
- Wrapper around a memoized function.
-
- This is the object the `@op` decorator converts functions into.
- """
-
- def __init__(
- self,
- func_op: FuncOp,
- executor: str = "python",
- ):
- self.func_op = func_op
- self.__name__ = self.func_op.sig.ui_name
- self._is_synchronized = False
- self._is_invalidated = False
- self._storage_id = None
- self.executor = executor
- if (
- contexts.GlobalContext.current is not None
- and contexts.GlobalContext.current.mode == MODES.define
- ):
- contexts.GlobalContext.current._defined_funcs.append(self)
-
- @property
- def sig(self) -> Signature:
- return self.func_op.sig
-
- def __repr__(self) -> str:
- sig = self.func_op.sig
- if self._is_invalidated:
- # clearly distinguish stale functions
- return f"FuncInterface(func_name={sig.ui_name}, is_invalidated=True)"
- else:
- from rich.text import Text
-
- return f"FuncInterface(signature={sig})"
-
- def invalidate(self):
- self._is_invalidated = True
- self._is_synchronized = False
-
- @property
- def is_invalidated(self) -> bool:
- return self._is_invalidated
-
- def _preprocess_call(
- self, *args, **kwargs
- ) -> Tuple[Dict[str, Any], str, "storage.Storage", contexts.Context]:
- context = contexts.GlobalContext.current
- storage = context.storage
- if self._is_invalidated:
- raise RuntimeError(
- "This function has been invalidated due to a change in the signature, and cannot be called"
- )
- # synchronize if necessary
- storage.synchronize(self)
- # synchronize(func=self, storage=context.storage)
- inputs = bind_inputs(args, kwargs, mode=context.mode, func_op=self.func_op)
- mode = context.mode
- return inputs, mode, storage, context
-
- def call(self, args, kwargs) -> Tuple[Union[None, Any, Tuple[Any]], Optional[Call]]:
- # low-level API for more control over internal machinery
- r = runner.Runner(context=contexts.GlobalContext.current, func_op=self.func_op)
- # inputs, mode, storage, context = self._preprocess_call(*args,
- # **kwargs)
- if r.storage is not None:
- r.storage.synchronize(self)
- if r.mode == MODES.run:
- func_op = r.func_op
- r.preprocess(args, kwargs)
- if r.must_execute:
- r.pre_execute(conn=None)
- if r.tracer_option is not None:
- tracer = r.tracer_option
- with tracer:
- if isinstance(tracer, DecTracer):
- node = tracer.register_call(func=func_op.func)
- result = func_op.func(**r.func_inputs)
- if isinstance(tracer, DecTracer):
- tracer.register_return(node=node)
- else:
- result = func_op.func(**r.func_inputs)
- outputs = _postprocess_outputs(sig=func_op.sig, result=result)
- call = r.post_execute(outputs=outputs)
- else:
- call = r.load_call(conn=None)
- return r.postprocess(call=call), call
- return r.process_other_modes(args, kwargs), None
-
- def __call__(self, *args, **kwargs) -> Union[None, Any, Tuple[Any]]:
- return self.call(args, kwargs)[0]
-
-
-class AsyncFuncInterface(FuncInterface):
- async def call(
- self, args, kwargs
- ) -> Tuple[Union[None, Any, Tuple[Any]], Optional[Call]]:
- # low-level API for more control over internal machinery
- r = runner.Runner(context=contexts.GlobalContext.current, func_op=self.func_op)
- # inputs, mode, storage, context = self._preprocess_call(*args,
- # **kwargs)
- if r.storage is not None:
- r.storage.synchronize(self)
- if r.mode == MODES.run:
- func_op = r.func_op
- r.preprocess(args, kwargs)
- if r.must_execute:
- r.pre_execute(conn=None)
- if r.tracer_option is not None:
- tracer = r.tracer_option
- with tracer:
- if isinstance(tracer, DecTracer):
- node = tracer.register_call(func=func_op.func)
- result = await func_op.func(**r.func_inputs)
- if isinstance(tracer, DecTracer):
- tracer.register_return(node=node)
- else:
- result = await func_op.func(**r.func_inputs)
- outputs = _postprocess_outputs(sig=func_op.sig, result=result)
- call = r.post_execute(outputs=outputs)
- else:
- call = r.load_call(conn=None)
- return r.postprocess(call=call), call
- return r.process_other_modes(args, kwargs), None
-
- async def __call__(self, *args, **kwargs) -> Union[None, Any, Tuple[Any]]:
- return (await self.call(args, kwargs))[0]
-
-
-class FuncDecorator:
- # parametrized version of `@op` decorator
- def __init__(
- self,
- **kwargs,
- ):
- self.kwargs = kwargs
-
- def __call__(self, func: Callable) -> "func":
- # func = unwrap_decorators(func, strict=True)
- func_op = FuncOp(
- func=func,
- n_outputs_override=self.kwargs.get("n_outputs"),
- version=self.kwargs.get("version"),
- ui_name=self.kwargs.get("ui_name"),
- is_super=self.kwargs.get("is_super", False),
- )
- if inspect.iscoroutinefunction(func):
- InterfaceCls = AsyncFuncInterface
- else:
- InterfaceCls = FuncInterface
- return wraps(func)(
- InterfaceCls(
- func_op=func_op,
- executor=self.kwargs.get("executor", "python"),
- )
- )
-
-
-def op(
- version: Union[Callable, Optional[int]] = None,
- nout: Optional[int] = None,
- ui_name: Optional[str] = None,
- executor: str = "python",
-) -> "version": # a hack to make mypy/autocomplete happy
- if callable(version):
- # a hack to handle the @op case
- func = version
- # func = unwrap_decorators(func, strict=True)
- func_op = FuncOp(func=func, n_outputs_override=nout)
- if inspect.iscoroutinefunction(func):
- return wraps(func)(AsyncFuncInterface(func_op=func_op))
- else:
- return wraps(func)(FuncInterface(func_op=func_op))
- else:
- # @op(...) case
- return FuncDecorator(
- version=version,
- n_outputs=nout,
- ui_name=ui_name,
- executor=executor,
- )
-
-
-def superop(
- version: Union[Callable, Optional[int]] = None,
- ui_name: Optional[str] = None,
- executor: str = "python",
-) -> "version":
- if callable(version):
- # a hack to handle the @op case
- func = version
- # func_op = FuncOp(func=unwrap_decorators(func, strict=True),
- # is_super=True)
- func_op = FuncOp(func=func, is_super=True)
- if inspect.iscoroutinefunction(func):
- return AsyncFuncInterface(func_op=func_op)
- else:
- return FuncInterface(func_op=func_op)
- else:
- # @op(...) case
- return FuncDecorator(
- version=version, ui_name=ui_name, executor=executor, is_super=True
- )
-
-
-from . import storage
-from . import runner
diff --git a/mandala/ui/provenance.py b/mandala/ui/provenance.py
deleted file mode 100644
index ab965fd..0000000
--- a/mandala/ui/provenance.py
+++ /dev/null
@@ -1,234 +0,0 @@
-from ..common_imports import *
-from .storage import Storage
-from ..core.model import Ref, Call
-from ..core.builtins_ import StructRef, ListRef, DictRef, SetRef
-from ..core.config import Provenance, parse_output_idx
-from ..queries.weaver import ValNode, CallNode, StructOrientations, traverse_all
-from ..queries.viz import GraphPrinter
-
-BUILTIN_OP_IDS = ("__list___0", "__dict___0", "__set___0")
-OP_ID_TO_STRUCT_NAME = {
- "__list___0": "lst",
- "__dict___0": "dct",
- "__set___0": "st",
-}
-
-OP_ID_TO_ELT_NAME = {
- "__list___0": "elt",
- "__dict___0": "val",
- "__set___0": "elt",
-}
-
-ELT_NAMES = ("elt", "val")
-STRUCT_NAMES = ("lst", "dct", "st")
-IDX_NAMES = ("idx", "key")
-
-
-class ProvHelpers:
- """
- Tools for producing provenance graphs identical to runtime provenance graphs
- """
-
- def __init__(self, storage: Storage, prov_df: pd.DataFrame):
- self.storage = storage
- self.prov = prov_df.sort_index() # improve performance of .loc
-
- def get_call_df(self, call_causal_uid: str) -> pd.DataFrame:
- """
- Get the provenance dataframe for the given call causal uid
- """
- #! very inefficient
- return self.prov[self.prov[Provenance.call_causal_uid] == call_causal_uid]
-
- def get_df_as_input(self, causal_uid: str) -> pd.DataFrame:
- """
- Get the provenance dataframe for the given causal uid as an input
- """
- key = (causal_uid, "input")
- if key not in self.prov.index:
- # return empty like self.prov
- return self.prov.loc[[]]
- else:
- return self.prov.loc[[key]] # passing a list to .loc guarantees a dataframe
-
- def get_df_as_output(self, causal_uid: str) -> pd.DataFrame:
- """
- Get the provenance dataframe for the given causal uid as an output
- """
- key = (causal_uid, "output")
- if key not in self.prov.index:
- # return empty like self.prov
- return self.prov.loc[[]]
- else:
- res = self.prov.loc[[key]] # passing a list to .loc guarantees a dataframe
- if res.shape[0] > 1:
- raise NotImplementedError(
- f"causal uid {causal_uid} is the output of multiple calls"
- )
- return res
-
- def is_op_output(self, causal_uid: str) -> bool:
- """
- Check if the given causal uid is the output of a (non-struct) op call
- and verify that there is a unique such call
- """
- n = self.get_df_as_output(causal_uid=causal_uid).shape[0]
- if n > 1:
- raise NotImplementedError(
- f"causal uid {causal_uid} is the output of multiple op calls"
- )
- return n == 1
-
- def get_containing_structs(self, causal_uid: str) -> Tuple[List[str], List[str]]:
- """
- Return two lists: the causal UIDs of structs containing this ref, and
- the causal call UIDs of the corresponding structural calls
- """
- elt_rows = self.get_df_as_input(causal_uid=causal_uid)
- # restrict to structs containing this ref only
- elt_rows = elt_rows[elt_rows[Provenance.op_id].isin(BUILTIN_OP_IDS)]
- elt_rows = elt_rows[elt_rows[Provenance.name].isin(ELT_NAMES)]
- # get the call causal UIDs
- call_causal_uids = elt_rows[Provenance.call_causal_uid].tolist()
- # restrict to the rows containing the structs themselves
- x = self.prov[self.prov[Provenance.call_causal_uid].isin(call_causal_uids)]
- x = x[x[Provenance.name].isin(STRUCT_NAMES)]
- return (
- x.index.get_level_values(0).tolist(),
- x[Provenance.call_causal_uid].tolist(),
- )
-
- def get_struct_call_uids(self, causal_uid: str) -> List[str]:
- key = (causal_uid, "input")
- df = self.prov.loc[[key]]
- df = df[df[Provenance.op_id].isin(BUILTIN_OP_IDS)]
- df = df[df[Provenance.name].isin(STRUCT_NAMES)]
- return df[Provenance.call_causal_uid].tolist()
-
- def get_creator_chain(self, causal_uid: str) -> List[str]:
- """
- Return a list of call UIDs for the chain of calls creating a value (if any)
- """
- as_output_df = self.get_df_as_output(causal_uid=causal_uid)
- if as_output_df.shape[0] > 0:
- return as_output_df[Provenance.call_causal_uid].tolist()
- else:
- # gotta check for containing structs
- struct_uids, struct_call_uids = self.get_containing_structs(
- causal_uid=causal_uid
- )
- for struct_uid, struct_call_uid in zip(struct_uids, struct_call_uids):
- recursive_result = self.get_creator_chain(causal_uid=struct_uid)
- if len(recursive_result) > 0:
- return [struct_call_uid] + self.get_creator_chain(
- causal_uid=struct_uid
- )
- return []
-
- def link_call_into_graph(
- self,
- call_causal_uid: str,
- refs: Dict[str, Ref],
- orientation: Optional[str],
- calls: Dict[str, Call],
- ) -> Call:
- """
- Load the call object from storage, link in into the current graph
- (creating `Ref` instances when they are not already in `refs`), and add
- pointwise constraints to each ref (via the full UID).
-
- ! Note that for structs we don't unify the indices of the struct with the
- rest of the graph to avoid unnatural clashes between refs. We also
- don't put the indices in the `refs` object.
- """
- # load the call
- call = self.storage.rel_adapter.call_get_lazy(
- uid=call_causal_uid, by_causal=True
- )
- if call.full_uid in calls:
- return calls[call.full_uid]
- # look up inputs/outputs in `refs` or add them there
- for name, inp in call.inputs.items():
- if call.func_op.is_builtin and name in IDX_NAMES:
- continue
- if inp.causal_uid in refs:
- call.inputs[name] = refs[inp.causal_uid]
- else:
- refs[inp.causal_uid] = inp
- for idx, outp in enumerate(call.outputs):
- if outp.causal_uid in refs:
- call.outputs[idx] = refs[outp.causal_uid]
- else:
- refs[outp.causal_uid] = outp
- # actual linking step
- call.link(orientation=orientation)
- # add pointwise constraints
- for ref in itertools.chain(call.inputs.values(), call.outputs):
- ref.query.constraint = [ref.full_uid]
- calls[call.full_uid] = call
- return call
-
- def step(self, ref: Ref, refs: Dict[str, Ref], calls: Dict[str, Call]) -> List[Ref]:
- """
- Given the (*non-index*) refs constructed so far and a ref in the graph,
- link the chain of calls creating this ref. If some of the refs along the
- way already exist, use those instead of creating new refs.
- """
- # build one step of the provenance graph
- creator_chain = self.get_creator_chain(causal_uid=ref.causal_uid)
- if not creator_chain:
- if isinstance(ref, StructRef):
- constituent_call_uids = self.get_struct_call_uids(
- causal_uid=ref.causal_uid
- )
- elts = []
- for constituent_call_uid in constituent_call_uids:
- call = self.link_call_into_graph(
- call_causal_uid=constituent_call_uid,
- refs=refs,
- orientation=StructOrientations.construct,
- calls=calls,
- )
- op_id = call.func_op.sig.versioned_internal_name
- elts.append(call.inputs[OP_ID_TO_ELT_NAME[op_id]])
- return elts
- else:
- return []
- elif len(creator_chain) == 1:
- op_call = self.link_call_into_graph(
- call_causal_uid=creator_chain[0],
- refs=refs,
- orientation=None,
- calls=calls,
- )
- return list(op_call.inputs.values())
- else:
- for struct_call_uid in creator_chain[:-1]:
- self.link_call_into_graph(
- call_causal_uid=struct_call_uid,
- refs=refs,
- orientation=StructOrientations.destruct,
- calls=calls,
- )
- op_call = self.link_call_into_graph(
- call_causal_uid=creator_chain[-1],
- refs=refs,
- orientation=None,
- calls=calls,
- )
- return list(op_call.inputs.values())
-
- def get_graph(self, full_uid: str) -> Tuple[Set[ValNode], Set[CallNode]]:
- """
- Given the full UID of a ref, recover the full provenance graph for that
- ref
- """
- refs = {}
- calls = {}
- res = Ref.from_full_uid(full_uid=full_uid)
- refs[res.causal_uid] = res
- queue = [res]
- while queue:
- ref = queue.pop()
- queue.extend(self.step(ref=ref, refs=refs, calls=calls))
- return traverse_all(vqs=[res.query], direction="backward")
diff --git a/mandala/ui/remote_utils.py b/mandala/ui/remote_utils.py
deleted file mode 100644
index d62661d..0000000
--- a/mandala/ui/remote_utils.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from pypika import Table, Query
-import pyarrow.parquet as pq
-
-from ..common_imports import *
-from ..core.config import Config
-from ..storages.rels import RelAdapter, RemoteEventLogEntry
-from ..storages.rel_impls.bases import RelStorage
-from ..storages.sigs import SigSyncer, SigAdapter
-from ..storages.remote_storage import RemoteStorage
-from ..storages.rel_impls.utils import Transactable, transaction, Connection
-
-
-class RemoteManager(Transactable):
- def __init__(
- self,
- rel_adapter: RelAdapter,
- sig_adapter: SigAdapter,
- rel_storage: RelStorage,
- sig_syncer: SigSyncer,
- root: RemoteStorage,
- ):
- self.rel_adapter = rel_adapter
- self.sig_adapter = sig_adapter
- self.rel_storage = rel_storage
- self.sig_syncer = sig_syncer
- self.root = root
-
- def _get_connection(self) -> Connection:
- return self.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_storage._end_transaction(conn=conn)
-
- @transaction()
- def bundle_to_remote(
- self, conn: Optional[Connection] = None
- ) -> RemoteEventLogEntry:
- """
- Collect the new calls according to the event log, and pack them into a
- dict of binary blobs to be sent off to the remote server.
-
- NOTE: this also renames tables and columns to their immutable internal
- names.
- """
- # Bundle event log and referenced calls into tables.
- event_log_df = self.rel_adapter.get_event_log(conn=conn)
- tables_with_changes = {}
- table_names_with_changes = event_log_df["table"].unique()
-
- event_log_table = Table(self.rel_adapter.EVENT_LOG_TABLE)
- for table_name in table_names_with_changes:
- table = Table(table_name)
- tables_with_changes[table_name] = self.rel_storage.execute_arrow(
- query=Query.from_(table)
- .join(event_log_table)
- .on(table[Config.uid_col] == event_log_table[Config.uid_col])
- .select(table.star),
- conn=conn,
- )
- # pass to internal names
- tables_with_changes = self.sig_adapter.rename_tables(
- tables_with_changes, to="internal", conn=conn
- )
- output = {}
- for table_name, table in tables_with_changes.items():
- buffer = io.BytesIO()
- pq.write_table(table, buffer)
- output[table_name] = buffer.getvalue()
- return output
-
- @transaction()
- def apply_from_remote(
- self, changes: List[RemoteEventLogEntry], conn: Optional[Connection] = None
- ):
- """
- Apply new calls from the remote server.
-
- NOTE: this also renames tables and columns to their UI names.
- """
- for raw_changeset in changes:
- changeset_data = {}
- for table_name, serialized_table in raw_changeset.items():
- buffer = io.BytesIO(serialized_table)
- deserialized_table = pq.read_table(buffer)
- changeset_data[table_name] = deserialized_table
- # pass to UI names
- changeset_data = self.sig_adapter.rename_tables(
- tables=changeset_data, to="ui", conn=conn
- )
- for table_name, deserialized_table in changeset_data.items():
- self.rel_storage.upsert(table_name, deserialized_table, conn=conn)
-
- @transaction()
- def sync_from_remote(self, conn: Optional[Connection] = None):
- """
- Pull new calls from the remote server.
-
- Note that the server's schema (i.e. signatures) can be a super-schema of
- the local schema, but all local schema elements must be present in the
- remote schema, because this is enforced by how schema updates are
- performed.
- """
- if not isinstance(self.root, RemoteStorage):
- return
- # apply signature changes from the server first, because the new calls
- # from the server may depend on the new schema.
- self.sig_syncer.sync_from_remote(conn=conn)
- # next, pull new calls
- new_log_entries, timestamp = self.root.get_log_entries_since(
- self.last_timestamp
- )
- self.apply_from_remote(new_log_entries, conn=conn)
- self.last_timestamp = timestamp
- logger.debug("synced from remote")
-
- @transaction()
- def sync_to_remote(self, conn: Optional[Connection] = None):
- """
- Send calls to the remote server.
-
- As with `sync_from_remote`, the server may have a super-schema of the
- local schema. The current signatures are first pulled and applied to the
- local schema.
- """
- if not isinstance(self.root, RemoteStorage):
- # todo: there should be a way to completely ignore the event log
- # when there's no remote
- self.rel_adapter.clear_event_log(conn=conn)
- else:
- # collect new work and send it to the server
- changes = self.bundle_to_remote(conn=conn)
- self.root.save_event_log_entry(changes)
- # clear the event log only *after* the changes have been received
- self.rel_adapter.clear_event_log(conn=conn)
- logger.debug("synced to remote")
-
- @transaction()
- def sync_with_remote(self, conn: Optional[Connection] = None):
- if not isinstance(self.root, RemoteStorage):
- return
- self.sync_to_remote(conn=conn)
- self.sync_from_remote(conn=conn)
diff --git a/mandala/ui/runner.py b/mandala/ui/runner.py
deleted file mode 100644
index 97c2ce6..0000000
--- a/mandala/ui/runner.py
+++ /dev/null
@@ -1,344 +0,0 @@
-from ..common_imports import *
-from ..core.model import Call, Ref
-from ..core.config import Config
-from ..core.builtins_ import StructOrientations
-from .storage import FuncOp, Connection
-from .utils import bind_inputs, format_as_outputs, wrap_atom
-from .contexts import Context
-
-from ..queries.weaver import call_query
-
-from .utils import MODES, debug_call, get_terminal_data, check_determinism
-
-from ..common_imports import *
-from ..core.config import Config
-from ..core.model import Ref, Call, FuncOp
-from ..core.builtins_ import Builtins
-from ..core.wrapping import (
- wrap_inputs,
- wrap_outputs,
- causify_outputs,
- decausify,
- unwrap,
- contains_transient,
- contains_not_in_memory,
-)
-
-from ..storages.rel_impls.utils import Transactable, transaction, Connection
-
-from ..deps.tracers import TracerABC
-from ..deps.versioner import Versioner
-
-from ..queries.weaver import StructOrientations
-
-
-class Runner(Transactable): # this is terrible
- def __init__(self, context: Optional[Context], func_op: FuncOp):
- self.context: Context = context
- self.storage = context.storage if context is not None else None
- self.func_op = func_op
- self.mode = context.mode if context is not None else MODES.noop
- self.code_state = context._code_state if context is not None else None
- #! todo - these should be set by the context
- self.recurse = False
- self.collect_calls = False
-
- ### set by preprocess
- self.linking_on: bool = None
- self.must_execute: bool = None
- self.suspended_trace_obj: Optional[Any] = None
- self.tracer_option: Optional[TracerABC] = None
- self.versioner: Optional[Versioner] = None
- self.wrapped_inputs: Dict[str, Ref] = None
- self.input_calls: List[Call] = None
- self.pre_call_uid: str = None
- self.call_option: Optional[Call] = None
-
- ### set by prep_execute
- self.must_save: bool = None
- self.is_recompute: bool = None
- self.func_inputs: Dict[str, Any] = None
-
- def process_other_modes(
- self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
- ) -> Union[None, Any, Tuple[Any]]:
- if self.mode == MODES.query:
- inputs = bind_inputs(args, kwargs, mode=self.mode, func_op=self.func_op)
- return self.process_query(inputs=inputs)
- elif self.mode == MODES.batch:
- inputs = bind_inputs(args, kwargs, mode=self.mode, func_op=self.func_op)
- return self.process_batch(inputs=inputs)
- elif self.mode == MODES.noop:
- return self.func_op.func(*args, **kwargs)
- else:
- raise NotImplementedError
-
- def process_query(self, inputs: Dict[str, Any]) -> Union[None, Any, Tuple[Any]]:
- return format_as_outputs(
- outputs=call_query(func_op=self.func_op, inputs=inputs)
- )
-
- def process_batch(self, inputs: Dict[str, Any]) -> Union[None, Any, Tuple[Any]]:
- wrapped_inputs = {k: wrap_atom(v) for k, v in inputs.items()}
- outputs, call_struct = self.storage.call_batch(
- func_op=self.func_op, inputs=wrapped_inputs
- )
- self.context._call_structs.append(call_struct)
- return format_as_outputs(outputs=outputs)
-
- def preprocess(
- self,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- conn: Optional[Connection] = None,
- ):
- self.context._call_depth += 1
- self.linking_on = self.context._call_depth == 1
- func_op = self.func_op
- inputs = bind_inputs(args, kwargs, mode=self.mode, func_op=func_op)
-
- if self.storage.versioned:
- self.versioner = self.context._cached_versioner
- self.suspended_trace_obj = self.versioner.TracerCls.get_active_trace_obj()
- self.versioner.TracerCls.set_active_trace_obj(trace_obj=None)
- else:
- self.versioner = None
- self.suspended_trace_obj = None
-
- wrapped_inputs, input_calls = wrap_inputs(
- objs=inputs,
- annotations=func_op.input_annotations,
- )
- if self.linking_on:
- for input_call in input_calls:
- input_call.link(
- orientation=StructOrientations.construct,
- )
- pre_call_uid = func_op.get_pre_call_uid(
- input_uids={k: v.uid for k, v in wrapped_inputs.items()}
- )
- call_option = self.storage.lookup_call(
- func_op=func_op,
- pre_call_uid=pre_call_uid,
- input_uids={k: v.uid for k, v in wrapped_inputs.items()},
- input_causal_uids={k: v.causal_uid for k, v in wrapped_inputs.items()},
- conn=conn,
- code_state=self.code_state,
- versioner=self.versioner,
- )
- self.tracer_option = (
- self.versioner.make_tracer() if self.storage.versioned else None
- )
-
- # condition determining whether we will actually call the underlying function
- self.must_execute = (
- call_option is None
- or (self.recurse and func_op.is_super)
- or (
- call_option is not None
- and call_option.transient
- and self.context.recompute_transient
- )
- )
- self.wrapped_inputs = wrapped_inputs
- self.input_calls = input_calls
- self.pre_call_uid = pre_call_uid
- self.call_option = call_option
-
- @transaction()
- def pre_execute(
- self,
- conn: Optional[Connection],
- ):
- call_option = self.call_option
- wrapped_inputs = self.wrapped_inputs
- func_op = self.func_op
- self.is_recompute = (
- call_option is not None
- and call_option.transient
- and self.context.recompute_transient
- )
- needs_input_values = (
- call_option is None or call_option is not None and call_option.transient
- )
- pass_inputs_unwrapped = Config.autounwrap_inputs and not func_op.is_super
- self.must_save = call_option is None
- if not (self.recurse and func_op.is_super) and not self.context.allow_calls:
- raise ValueError(
- f"Call to {func_op.sig.ui_name} not found in call storage."
- )
- if needs_input_values:
- self.storage.cache.mattach(vrefs=list(wrapped_inputs.values()))
- if any(contains_not_in_memory(ref=ref) for ref in wrapped_inputs.values()):
- msg = (
- "Cannot execute function whose inputs are transient values "
- "that are not in memory. "
- "Use `recompute_transient=True` to force recomputation of these inputs."
- )
- raise ValueError(msg)
- if pass_inputs_unwrapped:
- self.func_inputs = unwrap(obj=wrapped_inputs, through_collections=True)
- else:
- self.func_inputs = wrapped_inputs
-
- def post_execute(self, outputs: List[Any]):
- call_option = self.call_option
- wrapped_inputs = self.wrapped_inputs
- pre_call_uid = self.pre_call_uid
- input_calls = self.input_calls
- func_op = self.func_op
-
- if self.tracer_option is not None:
- # check the trace against the code state hypothesis
- self.versioner.apply_state_hypothesis(
- hypothesis=self.code_state, trace_result=self.tracer_option.graph.nodes
- )
- # update the global topology and code state
- self.versioner.update_global_topology(graph=self.tracer_option.graph)
- self.code_state.add_globals_from(graph=self.tracer_option.graph)
-
- content_version, semantic_version = (
- self.versioner.get_version_ids(
- pre_call_uid=pre_call_uid,
- tracer_option=self.tracer_option,
- is_recompute=self.is_recompute,
- )
- if self.storage.versioned
- else (None, None)
- )
- call_uid = func_op.get_call_uid(
- pre_call_uid=pre_call_uid, semantic_version=semantic_version
- )
- wrapped_outputs, output_calls = wrap_outputs(
- objs=outputs,
- annotations=func_op.output_annotations,
- )
- transient = any(contains_transient(ref) for ref in wrapped_outputs)
- if self.is_recompute:
- check_determinism(
- observed_semver=semantic_version,
- stored_semver=call_option.semantic_version,
- observed_output_uids=[v.uid for v in wrapped_outputs],
- stored_output_uids=[w.uid for w in call_option.outputs],
- func_op=func_op,
- )
- if self.linking_on:
- wrapped_outputs = [x.unlinked(keep_causal=True) for x in wrapped_outputs]
- call = Call(
- uid=call_uid,
- func_op=func_op,
- inputs=wrapped_inputs,
- outputs=wrapped_outputs,
- content_version=content_version,
- semantic_version=semantic_version,
- transient=transient,
- )
- for outp in wrapped_outputs:
- decausify(ref=outp) # for now, a "clean slate" approach
- causify_outputs(refs=wrapped_outputs, call_causal_uid=call.causal_uid)
- if self.linking_on:
- call.link(orientation=None)
- output_calls = [Builtins.collect_all_calls(x) for x in wrapped_outputs]
- output_calls = [x for y in output_calls for x in y]
- for output_call in output_calls:
- output_call.link(
- orientation=StructOrientations.destruct,
- )
- if self.must_save:
- self.storage.cache.mcache_call_and_objs(
- calls=[call] + input_calls + output_calls
- )
- return call
-
- @transaction()
- def load_call(self, conn: Optional[Connection] = None):
- call_option = self.call_option
- wrapped_inputs = self.wrapped_inputs
- input_calls = self.input_calls
- func_op = self.func_op
-
- assert call_option is not None
-
- if not self.context.lazy:
- self.storage.cache.preload_objs(
- [v.uid for v in call_option.outputs], conn=conn
- )
- wrapped_outputs = [
- self.storage.cache.obj_get(v.uid) for v in call_option.outputs
- ]
- else:
- wrapped_outputs = [v for v in call_option.outputs]
- # recreate call
- call = Call(
- uid=call_option.uid,
- func_op=func_op,
- inputs=wrapped_inputs,
- outputs=wrapped_outputs,
- content_version=call_option.content_version,
- semantic_version=call_option.semantic_version,
- transient=call_option.transient,
- )
- if self.linking_on:
- call.outputs = [x.unlinked(keep_causal=False) for x in call.outputs]
- causify_outputs(refs=call.outputs, call_causal_uid=call.causal_uid)
- if not self.context.lazy:
- output_calls = [Builtins.collect_all_calls(x) for x in call.outputs]
- output_calls = [x for y in output_calls for x in y]
-
- else:
- output_calls = []
- call.link(orientation=None)
- for output_call in output_calls:
- output_call.link(
- orientation=StructOrientations.destruct,
- )
- else:
- causify_outputs(refs=call.outputs, call_causal_uid=call.causal_uid)
- if call.causal_uid != call_option.causal_uid:
- # this is a new call causally; must save it and its constituents
- output_calls = [Builtins.collect_all_calls(x) for x in call.outputs]
- output_calls = [x for y in output_calls for x in y]
- self.storage.cache.mcache_call_and_objs(
- calls=[call] + input_calls + output_calls
- )
- return call
-
- def postprocess(
- self, call: Call, output_format: str = "returns"
- ) -> Union[None, Any, Tuple[Any]]:
- func_op = self.func_op
- if self.storage.versioned and self.suspended_trace_obj is not None:
- self.versioner.TracerCls.set_active_trace_obj(
- trace_obj=self.suspended_trace_obj
- )
- terminal_data = get_terminal_data(func_op=func_op, call=call)
- # Tracer.leaf_signal(data=terminal_data)
- self.versioner.TracerCls.register_leaf_event(
- trace_obj=self.suspended_trace_obj, data=terminal_data
- )
- # self.tracer_impl.suspended_tracer.register_leaf(data=terminal_data)
- if self.context.debug_calls:
- debug_call(
- func_name=func_op.sig.ui_name,
- memoized=self.call_option is not None,
- wrapped_inputs=self.wrapped_inputs,
- wrapped_outputs=call.outputs,
- )
- if self.collect_calls:
- self.context._call_buffer.append(call)
- sig = func_op.sig
- self.context._call_uids[(sig.internal_name, sig.version)].append(call.uid)
- self.context._call_depth -= 1
- if self.context._attach_call_to_outputs:
- for output in call.outputs:
- output._call = call.detached()
- if output_format == "list":
- return call.outputs
- elif output_format == "returns":
- return format_as_outputs(outputs=call.outputs)
-
- def _get_connection(self) -> Connection:
- return self.storage.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.storage.rel_storage._end_transaction(conn=conn)
diff --git a/mandala/ui/storage.py b/mandala/ui/storage.py
deleted file mode 100644
index 96351a2..0000000
--- a/mandala/ui/storage.py
+++ /dev/null
@@ -1,1258 +0,0 @@
-import datetime
-import tqdm
-from typing import Literal
-from collections import deque
-
-from .viz import _get_colorized_diff
-from .utils import MODES
-from .remote_utils import RemoteManager
-from . import contexts
-from .context_cache import Cache
-
-from ..common_imports import *
-from ..core.config import Config, Provenance, dump_output_name
-from ..core.model import Ref, Call, FuncOp, ValueRef, Delayed
-from ..core.builtins_ import Builtins, ListRef, DictRef, SetRef
-from ..core.wrapping import (
- unwrap,
- compare_dfs_as_relations,
-)
-from ..core.tps import Type, AnyType, ListType, DictType, SetType
-from ..core.sig import Signature
-from ..core.utils import get_uid, OpKey
-
-from ..storages.rel_impls.utils import Transactable, transaction, Connection
-from ..storages.kv import InMemoryStorage, MultiProcInMemoryStorage, KVCache
-
-if Config.has_duckdb:
- from ..storages.rel_impls.duckdb_impl import DuckDBRelStorage
-from ..storages.rel_impls.sqlite_impl import SQLiteRelStorage
-from ..storages.rels import RelAdapter, VersionAdapter
-from ..storages.sigs import SigSyncer
-from ..storages.remote_storage import RemoteStorage
-from ..deps.tracers import DecTracer
-from ..deps.versioner import Versioner, CodeState
-from ..deps.utils import get_dep_key_from_func, extract_func_obj
-from ..deps.model import DepKey, TerminalData
-
-from ..queries.workflow import CallStruct
-from ..queries.weaver import (
- ValNode,
- CallNode,
- traverse_all,
-)
-from ..queries.viz import (
- visualize_graph,
- print_graph,
- get_names,
- extract_names_from_scope,
- GraphPrinter,
- ValueLoaderABC,
-)
-from ..queries.main import Querier
-from ..queries.graphs import get_canonical_order, InducedSubgraph
-
-from ..core.prov import propagate_struct_provenance
-
-
-class Storage(Transactable):
- """
- Groups together all the components of the storage system.
-
- Responsible for things that require multiple components to work together,
- e.g.
- - committing: moving calls from the "temporary" partition to the "main"
- partition. See also `CallStorage`.
- - synchronizing: connecting an operation with the storage and performing
- any necessary updates
- """
-
- def __init__(
- self,
- db_path: Optional[Union[str, Path]] = None,
- db_backend: str = Config.db_backend,
- spillover_dir: Optional[Union[str, Path]] = None,
- spillover_threshold_mb: Optional[float] = None,
- root: Optional[Union[Path, RemoteStorage]] = None,
- timestamp: Optional[datetime.datetime] = None,
- multiproc: bool = False,
- evict_on_commit: bool = None,
- signatures: Optional[Dict[Tuple[str, int], Signature]] = None,
- _read_only: bool = False,
- ### dependency tracking config
- deps_path: Optional[Union[Path, str]] = None,
- deps_package: Optional[str] = None,
- track_methods: bool = True,
- _strict_deps: bool = True, # for testing only
- tracer_impl: Optional[type] = None,
- ):
- self.root = root
- # all objects (inputs and outputs to operations, defaults) are saved here
- # stores the memoization tables
- if db_path is None and Config._persistent_storage_testing:
- # get a temp db path
- # generate a random filename
- db_name = f"db_{get_uid()}.db"
- db_path = Path(
- os.path.join(
- os.path.dirname(os.path.abspath(__file__)),
- f"../../temp_dbs/{db_name}",
- )
- ).resolve()
- self.db_path = db_path
- self.db_backend = db_backend
- self.evict_on_commit = (
- Config.evict_on_commit if evict_on_commit is None else evict_on_commit
- )
- if Config.has_duckdb and db_backend == "duckdb":
- DBImplementation = DuckDBRelStorage
- else:
- DBImplementation = SQLiteRelStorage
- self.rel_storage = DBImplementation(
- address=None if db_path is None else str(db_path),
- _read_only=_read_only,
- )
-
- # manipulates the memoization tables
- self.rel_adapter = RelAdapter(
- rel_storage=self.rel_storage,
- spillover_dir=Path(spillover_dir) if spillover_dir is not None else None,
- spillover_threshold_mb=spillover_threshold_mb,
- )
-
- self.cache = Cache(rel_adapter=self.rel_adapter)
-
- # self.versions_adapter = VersionAdapter(rel_adapter=rel_adapter)
- self.sig_adapter = self.rel_adapter.sig_adapter
- self.sig_syncer = SigSyncer(sig_adapter=self.sig_adapter, root=self.root)
- if signatures is not None:
- self.sig_adapter.dump_state(state=signatures)
- self.last_timestamp = (
- timestamp if timestamp is not None else datetime.datetime.fromtimestamp(0)
- )
-
- self.version_adapter = VersionAdapter(rel_adapter=self.rel_adapter)
- if deps_path is not None:
- deps_path = (
- Path(deps_path).absolute().resolve()
- if deps_path != "__main__"
- else "__main__"
- )
- roots = [] if deps_path == "__main__" else [deps_path]
- self._versioned = True
- current_versioner = self.version_adapter.load_state()
- if current_versioner is not None:
- if current_versioner.paths != roots:
- raise ValueError(
- f"Found existing versioner with roots {current_versioner.paths}, but "
- f"was asked to use {roots}"
- )
- else:
- versioner = Versioner(
- paths=roots,
- TracerCls=DecTracer if tracer_impl is None else tracer_impl,
- strict=_strict_deps,
- track_methods=track_methods,
- package_name=deps_package,
- )
- self.version_adapter.dump_state(state=versioner)
- else:
- self._versioned = False
-
- if root is not None:
- self.remote_manager = RemoteManager(
- rel_adapter=self.rel_adapter,
- sig_adapter=self.sig_adapter,
- rel_storage=self.rel_storage,
- sig_syncer=self.sig_syncer,
- root=self.root,
- )
- else:
- self.remote_manager = None
-
- # set up builtins
- for func_op in Builtins.OPS.values():
- self.synchronize_op(func_op=func_op)
-
- @property
- def in_memory(self) -> bool:
- return self.db_path is None
-
- @transaction()
- def get_versioner(self, conn: Optional[Connection] = None) -> Versioner:
- result = self.version_adapter.load_state(conn=conn)
- if result is None:
- raise ValueError("This storage is not versioned.")
- return result
-
- @property
- def versioned(self) -> bool:
- return self._versioned
-
- ############################################################################
- ### `Transactable` interface
- ############################################################################
- def _get_connection(self) -> Connection:
- return self.rel_storage._get_connection()
-
- def _end_transaction(self, conn: Connection):
- return self.rel_storage._end_transaction(conn=conn)
-
- def unwrap(self, obj: Any) -> Any:
- with self.run():
- res = unwrap(obj)
- return res
-
- @transaction()
- def commit(
- self,
- calls: Optional[List[Call]] = None,
- versioner: Optional[Versioner] = None,
- conn: Optional[Connection] = None,
- ):
- """
- Flush calls and objs from the cache that haven't yet been written to the database.
- """
- self.cache.commit(
- calls=calls,
- versioner=versioner,
- version_adapter=self.version_adapter,
- conn=conn,
- )
-
- @transaction()
- def eval_df(
- self,
- full_uids_df: pd.DataFrame,
- drop_duplicates: bool = False,
- values: Literal["objs", "refs", "uids", "full_uids", "lazy"] = "objs",
- conn: Optional[Connection] = None,
- ) -> pd.DataFrame:
- """
- - ! this function loads objects in the cache; this is probably not
- transparent to the user
- - Note that currently we pass the full UIDs as input, and thus we return
- values with causal UIDs. Maybe it is desirable in some settings to
- disable this behavior.
-
- Can handle nulls in the input dataframe.
- """
- if values == "full_uids":
- return full_uids_df
- has_meta = set(Config.special_call_cols).issubset(full_uids_df.columns)
- ref_cols = [
- col for col in full_uids_df.columns if col not in Config.special_call_cols
- ]
- uids_df = full_uids_df[ref_cols].applymap(
- lambda uid: uid.rsplit(".", 1)[0] if uid is not None else None
- )
- if has_meta:
- uids_df[Config.special_call_cols] = full_uids_df[Config.special_call_cols]
- if values in ("objs", "refs"):
- uids_to_collect = [
- item
- for _, column in uids_df.items()
- for _, item in column.items()
- if item is not None
- ]
- self.cache.preload_objs(uids_to_collect, conn=conn)
- if values == "objs":
- result = uids_df.applymap(
- lambda uid: unwrap(self.cache.obj_get(uid))
- if uid is not None
- else None
- )
- else:
- result = full_uids_df.applymap(
- lambda full_uid: self.cache.obj_get(
- obj_uid=full_uid.rsplit(".", 1)[0],
- causal_uid=full_uid.rsplit(".", 1)[1],
- )
- if full_uid is not None
- else None
- )
- elif values == "uids":
- result = uids_df
- elif values == "lazy":
- result = full_uids_df[ref_cols].applymap(
- lambda full_uid: Ref.from_uid(
- uid=full_uid.rsplit(".", 1)[0],
- causal_uid=full_uid.rsplit(".", 1)[1],
- )
- if full_uid is not None
- else None
- )
- if has_meta:
- result[Config.special_call_cols] = full_uids_df[
- Config.special_call_cols
- ]
- else:
- raise ValueError(
- f"Invalid value for `values`: {values}. Must be one of "
- "['objs', 'refs', 'uids', 'lazy']"
- )
- if drop_duplicates:
- result = result.drop_duplicates()
- return result
-
- @transaction()
- def get_table(
- self,
- func_interface: Union["funcs.FuncInterface", Any],
- meta: bool = False,
- values: Literal["objs", "uids", "full_uids", "refs", "lazy"] = "objs",
- drop_duplicates: bool = False,
- conn: Optional[Connection] = None,
- ) -> pd.DataFrame:
- full_uids_df = self.rel_storage.get_data(
- table=func_interface.func_op.sig.versioned_ui_name, conn=conn
- )
- if not meta:
- full_uids_df = full_uids_df.drop(columns=Config.special_call_cols)
- df = self.eval_df(
- full_uids_df=full_uids_df,
- values=values,
- drop_duplicates=drop_duplicates,
- conn=conn,
- )
- return df
-
- ############################################################################
- ### synchronization
- ############################################################################
- @transaction()
- def synchronize_op(
- self,
- func_op: FuncOp,
- conn: Optional[Connection] = None,
- ):
- # first, pull the current data from the remote!
- self.sig_syncer.sync_from_remote(conn=conn)
- # this step also sends the signature to the remote
- new_sig = self.sig_syncer.sync_from_local(sig=func_op.sig, conn=conn)
- func_op.sig = new_sig
- # to send any default values that were created by adding inputs
- self.sync_to_remote(conn=conn)
-
- @transaction()
- def synchronize(
- self, f: Union["funcs.FuncInterface", Any], conn: Optional[Connection] = None
- ):
- if f._is_invalidated:
- raise RuntimeError(
- "This function has been invalidated due to a change in the signature, and cannot be called"
- )
- # if f._is_synchronized:
- # if f._storage_id != id(self):
- # raise RuntimeError(
- # "This function is already synchronized with a different storage object. Re-define the function to synchronize it with this storage object."
- # )
- # return
- self.synchronize_op(func_op=f.func_op, conn=conn)
- f._is_synchronized = True
- f._storage_id = id(self)
-
- ############################################################################
- ### versioning
- ############################################################################
- @transaction()
- def guess_code_state(
- self, versioner: Optional[Versioner] = None, conn: Optional[Connection] = None
- ) -> CodeState:
- if versioner is None:
- versioner = self.get_versioner(conn=conn)
- return versioner.guess_code_state()
-
- @transaction()
- def sync_code(
- self, conn: Optional[Connection] = None
- ) -> Tuple[Versioner, CodeState]:
- versioner = self.get_versioner(conn=conn)
- code_state = self.guess_code_state(versioner=versioner, conn=conn)
- versioner.sync_codebase(code_state=code_state)
- return versioner, code_state
-
- @transaction()
- def sync_component(
- self,
- component: types.FunctionType,
- is_semantic_change: Optional[bool],
- conn: Optional[Connection] = None,
- ):
- # low-level versioning
- dep_key = get_dep_key_from_func(func=component)
- versioner = self.get_versioner(conn=conn)
- code_state = self.guess_code_state(versioner=versioner, conn=conn)
- result = versioner.sync_component(
- component=dep_key,
- is_semantic_change=is_semantic_change,
- code_state=code_state,
- )
- self.version_adapter.dump_state(state=versioner, conn=conn)
- return result
-
- @transaction()
- def _show_version_data(
- self,
- f: Union[Callable, "funcs.FuncInterface"],
- deps: bool = True,
- meta: bool = False,
- plain: bool = False,
- compact: bool = False,
- conn: Optional[Connection] = None,
- ):
- # show the versions of a function, with/without its dependencies
- func = extract_func_obj(obj=f, strict=True)
- component = get_dep_key_from_func(func=func)
- versioner = self.get_versioner(conn=conn)
- if deps:
- versioner.show_versions(
- component=component,
- include_metadata=meta,
- plain=plain,
- )
- else:
- versioner.component_dags[component].show(
- compact=compact, plain=plain, include_metadata=meta
- )
-
- @transaction()
- def versions(
- self,
- f: Union[Callable, "funcs.FuncInterface"],
- meta: bool = False,
- plain: bool = False,
- conn: Optional[Connection] = None,
- ):
- self._show_version_data(
- f=f,
- deps=True,
- meta=meta,
- plain=plain,
- compact=False,
- conn=conn,
- )
-
- @transaction()
- def sources(
- self,
- f: Union[Callable, "funcs.FuncInterface"],
- meta: bool = False,
- plain: bool = False,
- compact: bool = False,
- conn: Optional[Connection] = None,
- ):
- func = extract_func_obj(obj=f, strict=True)
- component = get_dep_key_from_func(func=func)
- versioner = self.get_versioner(conn=conn)
- print(
- f"Revision history for the source code of function {component[1]} from module {component[0]} "
- '("===HEAD===" is the current version):'
- )
- versioner.component_dags[component].show(
- compact=compact, plain=plain, include_metadata=meta
- )
-
- @transaction()
- def code(
- self, version_id: str, meta: bool = False, conn: Optional[Connection] = None
- ):
- # show a copy-pastable version of the code for a given version id. Plain
- # by design.
- result = self.get_code(version_id=version_id, show=False, meta=meta, conn=conn)
- print(result)
-
- @transaction()
- def get_code(
- self,
- version_id: str,
- show: bool = True,
- meta: bool = False,
- conn: Optional[Connection] = None,
- ) -> str:
- versioner = self.get_versioner(conn=conn)
- for dag in versioner.component_dags.values():
- if version_id in dag.commits.keys():
- text = dag.get_content(commit=version_id)
- if show:
- print(text)
- return text
- for (
- content_version,
- version,
- ) in versioner.get_flat_versions().items():
- if version_id == content_version:
- raw_string = versioner.present_dependencies(
- commits=version.semantic_expansion,
- include_metadata=meta,
- )
- if show:
- print(raw_string)
- return raw_string
- raise ValueError(f"version id {version_id} not found")
-
- @transaction()
- def diff(
- self,
- id_1: str,
- id_2: str,
- context_lines: int = 2,
- conn: Optional[Connection] = None,
- ):
- code_1: str = self.get_code(version_id=id_1, show=False, conn=conn)
- code_2: str = self.get_code(version_id=id_2, show=False, conn=conn)
- print(
- _get_colorized_diff(current=code_1, new=code_2, context_lines=context_lines)
- )
-
- ############################################################################
- ### make calls in contexts
- ############################################################################
- @transaction()
- def _load_memoization_tables(
- self, evaluate: bool = False, conn: Optional[Connection] = None
- ) -> Dict[str, pd.DataFrame]:
- """
- Get a dict of {versioned internal name: memoization table} for all
- functions. Note that memoization tables are labeled by UI arg names.
- """
- sigs = self.sig_adapter.load_state(conn=conn)
- ui_to_internal = {
- sig.versioned_ui_name: sig.versioned_internal_name for sig in sigs.values()
- }
- ui_call_data = self.rel_adapter.get_all_call_data(conn=conn)
- call_data = {ui_to_internal[k]: v for k, v in ui_call_data.items()}
- if evaluate:
- call_data = {
- k: self.eval_df(full_uids_df=v, values="objs", conn=conn)
- for k, v in call_data.items()
- }
- return call_data
-
- @transaction()
- def get_compatible_semantic_versions(
- self,
- fqs: Set[CallNode],
- conn: Optional[Connection] = None,
- ) -> Tuple[Optional[Dict[OpKey, Set[str]]], Optional[Dict[DepKey, Set[str]]]]:
- if not self.versioned:
- return None, None
- if contexts.GlobalContext.current is not None:
- code_state = contexts.GlobalContext.current._code_state
- else:
- code_state = self.guess_code_state(
- versioner=self.get_versioner(), conn=conn
- )
- result_ops = {}
- result_deps = {}
- versioner = self.get_versioner(conn=conn)
- for func_query in fqs:
- sig = func_query.func_op.sig
- op_key = (sig.internal_name, sig.version)
- # dep_key = get_dep_key_from_func(func=func_query.func_op.func)
- dep_key = (func_query.func_op._module, func_query.func_op._qualname)
- if func_query.func_op._is_builtin:
- result_ops[op_key] = None
- result_deps[dep_key] = None
- else:
- versions = versioner.get_semantically_compatible_versions(
- component=dep_key, code_state=code_state
- )
- result_ops[op_key] = set([v.semantic_version for v in versions])
- result_deps[dep_key] = result_ops[op_key]
- return result_ops, result_deps
-
- @transaction()
- def execute_query(
- self,
- selection: List[ValNode],
- vqs: Set[ValNode],
- fqs: Set[CallNode],
- names: Dict[ValNode, str],
- values: Literal["objs", "refs", "uids", "lazy"] = "objs",
- engine: Optional[Literal["sql", "naive", "_test"]] = None,
- local: bool = False,
- verbose: bool = True,
- drop_duplicates: bool = True,
- visualize_steps_at: Optional[Path] = None,
- conn: Optional[Connection] = None,
- ) -> pd.DataFrame:
- """
- Execute the given queries and return the result as a pandas DataFrame.
- """
- if engine is None:
- engine = Config.query_engine
-
- def rename_cols(
- df: pd.DataFrame, selection: List[ValNode], names: Dict[ValNode, str]
- ):
- df.columns = [str(i) for i in range(len(df.columns))]
- cols = [names[query] for query in selection]
- df.rename(columns=dict(zip(df.columns, cols)), inplace=True)
-
- Querier.validate_query(vqs=vqs, fqs=fqs, selection=selection, names=names)
- context = contexts.GlobalContext.current
- if verbose:
- print(
- "Pattern-matching to the following computational graph (all constraints apply):"
- )
- print_graph(vqs=vqs, fqs=fqs, names=names, selection=selection)
- if visualize_steps_at is not None:
- assert engine == "naive"
- if engine in ["sql", "_test"]:
- version_constraints, _ = self.get_compatible_semantic_versions(
- fqs=fqs, conn=conn
- )
- call_uids = context._call_uids if local else None
- query = Querier.compile(
- selection=selection,
- vqs=vqs,
- fqs=fqs,
- version_constraints=version_constraints,
- filter_duplicates=drop_duplicates,
- call_uids=call_uids,
- )
- start = time.time()
- sql_uids_df = self.rel_storage.execute_df(query=str(query), conn=conn)
- end = time.time()
- logger.debug(f"SQL query took {round(end - start, 3)} seconds")
- rename_cols(df=sql_uids_df, selection=selection, names=names)
- uids_df = sql_uids_df
- if engine in ["naive", "_test"]:
- memoization_tables = self._load_memoization_tables(conn=conn)
- logger.debug("Executing query naively...")
- naive_uids_df = Querier.execute_naive(
- vqs=vqs,
- fqs=fqs,
- selection=selection,
- memoization_tables=memoization_tables,
- filter_duplicates=drop_duplicates,
- table_evaluator=self.eval_df,
- visualize_steps_at=visualize_steps_at,
- )
- rename_cols(df=naive_uids_df, selection=selection, names=names)
- uids_df = naive_uids_df
- if engine == "_test":
- outcome, reason = compare_dfs_as_relations(
- df_1=sql_uids_df, df_2=naive_uids_df, return_reason=True
- )
- assert outcome, reason
- return self.eval_df(full_uids_df=uids_df, values=values, conn=conn)
-
- def _get_graph_and_names(
- self,
- objs: Tuple[Union[Ref, ValNode]],
- direction: Literal["forward", "backward", "both"] = "both",
- scope: Optional[Dict[str, Any]] = None,
- project: bool = False,
- ):
- vqs = {obj.query if isinstance(obj, Ref) else obj for obj in objs}
- vqs, fqs = traverse_all(vqs=vqs, direction=direction)
- hints = extract_names_from_scope(scope=scope) if scope is not None else {}
- g = InducedSubgraph(vqs=vqs, fqs=fqs)
- if project:
- v_proj, f_proj, _ = g.project()
- proj_hints = {v_proj[vq]: hints[vq] for vq in v_proj if vq in hints}
- names = get_names(
- hints=proj_hints,
- canonical_order=get_canonical_order(
- vqs=set(v_proj.values()), fqs=set(f_proj.values())
- ),
- )
- final_vqs = set(v_proj.values())
- final_fqs = set(f_proj.values())
- else:
- v_proj = {vq: vq for vq in vqs}
- f_proj = {fq: fq for fq in fqs}
- names = get_names(
- hints=hints,
- canonical_order=get_canonical_order(vqs=set(vqs), fqs=set(fqs)),
- )
- final_vqs = vqs
- final_fqs = fqs
- return final_vqs, final_fqs, names, v_proj, f_proj
-
- def draw_graph(
- self,
- *objs: Union[Ref, ValNode],
- traverse: Literal["forward", "backward", "both"] = "backward",
- project: bool = False,
- show_how: Literal["none", "browser", "inline", "open"] = "browser",
- ):
- scope = inspect.currentframe().f_back.f_locals
- vqs, fqs, names, v_proj, f_proj = self._get_graph_and_names(
- objs,
- direction=traverse,
- scope=scope,
- project=project,
- )
- visualize_graph(vqs, fqs, names=names, show_how=show_how)
-
- def print_graph(
- self,
- *objs: Union[Ref, ValNode],
- project: bool = False,
- traverse: Literal["forward", "backward", "both"] = "backward",
- ):
- scope = inspect.currentframe().f_back.f_locals
- vqs, fqs, names, v_proj, f_proj = self._get_graph_and_names(
- objs,
- direction=traverse,
- scope=scope,
- project=project,
- )
- print_graph(
- vqs=vqs,
- fqs=fqs,
- names=names,
- selection=[
- v_proj[obj.query] if isinstance(obj, Ref) else obj for obj in objs
- ],
- )
-
- @transaction()
- def similar(
- self,
- *objs: Union[Ref, ValNode],
- values: Literal["objs", "refs", "uids", "lazy"] = "objs",
- context: bool = False,
- verbose: Optional[bool] = None,
- local: bool = False,
- drop_duplicates: bool = True,
- engine: Literal["sql", "naive", "_test"] = None,
- _visualize_steps_at: Optional[Path] = None,
- conn: Optional[Connection] = None,
- ) -> pd.DataFrame:
- scope = inspect.currentframe().f_back.f_back.f_locals
- return self.df(
- *objs,
- direction="backward",
- scope=scope,
- values=values,
- context=context,
- skip_objs=False,
- verbose=verbose,
- local=local,
- drop_duplicates=drop_duplicates,
- engine=engine,
- _visualize_steps_at=_visualize_steps_at,
- conn=conn,
- )
-
- @transaction()
- def df(
- self,
- *objs: Union[Ref, ValNode],
- direction: Literal["forward", "backward", "both"] = "both",
- values: Literal["objs", "refs", "uids", "lazy"] = "objs",
- context: bool = False,
- skip_objs: bool = False,
- verbose: Optional[bool] = None,
- local: bool = False,
- drop_duplicates: bool = True,
- engine: Literal["sql", "naive", "_test"] = None,
- _visualize_steps_at: Optional[Path] = None,
- scope: Optional[Dict[str, Any]] = None,
- conn: Optional[Connection] = None,
- ) -> pd.DataFrame:
- """
- Universal query method over computational graphs, both imperative and
- declarative.
- """
- if verbose is None:
- verbose = Config.verbose_queries
- if not all(isinstance(obj, (Ref, ValNode)) for obj in objs):
- raise ValueError(
- "All arguments to df() must be either `Ref`s or `ValQuery`s."
- )
- #! important
- # We must sync any dirty cache elements to the db before performing a query.
- # If we don't, we'll query a store that might be missing calls and objs.
- self.commit(versioner=None)
- selection = [obj.query if isinstance(obj, Ref) else obj for obj in objs]
- # deps = get_deps(nodes=set(selection))
- vqs, fqs = traverse_all(vqs=set(selection), direction=direction)
- if scope is None:
- scope = inspect.currentframe().f_back.f_back.f_locals
- name_hints = extract_names_from_scope(scope=scope)
- v_map, f_map, target_selection, target_names = Querier.prepare_projection_query(
- vqs=vqs, fqs=fqs, selection=selection, name_hints=name_hints
- )
- target_vqs, target_fqs = set(v_map.values()), set(f_map.values())
- if context:
- g = InducedSubgraph(vqs=target_vqs, fqs=target_fqs)
- _, _, topsort = g.canonicalize()
- target_selection = [vq for vq in topsort if isinstance(vq, ValNode)]
- df = self.execute_query(
- selection=target_selection,
- vqs=set(v_map.values()),
- fqs=set(f_map.values()),
- values=values,
- names=target_names,
- verbose=verbose,
- drop_duplicates=drop_duplicates,
- visualize_steps_at=_visualize_steps_at,
- engine=engine,
- local=local,
- conn=conn,
- )
- for col in df.columns:
- try:
- df = df.sort_values(by=col)
- except Exception:
- continue
- if skip_objs:
- # drop the dtypes that are objects
- df = df.select_dtypes(exclude=["object"])
- return df
-
- def _make_terminal_data(self, func_op: FuncOp, call: Call) -> TerminalData:
- terminal_data = TerminalData(
- op_internal_name=func_op.sig.internal_name,
- op_version=func_op.sig.version,
- call_content_version=call.content_version,
- call_semantic_version=call.semantic_version,
- dep_key=get_dep_key_from_func(func=func_op.func),
- )
- return terminal_data
-
- @transaction()
- def lookup_call(
- self,
- func_op: FuncOp,
- pre_call_uid: str,
- input_uids: Dict[str, str],
- input_causal_uids: Dict[str, str],
- code_state: Optional[CodeState] = None,
- versioner: Optional[Versioner] = None,
- conn: Optional[Connection] = None,
- ) -> Optional[Call]:
- """
- Return a *detached* call for the given function and inputs, if it
- exists.
- """
- if not self.versioned:
- semantic_version = None
- else:
- assert code_state is not None
- component = get_dep_key_from_func(func=func_op.func)
- lookup_outcome = versioner.lookup_call(
- component=component, pre_call_uid=pre_call_uid, code_state=code_state
- )
- if lookup_outcome is None:
- return
- else:
- _, semantic_version = lookup_outcome
- causal_uid = func_op.get_call_causal_uid(
- input_uids=input_uids,
- input_causal_uids=input_causal_uids,
- semantic_version=semantic_version,
- )
- if self.cache.call_exists(uid=causal_uid, by_causal=True):
- return self.cache.call_get(uid=causal_uid, by_causal=True, lazy=True)
- call_uid = func_op.get_call_uid(
- pre_call_uid=pre_call_uid, semantic_version=semantic_version
- )
- if self.cache.call_exists(uid=call_uid, by_causal=False):
- return self.cache.call_get(uid=call_uid, by_causal=False, lazy=True)
- return None
-
- def call_batch(
- self, func_op: FuncOp, inputs: Dict[str, Ref]
- ) -> Tuple[List[Ref], CallStruct]:
- output_types = [Type.from_annotation(a) for a in func_op.output_annotations]
- outputs = [make_delayed(tp=tp) for tp in output_types]
- call_struct = CallStruct(func_op=func_op, inputs=inputs, outputs=outputs)
- return outputs, call_struct
-
- ############################################################################
- ### low-level provenance interfaces
- ############################################################################
- @transaction()
- def get_creators(
- self,
- refs: List[Ref],
- prov_df: Optional[pd.DataFrame] = None,
- conn: Optional[Connection] = None,
- ) -> Tuple[List[Optional[Call]], List[Optional[str]]]:
- """
- Given some Refs, return the
- - calls that created them (there may be at most one such call per Ref),
- or None if there was no such call.
- - the output name under which the refs were created
-
- ! This currently fails for refs created as a list of inputs.
- """
- if not refs:
- return [], []
- if prov_df is None:
- prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn)
- prov_df = propagate_struct_provenance(prov_df=prov_df)
- causal_uids = list([ref.causal_uid for ref in refs])
- assert all(x is not None for x in causal_uids)
- res_df = prov_df.query('causal in @causal_uids and direction_new == "output"')[
- ["causal", "call_causal", "name", "op_id"]
- ].set_index("causal")[["call_causal", "name", "op_id"]]
- if len(res_df) == 0:
- return [None] * len(refs), [None] * len(refs)
- if not res_df.index.is_unique:
- logging.warning(
- "Work in progress: Detected ref w/ multiple creators (this happens when a data structure is created explicitly out of its elements), choosing arbitrary creator."
- )
- causal_to_creator_call_uid = res_df["call_causal"].to_dict()
- causal_to_output_name = res_df["name"].to_dict()
- causal_to_op_id = res_df["op_id"].to_dict()
- op_groups = res_df.groupby("op_id")["call_causal"].apply(list).to_dict()
- call_causal_to_call = {}
- for op_id, call_causal_list in op_groups.items():
- internal_name, version = Signature.parse_versioned_name(
- versioned_name=op_id
- )
- versioned_ui_name = self.sig_adapter.load_state(conn=conn)[
- internal_name, version
- ].versioned_ui_name
- op_calls = self.cache.call_mget(
- uids=call_causal_list,
- by_causal=True,
- versioned_ui_name=versioned_ui_name,
- conn=conn,
- )
- call_causal_to_call.update({call.causal_uid: call for call in op_calls})
- calls = [
- call_causal_to_call[causal_to_creator_call_uid[causal_uid]]
- if causal_uid in causal_to_creator_call_uid
- else None
- for causal_uid in causal_uids
- ]
- output_names = [
- causal_to_output_name[causal_uid]
- if causal_uid in causal_to_output_name
- else None
- for causal_uid in causal_uids
- ]
- sess.d()
- return calls, output_names
-
- @transaction()
- def get_consumers(
- self,
- refs: List[Ref],
- prov_df: Optional[pd.DataFrame] = None,
- conn: Optional[Connection] = None,
- ) -> Tuple[List[List[Call]], List[List[str]]]:
- """
- Given some Refs, return the
- - calls that use them (there may be multiple such calls per Ref), or an empty list if there were no such calls.
- - the input names under which the refs were used
- """
- if prov_df is None:
- prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn)
- prov_df = propagate_struct_provenance(prov_df=prov_df)
- causal_uids = [ref.causal_uid for ref in refs]
- assert all(x is not None for x in causal_uids)
- res_groups = prov_df.query(
- 'causal in @causal_uids and direction_new == "input"'
- )[["causal", "call_causal", "name", "op_id"]]
- op_to_causal_to_call_uids_and_inp_names = defaultdict(dict)
- for causal, call_causal, name, op_id in res_groups.itertuples(index=False):
- if causal not in op_to_causal_to_call_uids_and_inp_names[op_id]:
- op_to_causal_to_call_uids_and_inp_names[op_id][causal] = []
- op_to_causal_to_call_uids_and_inp_names[op_id][causal].append(
- (call_causal, name)
- )
- op_id_to_versioned_ui_name = {}
- for op_id in op_to_causal_to_call_uids_and_inp_names.keys():
- internal_name, version = Signature.parse_versioned_name(
- versioned_name=op_id
- )
- op_id_to_versioned_ui_name[op_id] = self.sig_adapter.load_state(conn=conn)[
- internal_name, version
- ].versioned_ui_name
- op_to_causal_to_calls_and_inp_names = defaultdict(dict)
- for (
- op_id,
- causal_to_call_uids_and_inp_names,
- ) in op_to_causal_to_call_uids_and_inp_names.items():
- versioned_ui_name = op_id_to_versioned_ui_name[op_id]
- op_calls = self.cache.call_mget(
- uids=[
- elt[0]
- for v in causal_to_call_uids_and_inp_names.values()
- for elt in v
- ],
- by_causal=True,
- versioned_ui_name=versioned_ui_name,
- conn=conn,
- )
- call_causal_to_call = {call.causal_uid: call for call in op_calls}
- op_to_causal_to_calls_and_inp_names[op_id] = {
- causal: [
- (call_causal_to_call[call_causal], name)
- for call_causal, name in call_causal_list
- ]
- for causal, call_causal_list in causal_to_call_uids_and_inp_names.items()
- }
- concat_lists = lambda l: [elt for sublist in l for elt in sublist]
- calls = [
- concat_lists(
- [
- [
- v[0]
- for v in op_to_causal_to_calls_and_inp_names[op_id].get(
- causal_uid, []
- )
- ]
- for op_id in op_to_causal_to_calls_and_inp_names.keys()
- ]
- )
- for causal_uid in causal_uids
- ]
- input_names = [
- concat_lists(
- [
- [
- v[1]
- for v in op_to_causal_to_calls_and_inp_names[op_id].get(
- causal_uid, []
- )
- ]
- for op_id in op_to_causal_to_calls_and_inp_names.keys()
- ]
- )
- for causal_uid in causal_uids
- ]
- return calls, input_names
-
- @transaction()
- def get_dependent_calls(
- self,
- refs: List[Ref],
- prov_df: Optional[pd.DataFrame] = None,
- conn: Optional[Connection] = None,
- ) -> List[Call]:
- """
- Get all calls that depend on the given refs.
- """
- if prov_df is None:
- prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn)
- prov_df = propagate_struct_provenance(prov_df=prov_df)
- res = {}
- current = refs
- while True:
- calls_list, _ = self.get_consumers(refs=current, prov_df=prov_df, conn=conn)
- if not calls_list:
- break
- for calls in calls_list:
- for call in calls:
- res[call.causal_uid] = call
- current = list(
- {
- ref.causal_uid: ref
- for calls in calls_list
- for call in calls
- for ref in call.outputs
- }.values()
- )
- return list(res.values())
-
- ############################################################################
- ### provenance
- ############################################################################
- @transaction()
- def prov(
- self,
- ref: Ref,
- conn: Optional[Connection] = None,
- uids_only: bool = False,
- debug: bool = False,
- ):
- prov_df = self.rel_storage.get_data(Config.provenance_table, conn=conn)
- prov_df = prov_df.set_index([Provenance.causal_uid, Provenance.direction])
- x = provenance.ProvHelpers(storage=self, prov_df=prov_df)
- val_nodes, call_nodes = x.get_graph(full_uid=ref.full_uid)
- show_sources_as = "values" if not uids_only else "uids"
- printer = GraphPrinter(
- vqs=val_nodes,
- fqs=call_nodes,
- names=None,
- value_loader=ValueLoader(storage=self),
- )
- print(printer.print_computational_graph(show_sources_as=show_sources_as))
- if debug:
- visualize_graph(
- vqs=val_nodes, fqs=call_nodes, names=None, show_how="browser"
- )
-
- ############################################################################
- ### spawning contexts
- ############################################################################
- def _nest(self, **updates) -> contexts.Context:
- if contexts.GlobalContext.current is not None:
- return contexts.GlobalContext.current(**updates)
- else:
- result = contexts.Context(**updates)
- contexts.GlobalContext.current = result
- return result
-
- def __call__(self, **updates) -> contexts.Context:
- return self.run(**updates)
-
- def run(
- self,
- allow_calls: bool = True,
- debug_calls: bool = False,
- attach_call_to_outputs: bool = False,
- recompute_transient: bool = False,
- lazy: Optional[bool] = None,
- **updates,
- ) -> contexts.Context:
- # spawn context to execute or retrace calls
- lazy = not self.in_memory if lazy is None else lazy
- return self._nest(
- storage=self,
- allow_calls=allow_calls,
- debug_calls=debug_calls,
- recompute_transient=recompute_transient,
- _attach_call_to_outputs=attach_call_to_outputs,
- mode=MODES.run,
- lazy=lazy,
- **updates,
- )
-
- def query(self, **updates) -> contexts.Context:
- # spawn a context to define a query
- return self._nest(
- storage=self,
- mode=MODES.query,
- **updates,
- )
-
- def batch(self, **updates) -> contexts.Context:
- # spawn a context to execute calls in batch
- return self._nest(
- storage=self,
- mode=MODES.batch,
- **updates,
- )
-
- def noop(self) -> contexts.Context:
- return self._nest(
- storage=self,
- mode=MODES.noop,
- )
-
- ############################################################################
- ### remote sync operations
- ############################################################################
- @transaction()
- def sync_from_remote(self, conn: Optional[Connection] = None):
- if self.remote_manager is not None:
- self.remote_manager.sync_from_remote(conn=conn)
-
- @transaction()
- def sync_to_remote(self, conn: Optional[Connection] = None):
- if self.remote_manager is not None:
- self.remote_manager.sync_to_remote(conn=conn)
-
- @transaction()
- def sync_with_remote(self, conn: Optional[Connection] = None):
- if self.remote_manager is not None:
- self.sync_to_remote(conn=conn)
- self.sync_from_remote(conn=conn)
-
- ############################################################################
- ### refactoring
- ############################################################################
- @property
- def is_clean(self) -> bool:
- """
- Check that the storage has no uncommitted calls or objects.
- """
- return (
- self.cache.call_cache_by_causal.is_clean and self.cache.obj_cache.is_clean
- )
-
- def _check_rename_precondition(self, func: "funcs.FuncInterface"):
- """
- In order to rename function data, the function must be synced with the
- storage, and the storage must be clean
- """
- if not func._is_synchronized:
- raise RuntimeError("Cannot rename while function is not synchronized.")
- if not self.is_clean:
- raise RuntimeError("Cannot rename while there is uncommited work.")
-
- @transaction()
- def rename_func(
- self,
- func: "funcs.FuncInterface",
- new_name: str,
- conn: Optional[Connection] = None,
- ) -> Signature:
- """
- Rename a memoized function.
-
- What happens here:
- - check renaming preconditions
- - check there is no name clash with the new name
- - rename the memoization table
- - update signature object
- - invalidate the function (making it impossible to compute with it)
- """
- self._check_rename_precondition(func=func)
- sig = self.sig_syncer.sync_rename_sig(
- sig=func.func_op.sig, new_name=new_name, conn=conn
- )
- func.invalidate()
- return sig
-
- @transaction()
- def rename_arg(
- self,
- func: "funcs.FuncInterface",
- name: str,
- new_name: str,
- conn: Optional[Connection] = None,
- ) -> Signature:
- """
- Rename memoized function argument.
-
- What happens here:
- - check renaming preconditions
- - update signature object
- - rename table
- - invalidate the function (making it impossible to compute with it)
- """
- self._check_rename_precondition(func=func)
- sig = self.sig_syncer.sync_rename_input(
- sig=func.func_op.sig, input_name=name, new_input_name=new_name, conn=conn
- )
- func.invalidate()
- return sig
-
-
-from . import funcs
-from . import provenance
-
-FuncInterface = funcs.FuncInterface
-
-
-TP_TO_CLS = {
- AnyType: ValueRef,
- ListType: ListRef,
- DictType: DictRef,
- SetType: SetRef,
-}
-
-
-def make_delayed(tp: Type) -> Ref:
- return TP_TO_CLS[type(tp)](uid="", obj=Delayed(), in_memory=False)
-
-
-class ValueLoader(ValueLoaderABC):
- def __init__(self, storage: Storage):
- self.storage = storage
-
- def load_value(self, full_uid: str) -> Any:
- uid, _ = Ref.parse_full_uid(full_uid)
- return self.storage.rel_adapter.obj_get(uid=uid)
diff --git a/mandala/ui/utils.py b/mandala/ui/utils.py
deleted file mode 100644
index 7f0ec04..0000000
--- a/mandala/ui/utils.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from ..common_imports import *
-from ..core.model import FuncOp, Ref, wrap_atom, Call
-from ..core.wrapping import unwrap, causify_atom
-from ..core.config import Config, MODES
-from ..queries.weaver import ValNode, qwrap, prepare_query
-from ..deps.model import TerminalData
-from ..deps.utils import get_dep_key_from_func
-from textwrap import shorten
-
-
-T = TypeVar("T")
-
-
-def check_determinism(
- observed_semver: Optional[str],
- stored_semver: Optional[str],
- stored_output_uids: List[str],
- observed_output_uids: List[str],
- func_op: FuncOp,
-):
- # check deterministic behavior
- if stored_semver != observed_semver:
- raise ValueError(
- f"Detected non-deterministic dependencies for function "
- f"{func_op.sig.ui_name} after recomputation of transient values."
- )
- if len(stored_output_uids) != len(observed_output_uids):
- raise ValueError(
- f"Detected non-deterministic number of outputs for function "
- f"{func_op.sig.ui_name} after recomputation of transient values."
- )
- if observed_output_uids != stored_output_uids:
- raise ValueError(
- f"Detected non-deterministic outputs for function "
- f"{func_op.sig.ui_name} after recomputation of transient values. "
- f"{observed_output_uids} != {stored_output_uids}"
- )
-
-
-def get_terminal_data(func_op: FuncOp, call: Call) -> TerminalData:
- return TerminalData(
- op_internal_name=func_op.sig.internal_name,
- op_version=func_op.sig.version,
- call_content_version=call.content_version,
- call_semantic_version=call.semantic_version,
- dep_key=get_dep_key_from_func(func=func_op.func),
- )
-
-
-def wrap_ui(obj: T, recurse: bool = True) -> T:
- if isinstance(obj, Ref):
- return obj
- elif type(obj) in (list, tuple):
- if recurse:
- return type(obj)(wrap_ui(v, recurse=recurse) for v in obj)
- else:
- return obj
- elif type(obj) is dict:
- if recurse:
- return {k: wrap_ui(v, recurse=recurse) for k, v in obj.items()}
- else:
- return obj
- else:
- res = wrap_atom(obj)
- causify_atom(ref=res)
- return res
-
-
-def bind_inputs(args, kwargs, mode: str, func_op: FuncOp) -> Dict[str, Any]:
- """
- Given args and kwargs passed by the user from python, this adds defaults
- and returns a dict where they are indexed via internal names.
- """
- if mode == MODES.query:
- bound_args = func_op.py_sig.bind_partial(*args, **kwargs)
- inputs_dict = dict(bound_args.arguments)
- input_tps = func_op.input_types
- inputs_dict = {
- k: qwrap(obj=v, tp=input_tps[k], strict=True)
- for k, v in inputs_dict.items()
- }
- else:
- bound_args = func_op.py_sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
- inputs_dict = dict(bound_args.arguments)
- return inputs_dict
- return inputs_dict
-
-
-def format_as_outputs(
- outputs: Union[List[Ref], List[ValNode]]
-) -> Union[None, Any, Tuple[Any]]:
- if len(outputs) == 0:
- return None
- elif len(outputs) == 1:
- return outputs[0]
- else:
- return tuple(outputs)
-
-
-def debug_call(
- func_name: str,
- memoized: bool,
- wrapped_inputs: Dict[str, Ref],
- wrapped_outputs: List[Ref],
- io_truncate: Optional[int] = 20,
-):
- shortener = lambda s: shorten(
- repr(unwrap(s)), width=io_truncate, break_long_words=True
- )
- inputs_str = ", ".join(f"{k}={shortener(v)}" for k, v in wrapped_inputs.items())
- outputs_str = ", ".join(shortener(v) for v in wrapped_outputs)
- logging.info(
- f'{"(memoized)" if memoized else ""}: {func_name}({inputs_str}) ---> {outputs_str}'
- )
diff --git a/mandala/ui/viz.py b/mandala/ui/viz.py
deleted file mode 100644
index 8326bb9..0000000
--- a/mandala/ui/viz.py
+++ /dev/null
@@ -1,375 +0,0 @@
-"""
-A minimal OOP wrapper around dot/graphviz
-"""
-import difflib
-from ..common_imports import *
-from ..core.config import Config
-import tempfile
-import subprocess
-import webbrowser
-from typing import Literal
-
-if Config.has_pil:
- from PIL import Image
-
-
-class Color:
- def __init__(self, r: int, g: int, b: int, opacity: float = 1.0):
- self.r, self.g, self.b, self.opacity = r, g, b, opacity
-
- def __str__(self) -> str:
- opacity_int = int(self.opacity * 255)
- return f"#{self.r:02x}{self.g:02x}{self.b:02x}{opacity_int:02x}"
-
-
-SOLARIZED_LIGHT = {
- "base03": Color(0, 43, 54, 1),
- "base02": Color(7, 54, 66, 1),
- "base01": Color(88, 110, 117, 1),
- "base00": Color(101, 123, 131, 1),
- "base0": Color(131, 148, 150, 1),
- "base1": Color(147, 161, 161, 1),
- "base2": Color(238, 232, 213, 1),
- "base3": Color(253, 246, 227, 1),
- "yellow": Color(181, 137, 0, 1),
- "orange": Color(203, 75, 22, 1),
- "red": Color(220, 50, 47, 1),
- "magenta": Color(211, 54, 130, 1),
- "violet": Color(108, 113, 196, 1),
- "blue": Color(38, 139, 210, 1),
- "cyan": Color(42, 161, 152, 1),
- "green": Color(133, 153, 0, 1),
-}
-
-
-def _colorize(text: str, color: str) -> str:
- """
- Return `text` with ANSI color codes for `color` added.
- """
- colors = {
- "red": 31,
- "green": 32,
- "blue": 34,
- "yellow": 33,
- "magenta": 35,
- "cyan": 36,
- "white": 37,
- }
- return f"\033[{colors[color]}m{text}\033[0m"
-
-
-def _get_diff(current: str, new: str) -> str:
- """
- Return a line-by-line diff of the changes between `current` and `new`. each
- line removed from `current` is prefixed with a '-', and each line added to
- `new` is prefixed with a '+'.
- """
- lines = []
- for line in difflib.unified_diff(
- current.splitlines(),
- new.splitlines(),
- n=2, # number of lines of context around changes to show
- # fromfile="current", tofile="new"
- lineterm="",
- ):
- if line.startswith("@@") or line.startswith("+++") or line.startswith("---"):
- continue
- lines.append(line)
- return "\n".join(lines)
-
-
-def _get_colorized_diff(
- current: str, new: str, style: str = "multiline", context_lines: int = 2
-) -> str:
- """
- Return a line-by-line colorized diff of the changes between `current` and
- `new`. each line removed from `current` is colored red, and each line added
- to `new` is colored green.
- """
- lines = []
- for line in difflib.unified_diff(
- current.splitlines(),
- new.splitlines(),
- n=context_lines, # number of lines of context around changes to show
- # fromfile="current", tofile="new"
- lineterm="",
- ):
- if line.startswith("@@") or line.startswith("+++") or line.startswith("---"):
- continue
- if line.startswith("-"):
- if style == "inline":
- line = line[1:]
- lines.append(_colorize(line, "red"))
- elif line.startswith("+"):
- if style == "inline":
- line = line[1:]
- lines.append(_colorize(line, "green"))
- else:
- lines.append(line)
- if style == "multiline":
- return "\n".join(lines)
- elif style == "inline":
- return " ---> ".join(lines)
- else:
- raise ValueError(f"Unknown style: {style}")
-
-
-################################################################################
-### tiny graphviz model
-################################################################################
-class Cell:
- def __init__(
- self,
- text: str,
- port: Optional[str] = None,
- colspan: int = 1,
- bgcolor: Color = SOLARIZED_LIGHT["base3"],
- border_color: Color = SOLARIZED_LIGHT["base03"], # border color
- font_color: Optional[Color] = None,
- bold: bool = False,
- ):
- self.port = port
- self.text = text
- self.colspan = colspan
- self.bgcolor = bgcolor
- self.border_color = border_color
- self.font_color = font_color
- self.bold = bold
-
- def to_dot_string(self) -> str:
- text_str = self.text
- if self.bold:
- text_str = f"{text_str}"
- if self.font_color is not None:
- text_str = f'{text_str}'
- return f'{text_str} | '
-
-
-class HTMLBuilder:
- def __init__(self):
- self.rows: List[List[Cell]] = []
-
- def add_row(self, cells: List[Cell]):
- self.rows.append(cells)
-
- def to_html_like_label(self) -> str:
- start = '"
- # broadcast colspan
- row_sizes = set([len(row) for row in self.rows])
- lcm = np.lcm.reduce(list(row_sizes))
- for row in self.rows:
- elt_colspan = lcm // len(row)
- for cell in row:
- cell.colspan = elt_colspan
-
- rows = []
- for row in self.rows:
- rows.append("")
- for cell in row:
- rows.append(cell.to_dot_string())
- rows.append("
")
- return start + "".join(rows) + end
-
-
-class Node:
- def __init__(
- self,
- internal_name: str,
- label: str,
- color: Color = SOLARIZED_LIGHT["base3"],
- shape: str = "rect",
- ):
- """
- `shape` can be "rect", "record" or "Mrecord" for a record with rounded corners.
- """
- self.internal_name = internal_name
- self.label = label
- self.color = color
- self.shape = shape
-
- def to_dot_string(self) -> str:
- dot_label = f'"{self.label}"' if self.shape != "plain" else f"<{self.label}>"
- return f'"{self.internal_name}" [label={dot_label}, color="{self.color}", shape="{self.shape}"];'
-
-
-class Edge:
- def __init__(
- self,
- source_node: Node,
- target_node: Node,
- source_port: Optional[str] = None,
- target_port: Optional[str] = None,
- arrowtail: Optional[str] = None,
- arrowhead: Optional[str] = None,
- label: str = "",
- color: Color = SOLARIZED_LIGHT["base03"],
- ):
- self.source_node = source_node
- self.target_node = target_node
- self.color = color
- self.label = label
- self.source_port = source_port
- self.target_port = target_port
- self.arrowtail = arrowtail
- self.arrowhead = arrowhead
-
- def to_dot_string(self) -> str:
- source = f'"{self.source_node.internal_name}"'
- target = f'"{self.target_node.internal_name}"'
- if self.source_port is not None:
- source += f":{self.source_port}"
- if self.target_port is not None:
- target += f":{self.target_port}"
- attrs = [f'label="{self.label}"', f'color="{self.color}"']
- if self.arrowtail is not None:
- attrs.append(f'arrowtail="{self.arrowtail}"')
- if self.arrowhead is not None:
- attrs.append(f'arrowhead="{self.arrowhead}"')
- return f"{source} -> {target} [{', '.join(attrs)}];"
-
-
-class Group:
- def __init__(
- self,
- label: str,
- nodes: List[Node],
- parent: Optional["Group"] = None,
- border_color: Color = SOLARIZED_LIGHT["base03"],
- ):
- self.label = label
- self.nodes = nodes
- self.border_color = border_color
- self.parent = parent
-
-
-FONT = "Helvetica"
-FONT_SIZE = 10
-
-GRAPH_CONFIG = {
- # "overlap": "scalexy",
- "overlap": "scale",
- "rankdir": "TB", # top to bottom
- "fontname": FONT,
- "fontsize": FONT_SIZE,
- "fontcolor": SOLARIZED_LIGHT["base03"],
-}
-
-NODE_CONFIG = {
- "style": "rounded",
- "shape": "rect",
- "fontname": FONT,
- "fontsize": FONT_SIZE,
- "fontcolor": SOLARIZED_LIGHT["base03"],
-}
-
-EDGE_CONFIG = {
- "fontname": FONT,
- "fontsize": FONT_SIZE,
- "fontcolor": SOLARIZED_LIGHT["base03"],
-}
-
-
-def dict_to_dot_string(d: Dict[str, Any]) -> str:
- """Converts a dict to a dot string"""
- return ",".join([f'{k}="{v}"' for k, v in d.items()])
-
-
-def _get_group_string_shallow(group: Group, children_string: str) -> str:
- nodes_string = " ".join([f'"{node.internal_name}"' for node in group.nodes])
- return f'subgraph "cluster_{group.label}" {{style="rounded"; label="{group.label}"; color="{group.border_color}"; {nodes_string};\n {children_string} }}'
-
-
-def get_group_string(group: Group, groups_forest: Dict[Group, List[Group]]) -> str:
- children = groups_forest.get(group, [])
- return _get_group_string_shallow(
- group,
- children_string="\n".join(
- [get_group_string(child, groups_forest=groups_forest) for child in children]
- ),
- )
-
-
-def to_dot_string(
- nodes: List[Node],
- edges: List[Edge],
- groups: List[Group],
- rankdir: Literal["TB", "BT", "LR", "RL"] = "TB",
-) -> str:
- """Converts a graph to a dot string"""
- joiner = "\n" + " " * 12
- ### global config
- graph_config = copy.deepcopy(GRAPH_CONFIG)
- graph_config["rankdir"] = rankdir
- graph_config = f"graph [ {dict_to_dot_string(graph_config)} ];"
- node_config = f"node [ {dict_to_dot_string(NODE_CONFIG)} ];"
- edge_config = f"edge [ {dict_to_dot_string(EDGE_CONFIG)} ];"
- preamble = joiner.join([graph_config, node_config, edge_config])
- ### nodes
- node_strings = []
- for node in nodes:
- node_strings.append(node.to_dot_string())
- nodes_part = joiner.join(node_strings)
- ### edges
- edge_strings = []
- for edge in edges:
- edge_strings.append(edge.to_dot_string())
- edges_part = joiner.join(edge_strings)
- ### groups
- groups_forest = {
- group: [g for g in groups if g.parent is group] for group in groups
- }
- roots = [group for group in groups if group.parent is None]
- group_strings = []
- for group in roots:
- group_strings.append(get_group_string(group, groups_forest=groups_forest))
- groups_part = joiner.join(group_strings)
- result = f"""
- digraph G {{
- {preamble}
- {nodes_part}
- {edges_part}
- {groups_part}
- }}
- """
- return result
-
-
-def write_output(
- dot_string: str,
- output_ext: str,
- output_path: Optional[Path] = None,
- show_how: Literal["none", "browser", "inline", "open"] = "none",
-):
- """
- Writes the given dot string to a dot file (temp file by default) and then
- optionally shows it in the browser, opens it in a program, or does nothing.
- """
- # make a temp file and write the dot string to it
- if output_path is None:
- tfile = tempfile.NamedTemporaryFile(suffix=f".{output_ext}", delete=False)
- output_path = Path(tfile.name)
- with tempfile.NamedTemporaryFile(mode="w", delete=True) as f:
- path = f.name
- with open(path, "w") as f:
- f.write(dot_string)
- cmd = f"dot -T{output_ext} -o{output_path} {path}"
- subprocess.call(cmd, shell=True)
- if show_how == "browser":
- assert output_ext in [
- "png",
- "jpg",
- "jpeg",
- "svg",
- ], "Can only show png, jpg, jpeg, or svg in browser"
- webbrowser.open(str(output_path))
- return
- if show_how == "inline" or show_how == "open":
- assert (
- Config.has_pil
- ), "Pillow is not installed. Please install it to show images inline"
- img = Image.open(output_path, "r")
- if show_how == "inline":
- return img
- else:
- img.show()
diff --git a/mandala/utils.py b/mandala/utils.py
index d3de54a..bbfb65a 100644
--- a/mandala/utils.py
+++ b/mandala/utils.py
@@ -1,4 +1,25 @@
from .common_imports import *
+import joblib
+import io
+import inspect
+import prettytable
+import sqlite3
+from .config import *
+from abc import ABC, abstractmethod
+
+def dataframe_to_prettytable(df: pd.DataFrame) -> str:
+ # Initialize a PrettyTable object
+ table = prettytable.PrettyTable()
+
+ # Set the column names
+ table.field_names = df.columns.tolist()
+
+ # Add rows to the table
+ for row in df.itertuples(index=False):
+ table.add_row(row)
+
+ # Return the pretty-printed table as a string
+ return table.get_string()
def serialize(obj: Any) -> bytes:
@@ -18,25 +39,201 @@ def deserialize(value: bytes) -> Any:
return joblib.load(buffer)
-def _rename_cols_pandas(df: pd.DataFrame, mapping: Dict[str, str]) -> pd.DataFrame:
- return df.rename(columns=mapping, inplace=False)
+def get_content_hash(obj: Any) -> str:
+ if hasattr(obj, "__get_mandala_dict__"):
+ obj = obj.__get_mandala_dict__()
+ if Config.has_torch:
+ # TODO: ideally, should add a label to distinguish this from a numpy
+ # array with the same contents!
+ obj = tensor_to_numpy(obj)
+ if isinstance(obj, pd.DataFrame):
+ # DataFrames cause collisions for joblib hashing for some reason
+ # TODO: the below may be incomplete
+ obj = {
+ "columns": obj.columns,
+ "values": obj.values,
+ "index": obj.index,
+ }
+ result = joblib.hash(obj) # this hash is canonical wrt python collections
+ if result is None:
+ raise RuntimeError("joblib.hash returned None")
+ return result
+
+
+def dump_output_name(index: int, output_names: Optional[List[str]] = None) -> str:
+ if output_names is not None and index < len(output_names):
+ return output_names[index]
+ else:
+ return f"output_{index}"
+
+
+def parse_output_name(name: str) -> int:
+ return int(name.split("_")[-1])
+
+
+def get_setdict_union(
+ a: Dict[str, Set[str]], b: Dict[str, Set[str]]
+) -> Dict[str, Set[str]]:
+ return {k: a.get(k, set()) | b.get(k, set()) for k in a.keys() | b.keys()}
+
+
+def get_setdict_intersection(
+ a: Dict[str, Set[str]], b: Dict[str, Set[str]]
+) -> Dict[str, Set[str]]:
+ return {k: a[k] & b[k] for k in a.keys() & b.keys()}
+
+
+def get_dict_union_over_keys(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
+ return {k: a[k] if k in a else b[k] for k in a.keys() | b.keys()}
+
+
+def get_dict_intersection_over_keys(
+ a: Dict[str, Any], b: Dict[str, Any]
+) -> Dict[str, Any]:
+ return {k: a[k] for k in a.keys() & b.keys()}
+
+
+def get_adjacency_union(
+ a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]]
+) -> Dict[str, Dict[str, Set[str]]]:
+ return {
+ k: get_setdict_union(a.get(k, {}), b.get(k, {})) for k in a.keys() | b.keys()
+ }
+
+
+def get_adjacency_intersection(
+ a: Dict[str, Dict[str, Set[str]]], b: Dict[str, Dict[str, Set[str]]]
+) -> Dict[str, Dict[str, Set[str]]]:
+ return {k: get_setdict_intersection(a[k], b[k]) for k in a.keys() & b.keys()}
-def _rename_cols_arrow(table: pa.Table, mapping: Dict[str, str]) -> pa.Table:
- columns = table.column_names
- new_columns = [mapping.get(col, col) for col in columns]
- table = table.rename_columns(new_columns)
- return table
+def get_nullable_union(*sets: Set[str]) -> Set[str]:
+ return set.union(*sets) if len(sets) > 0 else set()
-def _rename_cols(table: TableType, mapping: Dict[str, str]) -> TableType:
- if isinstance(table, pd.DataFrame):
- return _rename_cols_pandas(df=table, mapping=mapping)
- elif isinstance(table, pa.Table):
- return _rename_cols_arrow(table=table, mapping=mapping)
+def get_nullable_intersection(*sets: Set[str]) -> Set[str]:
+ return set.intersection(*sets) if len(sets) > 0 else set()
+
+
+def get_adj_from_edges(
+ edges: Set[Tuple[str, str, str]], node_support: Optional[Set[str]] = None
+) -> Tuple[Dict[str, Dict[str, Set[str]]], Dict[str, Dict[str, Set[str]]]]:
+ """
+ Given edges, convert them into the adjacency representation used by the
+ `ComputationFrame` class.
+ """
+ out = {}
+ inp = {}
+ for src, dst, label in edges:
+ if src not in out:
+ out[src] = {}
+ if label not in out[src]:
+ out[src][label] = set()
+ out[src][label].add(dst)
+ if dst not in inp:
+ inp[dst] = {}
+ if label not in inp[dst]:
+ inp[dst][label] = set()
+ inp[dst][label].add(src)
+ if node_support is not None:
+ for node in node_support:
+ if node not in out:
+ out[node] = {}
+ if node not in inp:
+ inp[node] = {}
+ return out, inp
+
+
+def parse_returns(
+ sig: inspect.Signature,
+ returns: Any,
+ nout: Union[Literal["auto", "var"], int],
+ output_names: Optional[List[str]] = None,
+) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ Return two dicts based on the returns:
+ - {output name: output value}
+ - {output name: output type annotation}, where things like `Tuple[T, ...]` are expanded.
+ """
+ ### figure out the number of outputs, and convert them to a tuple
+ if nout == "auto": # infer from the returns
+ if isinstance(returns, tuple):
+ nout = len(returns)
+ returns_tuple = returns
+ else:
+ nout = 1
+ returns_tuple = (returns,)
+ elif nout == "var":
+ assert isinstance(returns, tuple)
+ nout = len(returns)
+ returns_tuple = returns
+ else: # nout is an integer
+ assert isinstance(nout, int)
+ assert isinstance(returns, tuple)
+ assert len(returns) == nout
+ returns_tuple = returns
+ ### get the dict of outputs
+ outputs_dict = {
+ dump_output_name(i, output_names): returns_tuple[i] for i in range(nout)
+ }
+ ### figure out the annotations
+ annotations_dict = {}
+ output_annotation = sig.return_annotation
+ if output_annotation is inspect._empty: # no annotation
+ annotations_dict = {k: Any for k in outputs_dict.keys()}
else:
- raise NotImplementedError(f"rename_cols not implemented for {type(table)}")
+ if (
+ hasattr(output_annotation, "__origin__")
+ and output_annotation.__origin__ is tuple
+ ):
+ if (
+ len(output_annotation.__args__) == 2
+ and output_annotation.__args__[1] == Ellipsis
+ ):
+ annotations_dict = {
+ k: output_annotation.__args__[0] for k in outputs_dict.keys()
+ }
+ else:
+ annotations_dict = {
+ k: output_annotation.__args__[i]
+ for i, k in enumerate(outputs_dict.keys())
+ }
+ else:
+ assert nout == 1
+ annotations_dict = {k: output_annotation for k in outputs_dict.keys()}
+ return outputs_dict, annotations_dict
+
+def unwrap_decorators(
+ obj: Callable, strict: bool = True
+) -> Union[types.FunctionType, types.MethodType]:
+ while hasattr(obj, "__wrapped__"):
+ obj = obj.__wrapped__
+ if not isinstance(obj, (types.FunctionType, types.MethodType)):
+ msg = f"Expected a function or method, but got {type(obj)}"
+ if strict:
+ raise RuntimeError(msg)
+ else:
+ logger.debug(msg)
+ return obj
+
+def is_subdict(a: Dict, b: Dict) -> bool:
+ """
+ Check that all keys in `a` are in `b` with the same value.
+ """
+ return all((k in b and a[k] == b[k]) for k in a)
+
+_KT, _VT = TypeVar("_KT"), TypeVar("_VT")
+def invert_dict(d: Dict[_KT, _VT]) -> Dict[_VT, List[_KT]]:
+ """
+ Invert a dictionary
+ """
+ out = {}
+ for k, v in d.items():
+ if v not in out:
+ out[v] = []
+ out[v].append(k)
+ return out
def ask_user(question: str, valid_options: List[str]) -> str:
"""
diff --git a/mandala/_next/viz.py b/mandala/viz.py
similarity index 100%
rename from mandala/_next/viz.py
rename to mandala/viz.py