-
Notifications
You must be signed in to change notification settings - Fork 256
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First cut at new lazy Python API (#498)
- Loading branch information
Showing
8 changed files
with
328 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,42 @@ | ||
""" | ||
Ad-hoc test for compilation. | ||
This needs to be integrated with the pytest test suite. | ||
Tests for compilation. | ||
""" | ||
|
||
import pytest | ||
|
||
import weld | ||
import weld.compile | ||
import weld.types | ||
|
||
# Single argument program. | ||
program = "|x: i32| x + 1" | ||
add_one = weld.compile.compile(program, [weld.types.I32()], [None], weld.types.I32(), None) | ||
|
||
print(add_one(1)) | ||
print(add_one(5)) | ||
|
||
|
||
inner = weld.types.WeldStruct((weld.types.I32(), weld.types.I32())) | ||
outer = weld.types.WeldStruct((weld.types.I32(), inner)) | ||
print(outer) | ||
|
||
program = "|x: i32, y: i32| {x + y, {1, 1}}" | ||
add = weld.compile.compile(program, | ||
def test_simple(): | ||
# Single argument program. | ||
program = "|x: i32| x + 1" | ||
add_one = weld.compile.compile(program, | ||
[weld.types.I32()], | ||
[None], | ||
weld.types.I32(), | ||
None) | ||
assert add_one(1)[0] == 2 | ||
assert add_one(5)[0] == 6 | ||
assert add_one(-4)[0] == -3 | ||
|
||
def test_exception(): | ||
with pytest.raises(weld.WeldError): | ||
# Single argument program. | ||
program = "|x: i32| x + ERROR" | ||
program = weld.compile.compile(program, | ||
[weld.types.I32()], | ||
[None], | ||
weld.types.I32(), | ||
None) | ||
|
||
def test_nested(): | ||
inner = weld.types.WeldStruct((weld.types.I32(), weld.types.I32())) | ||
outer = weld.types.WeldStruct((weld.types.I32(), inner)) | ||
program = "|x: i32, y: i32| {x + y, {1, 1}}" | ||
add = weld.compile.compile(program, | ||
[weld.types.I32(), weld.types.I32()], | ||
[None, None], | ||
outer, | ||
None) | ||
|
||
print(add(5, 5)) | ||
assert add(5, 5)[0] == (10, (1, 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
Tests for constructing and evaluating lazy operations. | ||
""" | ||
|
||
from weld.encoders import PrimitiveWeldEncoder, PrimitiveWeldDecoder | ||
from weld.types import * | ||
from weld.lazy import * | ||
|
||
def test_simple(): | ||
a = 1 | ||
b = 2 | ||
a = PhysicalValue(1, I32(), PrimitiveWeldEncoder()) | ||
b = PhysicalValue(2, I32(), PrimitiveWeldEncoder()) | ||
comp1 = WeldLazy( | ||
"{0} + {1}".format(a.id, b.id), | ||
[a, b], | ||
I32(), | ||
PrimitiveWeldDecoder()) | ||
x, _ = comp1.evaluate() | ||
assert x == 3 | ||
|
||
def test_dependencies(): | ||
a = 1 | ||
b = 2 | ||
a = PhysicalValue(1, I32(), PrimitiveWeldEncoder()) | ||
b = PhysicalValue(2, I32(), PrimitiveWeldEncoder()) | ||
comp1 = WeldLazy( | ||
"{0} + {1}".format(a.id, b.id), | ||
[a, b], | ||
I32(), | ||
PrimitiveWeldDecoder()) | ||
comp2 = WeldLazy( | ||
"{0} + {1}".format(comp1.id, comp1.id), | ||
# These should be de-duplicated. | ||
[comp1, comp1], | ||
I32(), | ||
PrimitiveWeldDecoder()) | ||
x, _ = comp2.evaluate() | ||
assert x == 6 | ||
|
||
def test_long_chain(): | ||
a = 1 | ||
b = 2 | ||
a = PhysicalValue(1, I32(), PrimitiveWeldEncoder()) | ||
b = PhysicalValue(2, I32(), PrimitiveWeldEncoder()) | ||
def add_previous(prev1, prev2): | ||
# Returns an expression representing prev1 + prev2 | ||
return WeldLazy( | ||
"{0} + {1}".format(prev1.id, prev2.id), | ||
[prev1, prev2], | ||
I32(), | ||
PrimitiveWeldDecoder()) | ||
|
||
length = 100 | ||
expr = add_previous(a, b) | ||
for i in range(length): | ||
expr = add_previous(expr, a) | ||
x, _ = expr.evaluate() | ||
assert x == (length + 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
|
||
import weld.compile | ||
|
||
from abc import ABC, abstractmethod | ||
from collections import namedtuple | ||
|
||
class NodeId(object): | ||
""" | ||
A node ID, which provides a unique name for a node in a tree. | ||
""" | ||
__slots__ = ['name'] | ||
|
||
def __eq__(self, other): | ||
return self.name == other.name | ||
|
||
def __hash__(self): | ||
return hash(self.name) | ||
|
||
def __init__(self, name): | ||
self.name = name | ||
|
||
def __str__(self): | ||
return self.name | ||
|
||
class WeldNode(ABC): | ||
""" | ||
Base class for nodes encapsulating a DAG of Weld computations. | ||
""" | ||
|
||
# ---------------------- Abstract Methods ------------------------------ | ||
|
||
@property | ||
@abstractmethod | ||
def children(self): | ||
""" List of nodes this node depends on. """ | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def output_type(self): | ||
""" The Weld output type of this node. """ | ||
pass | ||
|
||
|
||
# ---------------------- Provided Methods ------------------------------ | ||
|
||
@classmethod | ||
def prefix(cls): | ||
""" | ||
Prefix used for naming identifiers generated by this node. | ||
By default, this is the class name lowercased. This can be overridden. | ||
""" | ||
return cls.__name__.lower() | ||
|
||
@classmethod | ||
def counter_(cls): | ||
if not hasattr(cls, "counter_value_"): | ||
setattr(cls, "counter_value_", 0) | ||
return getattr(cls, "counter_value_") | ||
|
||
@classmethod | ||
def set_counter_(cls, value): | ||
setattr(cls, "counter_value_", value) | ||
|
||
@classmethod | ||
def generate_id(cls): | ||
""" Generates a unique ID for this node. """ | ||
cur_value = cls.counter_() | ||
cur_value += 1 | ||
cls.set_counter_(cur_value) | ||
return NodeId("{0}{1}".format(cls.prefix(), cur_value)) | ||
|
||
@property | ||
def id(self): | ||
if not hasattr(self, "node_id_"): | ||
setattr(self, "node_id_", self.generate_id()) | ||
return getattr(self, "node_id_") | ||
|
||
def _walk_bottomup(self, f, context, visited): | ||
""" Recursive bottom up DAG walk implementation. """ | ||
if self in visited: | ||
return | ||
visited.add(self) | ||
for dep in self.children: | ||
dep._walk_bottomup(f, context, visited) | ||
f(self, context) | ||
|
||
def walk(self, f, context, mode="bottomup"): | ||
""" Walk the DAG in the specified order. | ||
Each node in the DAG is visited exactly once. | ||
Parameters | ||
__________ | ||
f : A function to apply to each record. The function takes an operation | ||
and an optional context (i.e., any object) as arguments. | ||
context : An initial context. | ||
mode : The order in which to process the DAG. "topdown" (the default) | ||
traverses each node as its visited in breadth-first order. "bottomup" | ||
traverses the graph depth-first, so the roots are visited after the | ||
leaves (i.e., nodes are represented in "execution order" where | ||
dependencies are processed first). | ||
""" | ||
|
||
if mode == "bottomup": | ||
return self._walk_bottomup(f, context, set()) | ||
|
||
assert mode == "topdown" | ||
|
||
visited = set() | ||
queue = deque([self]) | ||
while len(queue) != 0: | ||
cur = queue.popleft() | ||
if cur not in visited: | ||
f(cur, context) | ||
visited.add(cur) | ||
for child in cur.children: | ||
queue.append(child) | ||
|
||
def __eq__(self, other): | ||
return self.id == other.id | ||
|
||
def __hash__(self): | ||
return hash(self.id) | ||
|
||
|
||
class PhysicalValue(WeldNode): | ||
""" | ||
A physical value that a lazy computation depends on. | ||
""" | ||
def __init__(self, value, ty, encoder): | ||
self.value = value | ||
self.encoder = encoder | ||
self.ty_ = ty | ||
|
||
@classmethod | ||
def prefix(cls): | ||
return "inp" | ||
|
||
@property | ||
def children(self): | ||
return [] | ||
|
||
@property | ||
def output_type(self): | ||
return self.ty_ | ||
|
||
class WeldLazy(WeldNode): | ||
""" | ||
A lazy value that encapsulates a Weld computation. | ||
""" | ||
|
||
def __init__(self, expression, dependencies, ty, decoder): | ||
""" | ||
Creates a new lazy Weld computation. | ||
Parameters | ||
---------- | ||
expression : str | ||
A weld expression. | ||
dependencies : list[WeldNode] | ||
A list of dependencies. The expression should only use names | ||
from this list. | ||
ty : WeldType | ||
The output type of this computation. | ||
decoder : A decoder for decoding the Weld result of this computation. | ||
""" | ||
# Remove duplicates here | ||
self.children_ = list(set(dependencies)) | ||
self.expression = expression | ||
self.decoder = decoder | ||
self.ty_ = ty | ||
|
||
@property | ||
def children(self): | ||
return self.children_ | ||
|
||
@property | ||
def output_type(self): | ||
return self.ty_ | ||
|
||
def _create_function_header(self, inputs): | ||
arguments = ["{0}: {1}".format(inp.id, str(inp.output_type)) for inp in inputs] | ||
return "|" + ", ".join(arguments) + "|" | ||
|
||
def evaluate(self): | ||
# Collect nodes in execution order. | ||
nodes_to_execute = [] | ||
self.walk(lambda node, expressions: expressions.append(node), nodes_to_execute) | ||
|
||
# Inputs are PhysicalValue objects. | ||
inputs = [node for node in nodes_to_execute if isinstance(node, PhysicalValue)] | ||
inputs.sort(key=lambda e: e.id.name) | ||
arg_types = [inp.output_type for inp in inputs] | ||
encoders = [inp.encoder for inp in inputs] | ||
values = [inp.value for inp in inputs] | ||
|
||
# Collect the expressions from the remaining nodes. | ||
expressions = [ | ||
"let {name} = ({expr});".format(name=node.id, expr=node.expression) for node in nodes_to_execute if isinstance(node, WeldLazy)] | ||
assert nodes_to_execute[-1] is self | ||
expressions.append(str(self.id)) | ||
|
||
program = self._create_function_header(inputs) + " " + "\n".join(expressions) | ||
print(program) | ||
|
||
# TODO(shoumik): cache me! | ||
program = weld.compile.compile(program, arg_types, encoders, self.output_type, self.decoder) | ||
return program(*values) |
Oops, something went wrong.