Skip to content

Commit

Permalink
First cut at new lazy Python API (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
sppalkia authored Feb 11, 2020
1 parent 7cb875e commit 225c2d5
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 37 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ python/pyweld/build/
python/pyweld/dist/
python/grizzly/build
*.egg-info/
weldenv/

# Environments etc.
weld-python/doc/build
site-packages/
weld-dev/
weldenv/

llvmext/test

Expand Down
2 changes: 1 addition & 1 deletion weld-python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from setuptools_rust import RustExtension

setup_requires = ["setuptools-rust>=0.10.1", "wheel"]
install_requires = []
install_requires = ["numpy"]

setup(
name="weld",
Expand Down
2 changes: 1 addition & 1 deletion weld-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trait ToPyErr<T, E> {

impl<T> ToPyErr<T, weld::WeldError> for weld::WeldResult<T> {
fn to_py(self) -> PyResult<T> {
self.map_err(|e| PyErr::new::<WeldError, _>(e.message().to_str().unwrap().to_string()))
self.map_err(|e| WeldError::py_err(e.message().to_str().unwrap().to_string()))
}
}

Expand Down
51 changes: 31 additions & 20 deletions weld-python/tests/compile_test.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
"""
Ad-hoc test for compilation.
This needs to be integrated with the pytest test suite.
Tests for compilation.
"""

import pytest

import weld
import weld.compile
import weld.types

# Single argument program.
program = "|x: i32| x + 1"
add_one = weld.compile.compile(program, [weld.types.I32()], [None], weld.types.I32(), None)

print(add_one(1))
print(add_one(5))


inner = weld.types.WeldStruct((weld.types.I32(), weld.types.I32()))
outer = weld.types.WeldStruct((weld.types.I32(), inner))
print(outer)

program = "|x: i32, y: i32| {x + y, {1, 1}}"
add = weld.compile.compile(program,
def test_simple():
# Single argument program.
program = "|x: i32| x + 1"
add_one = weld.compile.compile(program,
[weld.types.I32()],
[None],
weld.types.I32(),
None)
assert add_one(1)[0] == 2
assert add_one(5)[0] == 6
assert add_one(-4)[0] == -3

def test_exception():
with pytest.raises(weld.WeldError):
# Single argument program.
program = "|x: i32| x + ERROR"
program = weld.compile.compile(program,
[weld.types.I32()],
[None],
weld.types.I32(),
None)

def test_nested():
inner = weld.types.WeldStruct((weld.types.I32(), weld.types.I32()))
outer = weld.types.WeldStruct((weld.types.I32(), inner))
program = "|x: i32, y: i32| {x + y, {1, 1}}"
add = weld.compile.compile(program,
[weld.types.I32(), weld.types.I32()],
[None, None],
outer,
None)

print(add(5, 5))
assert add(5, 5)[0] == (10, (1, 1))
59 changes: 59 additions & 0 deletions weld-python/tests/lazy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Tests for constructing and evaluating lazy operations.
"""

from weld.encoders import PrimitiveWeldEncoder, PrimitiveWeldDecoder
from weld.types import *
from weld.lazy import *

def test_simple():
a = 1
b = 2
a = PhysicalValue(1, I32(), PrimitiveWeldEncoder())
b = PhysicalValue(2, I32(), PrimitiveWeldEncoder())
comp1 = WeldLazy(
"{0} + {1}".format(a.id, b.id),
[a, b],
I32(),
PrimitiveWeldDecoder())
x, _ = comp1.evaluate()
assert x == 3

def test_dependencies():
a = 1
b = 2
a = PhysicalValue(1, I32(), PrimitiveWeldEncoder())
b = PhysicalValue(2, I32(), PrimitiveWeldEncoder())
comp1 = WeldLazy(
"{0} + {1}".format(a.id, b.id),
[a, b],
I32(),
PrimitiveWeldDecoder())
comp2 = WeldLazy(
"{0} + {1}".format(comp1.id, comp1.id),
# These should be de-duplicated.
[comp1, comp1],
I32(),
PrimitiveWeldDecoder())
x, _ = comp2.evaluate()
assert x == 6

def test_long_chain():
a = 1
b = 2
a = PhysicalValue(1, I32(), PrimitiveWeldEncoder())
b = PhysicalValue(2, I32(), PrimitiveWeldEncoder())
def add_previous(prev1, prev2):
# Returns an expression representing prev1 + prev2
return WeldLazy(
"{0} + {1}".format(prev1.id, prev2.id),
[prev1, prev2],
I32(),
PrimitiveWeldDecoder())

length = 100
expr = add_previous(a, b)
for i in range(length):
expr = add_previous(expr, a)
x, _ = expr.evaluate()
assert x == (length + 3)
4 changes: 3 additions & 1 deletion weld-python/weld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
Weld bindings.
"""

from .error import WeldError
# Compiled from Rust bindings.
from .core import *
# Currently required to get WeldError interop with Rust
from .error import WeldError
216 changes: 216 additions & 0 deletions weld-python/weld/lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@

import weld.compile

from abc import ABC, abstractmethod
from collections import namedtuple

class NodeId(object):
"""
A node ID, which provides a unique name for a node in a tree.
"""
__slots__ = ['name']

def __eq__(self, other):
return self.name == other.name

def __hash__(self):
return hash(self.name)

def __init__(self, name):
self.name = name

def __str__(self):
return self.name

class WeldNode(ABC):
"""
Base class for nodes encapsulating a DAG of Weld computations.
"""

# ---------------------- Abstract Methods ------------------------------

@property
@abstractmethod
def children(self):
""" List of nodes this node depends on. """
pass

@property
@abstractmethod
def output_type(self):
""" The Weld output type of this node. """
pass


# ---------------------- Provided Methods ------------------------------

@classmethod
def prefix(cls):
"""
Prefix used for naming identifiers generated by this node.
By default, this is the class name lowercased. This can be overridden.
"""
return cls.__name__.lower()

@classmethod
def counter_(cls):
if not hasattr(cls, "counter_value_"):
setattr(cls, "counter_value_", 0)
return getattr(cls, "counter_value_")

@classmethod
def set_counter_(cls, value):
setattr(cls, "counter_value_", value)

@classmethod
def generate_id(cls):
""" Generates a unique ID for this node. """
cur_value = cls.counter_()
cur_value += 1
cls.set_counter_(cur_value)
return NodeId("{0}{1}".format(cls.prefix(), cur_value))

@property
def id(self):
if not hasattr(self, "node_id_"):
setattr(self, "node_id_", self.generate_id())
return getattr(self, "node_id_")

def _walk_bottomup(self, f, context, visited):
""" Recursive bottom up DAG walk implementation. """
if self in visited:
return
visited.add(self)
for dep in self.children:
dep._walk_bottomup(f, context, visited)
f(self, context)

def walk(self, f, context, mode="bottomup"):
""" Walk the DAG in the specified order.
Each node in the DAG is visited exactly once.
Parameters
__________
f : A function to apply to each record. The function takes an operation
and an optional context (i.e., any object) as arguments.
context : An initial context.
mode : The order in which to process the DAG. "topdown" (the default)
traverses each node as its visited in breadth-first order. "bottomup"
traverses the graph depth-first, so the roots are visited after the
leaves (i.e., nodes are represented in "execution order" where
dependencies are processed first).
"""

if mode == "bottomup":
return self._walk_bottomup(f, context, set())

assert mode == "topdown"

visited = set()
queue = deque([self])
while len(queue) != 0:
cur = queue.popleft()
if cur not in visited:
f(cur, context)
visited.add(cur)
for child in cur.children:
queue.append(child)

def __eq__(self, other):
return self.id == other.id

def __hash__(self):
return hash(self.id)


class PhysicalValue(WeldNode):
"""
A physical value that a lazy computation depends on.
"""
def __init__(self, value, ty, encoder):
self.value = value
self.encoder = encoder
self.ty_ = ty

@classmethod
def prefix(cls):
return "inp"

@property
def children(self):
return []

@property
def output_type(self):
return self.ty_

class WeldLazy(WeldNode):
"""
A lazy value that encapsulates a Weld computation.
"""

def __init__(self, expression, dependencies, ty, decoder):
"""
Creates a new lazy Weld computation.
Parameters
----------
expression : str
A weld expression.
dependencies : list[WeldNode]
A list of dependencies. The expression should only use names
from this list.
ty : WeldType
The output type of this computation.
decoder : A decoder for decoding the Weld result of this computation.
"""
# Remove duplicates here
self.children_ = list(set(dependencies))
self.expression = expression
self.decoder = decoder
self.ty_ = ty

@property
def children(self):
return self.children_

@property
def output_type(self):
return self.ty_

def _create_function_header(self, inputs):
arguments = ["{0}: {1}".format(inp.id, str(inp.output_type)) for inp in inputs]
return "|" + ", ".join(arguments) + "|"

def evaluate(self):
# Collect nodes in execution order.
nodes_to_execute = []
self.walk(lambda node, expressions: expressions.append(node), nodes_to_execute)

# Inputs are PhysicalValue objects.
inputs = [node for node in nodes_to_execute if isinstance(node, PhysicalValue)]
inputs.sort(key=lambda e: e.id.name)
arg_types = [inp.output_type for inp in inputs]
encoders = [inp.encoder for inp in inputs]
values = [inp.value for inp in inputs]

# Collect the expressions from the remaining nodes.
expressions = [
"let {name} = ({expr});".format(name=node.id, expr=node.expression) for node in nodes_to_execute if isinstance(node, WeldLazy)]
assert nodes_to_execute[-1] is self
expressions.append(str(self.id))

program = self._create_function_header(inputs) + " " + "\n".join(expressions)
print(program)

# TODO(shoumik): cache me!
program = weld.compile.compile(program, arg_types, encoders, self.output_type, self.decoder)
return program(*values)
Loading

0 comments on commit 225c2d5

Please sign in to comment.