Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task Fusion, Constant Conversion Optimization, and 27pt stencil benchmark #150

Open
wants to merge 55 commits into
base: branch-24.03
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
c7575d8
double add fused op
shivsundram Sep 18, 2021
9c3879c
all inputs serialized for fused op
shivsundram Sep 20, 2021
0a5fae3
fusion via inlining, as well as function based fusion in fused_binary
shivsundram Sep 27, 2021
2b0b3b0
scalars reductions and opids, need to remove dynamic allocations
shivsundram Sep 29, 2021
c301582
fusion metadata passed via serialization now
shivsundram Sep 29, 2021
01f9956
reuse serializer
shivsundram Oct 1, 2021
72e03b3
some timing scripts, also stuff
shivsundram Oct 4, 2021
4232c3c
remove profiling files
shivsundram Oct 4, 2021
5c67081
re add in examples
shivsundram Oct 4, 2021
983985a
partial fusion
Oct 14, 2021
2ce8002
merge attempt 1
Oct 14, 2021
caac1ee
finishing merge
Oct 15, 2021
ba128b1
op registry working
Oct 25, 2021
5bc9d7b
add new fused dir
Oct 25, 2021
3c9698b
Update the package version for the release
marcinz Oct 27, 2021
02d8ce6
Fix #111
magnatelee Oct 27, 2021
37273ae
Merge pull request #116 from magnatelee/typo-fix
magnatelee Oct 27, 2021
b698b33
Decrease relative tolerance in allclose for float16 values
marcinz Oct 29, 2021
ac314de
Revert "Decrease relative tolerance in allclose for float16 values"
marcinz Oct 29, 2021
4167c7a
Allow greater margin of error for tensordot with float16
marcinz Oct 29, 2021
395ff6d
gpu fused op
Nov 1, 2021
465a044
reduction fix
Nov 6, 2021
65ffccf
merge
Nov 8, 2021
5d3eab1
merge again
Nov 13, 2021
615c95a
re add cuda fused
Nov 13, 2021
13f95ad
fixing fuse file
Nov 13, 2021
ce77d59
more fused stuff
Nov 13, 2021
00897a6
Merge branch 'shiv1/op_fusion2' of github.com:shivsundram/legate.nump…
Nov 13, 2021
0d60b82
last merge fixes
Nov 13, 2021
e840297
constant optimization
Nov 15, 2021
8eb7694
better constant opt
Nov 22, 2021
becf41a
batch syncs for black scholes
Nov 22, 2021
b52c7bd
fused op cleanup
Nov 22, 2021
8c97fe0
merging in new branch
Nov 22, 2021
1ae85c5
black scholes adjustment
Dec 1, 2021
7a58b6d
add missing header
Dec 2, 2021
5a38ef1
27 pt stencil
Dec 4, 2021
f230fcb
only do constant optimization for deferred arrays for some reason
Dec 6, 2021
7922c3d
remove old files, change to constant optimization
Dec 12, 2021
d5908e1
cleanup
Dec 12, 2021
dce6226
cleanup
Dec 12, 2021
1c9fd1f
cleanup
Dec 12, 2021
1665cb5
constant opt adjustment
Dec 12, 2021
e541cf8
merging
Dec 12, 2021
aedf22d
merging
Dec 12, 2021
a3dd95a
undo last change
Dec 12, 2021
1fe56d3
cleanup fused op
Dec 13, 2021
c8c69b8
more cleanup
Dec 13, 2021
4c104dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2021
b4302a3
omp changes
Dec 13, 2021
e265929
merge conflict
Dec 13, 2021
92d8590
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2021
7391628
one more merge conflict
Dec 13, 2021
00687d9
merge conflict
Dec 13, 2021
1783d65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@

from legate.core import Array

from .config import BinaryOpCode, UnaryOpCode, UnaryRedCode
from .config import (
BinaryOpCode,
CuNumericOpCode,
FusedOpCode,
UnaryOpCode,
UnaryRedCode,
)
from .deferred import DeferredArray
from .doc_utils import copy_docstring
from .runtime import runtime
from .utils import unimplemented
Expand Down Expand Up @@ -464,7 +471,6 @@ def __ge__(self, rhs):
)

# __getattribute__

def _convert_key(self, key, stacklevel=2, first=True):
# Convert any arrays stored in a key to a cuNumeric array
if (
Expand Down Expand Up @@ -1953,7 +1959,7 @@ def perform_unary_reduction(
)
return dst

# Return a new cuNumeric array for a binary operation
# Return a new legate array for a binary operation
@classmethod
def perform_binary_op(
cls,
Expand Down Expand Up @@ -2017,29 +2023,64 @@ def perform_binary_op(
if out_dtype is None:
out_dtype = cls.find_common_type(one, two)
if check_types:
isDeferred = isinstance(one._thunk, DeferredArray) or isinstance(
two._thunk, DeferredArray
)
if one.dtype != two.dtype:
common_type = cls.find_common_type(one, two)
if one.dtype != common_type:
temp = ndarray(
shape=one.shape,
dtype=common_type,
stacklevel=(stacklevel + 1),
inputs=(one, two, where),
)
temp._thunk.convert(
one._thunk, stacklevel=(stacklevel + 1)
)
# remove convert ops
if isDeferred and one.shape == ():
temp = ndarray(
shape=one.shape,
dtype=common_type,
# buffer = one._thunk.array.astype(common_type),
stacklevel=(stacklevel + 1),
inputs=(one, two, where),
)
temp._thunk = runtime.create_scalar(
one._thunk.array.astype(common_type),
common_type,
shape=one.shape,
wrap=True,
)
else:
temp = ndarray(
shape=one.shape,
dtype=common_type,
stacklevel=(stacklevel + 1),
inputs=(one, two, where),
)
temp._thunk.convert(
one._thunk, stacklevel=(stacklevel + 1)
)
one = temp
if two.dtype != common_type:
temp = ndarray(
shape=two.shape,
dtype=common_type,
stacklevel=(stacklevel + 1),
inputs=(one, two, where),
)
temp._thunk.convert(
two._thunk, stacklevel=(stacklevel + 1)
)
# remove convert ops
if isDeferred and two.shape == ():
temp = ndarray(
shape=two.shape,
dtype=common_type,
# buffer = two._thunk.array.astype(common_type),
stacklevel=(stacklevel + 1),
inputs=(one, two, where),
)
temp._thunk = runtime.create_scalar(
two._thunk.array.astype(common_type),
common_type,
shape=two.shape,
wrap=True,
)
else:
temp = ndarray(
shape=two.shape,
dtype=common_type,
stacklevel=(stacklevel + 1),
inputs=(one, two, where),
)
temp._thunk.convert(
two._thunk, stacklevel=(stacklevel + 1)
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above code is a scalar constant optimization, which avoids dispatching CONVERT operations (for a scalar constant), as the constant's value is embedded in the code and thus already known

two = temp
if out.dtype != out_dtype:
temp = ndarray(
Expand Down
9 changes: 9 additions & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class CuNumericOpCode(IntEnum):
UNARY_RED = _cunumeric.CUNUMERIC_UNARY_RED
WHERE = _cunumeric.CUNUMERIC_WHERE
WRITE = _cunumeric.CUNUMERIC_WRITE
FUSED_OP = _cunumeric.CUNUMERIC_FUSED_OP


@unique
class FusedOpCode(IntEnum):
FUSE = 1


# Match these to BinaryOpCode in binary_op_util.h
Expand Down Expand Up @@ -197,3 +203,6 @@ class CuNumericRedopCode(IntEnum):
class CuNumericTunable(IntEnum):
NUM_GPUS = _cunumeric.CUNUMERIC_TUNABLE_NUM_GPUS
MAX_EAGER_VOLUME = _cunumeric.CUNUMERIC_TUNABLE_MAX_EAGER_VOLUME


cunumeric_context.fused_id = CuNumericOpCode.FUSED_OP
4 changes: 3 additions & 1 deletion cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def __numpy_array__(self, stacklevel=0):
return np.empty(shape=self.shape, dtype=self.dtype)

if self.scalar:
if not self.base._storage:
self.runtime.legate_runtime._launch_outstanding()

result = np.full(
self.shape,
self.get_scalar_array(stacklevel=(stacklevel + 1)),
Expand Down Expand Up @@ -1568,7 +1571,6 @@ def binary_op(

task.add_alignment(lhs, rhs1)
task.add_alignment(lhs, rhs2)

task.execute()

@profile
Expand Down
16 changes: 11 additions & 5 deletions examples/black_scholes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,17 @@ def run_black_scholes(N, D):
N *= 1000
start = datetime.datetime.now()
S, X, T, R, V = initialize(N, D)
call, put = black_scholes(S, X, T, R, V)
# Check the result for NaNs to synchronize before stopping timing
call_sum = np.sum(call)
put_sum = np.sum(put)
assert not math.isnan(call_sum) and not math.isnan(put_sum)
trials = 300
ends = [None for i in range(trials)]
for i in range(trials):
call, put = black_scholes(S, X, T, R, V)
# Check the result for NaNs to synchronize before stopping timing
call_sum = np.sum(call)
put_sum = np.sum(put)
ends[i] = (call_sum, put_sum)
for i in range(trials):
call_sum, put_sum = ends[i]
assert not math.isnan(call_sum) and not math.isnan(put_sum)
stop = datetime.datetime.now()
delta = stop - start
total = delta.total_seconds() * 1000.0
Expand Down
1 change: 1 addition & 0 deletions examples/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def run(grid, I, N): # noqa: E741
# delta = np.sum(np.absolute(work - center))
center[:] = work
total = np.sum(center)
# return total
return total / (N ** 2)


Expand Down
158 changes: 158 additions & 0 deletions examples/stencil_27.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python

# Copyright 2021 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

import argparse
import datetime
import math

from benchmark import run_benchmark

import cunumeric as np


def initialize(N):
print("Initializing stencil grid...")
grid = np.zeros((N + 2, N + 2, N + 2))
grid[:, :, 0] = -273.15
grid[:, 0, :] = -273.15
grid[0, :, :] = -273.15
grid[:, :, -1] = 273.15
grid[:, -1, :] = 273.15
grid[-1, :, :] = 273.15

return grid


def run(grid, I, N): # noqa: E741
print("Running Jacobi 27 stencil...")

# one
g000 = grid[0:-2, 0:-2, 0:-2]
g001 = grid[0:-2, 0:-2, 1:-1]
g002 = grid[0:-2, 0:-2, 2:]

g010 = grid[0:-2, 1:-1, 0:-2]
g011 = grid[0:-2, 1:-1, 1:-1]
g012 = grid[0:-2, 1:-1, 2:]

g020 = grid[0:-2, 2:, 0:-2]
g021 = grid[0:-2, 2:, 1:-1]
g022 = grid[0:-2, 2:, 2:]

# two
g100 = grid[1:-1, 0:-2, 0:-2]
g101 = grid[1:-1, 0:-2, 1:-1]
g102 = grid[1:-1, 0:-2, 2:]

g110 = grid[1:-1, 1:-1, 0:-2]
g111 = grid[1:-1, 1:-1, 1:-1]
g112 = grid[1:-1, 1:-1, 2:]

g120 = grid[1:-1, 2:, 0:-2]
g121 = grid[1:-1, 2:, 1:-1]
g122 = grid[1:-1, 2:, 2:]

# three
g200 = grid[2:, 0:-2, 0:-2]
g201 = grid[2:, 0:-2, 1:-1]
g202 = grid[2:, 0:-2, 2:]

g210 = grid[2:, 1:-1, 0:-2]
g211 = grid[2:, 1:-1, 1:-1]
g212 = grid[2:, 1:-1, 2:]

g220 = grid[2:, 2:, 0:-2]
g221 = grid[2:, 2:, 1:-1]
g222 = grid[2:, 2:, 2:]

for i in range(I):
g00 = g000 + g001 + g002
g01 = g010 + g011 + g012
g02 = g020 + g021 + g022
g10 = g100 + g101 + g102
g11 = g110 + g111 + g112
g12 = g120 + g121 + g122
g20 = g200 + g201 + g202
g21 = g210 + g211 + g212
g22 = g220 + g221 + g222

g0 = g00 + g01 + g02
g1 = g10 + g11 + g12
g2 = g20 + g21 + g22

res = g0 + g1 + g2
work = 0.037 * res
g111[:] = work
total = np.sum(g111)
return total / (N ** 2)


def run_stencil(N, I, timing): # noqa: E741
start = datetime.datetime.now()
grid = initialize(N)
average = run(grid, I, N)
# This will sync the timing because we will need to wait for the result
assert not math.isnan(average)
stop = datetime.datetime.now()
print("Average energy is %.8g" % average)
delta = stop - start
total = delta.total_seconds() * 1000.0
if timing:
print("Elapsed Time: " + str(total) + " ms")
return total


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--iter",
type=int,
default=100,
dest="I",
help="number of iterations to run",
)
parser.add_argument(
"-n",
"--num",
type=int,
default=100,
dest="N",
help="number of elements in one dimension",
)
parser.add_argument(
"-t",
"--time",
dest="timing",
action="store_true",
help="perform timing",
)
parser.add_argument(
"-b",
"--benchmark",
type=int,
default=1,
dest="benchmark",
help="number of times to benchmark this application (default 1 "
"- normal execution)",
)
args = parser.parse_args()
run_benchmark(
run_stencil, args.benchmark, "Stencil", (args.N, args.I, args.timing)
)
2 changes: 1 addition & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def driver():
"--clean",
dest="clean_first",
action=BooleanFlag,
default=True,
default=False,
help="Clean before build.",
)
parser.add_argument(
Expand Down
3 changes: 3 additions & 0 deletions src/cunumeric.mk
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# since we have to add the -fopenmp flag to CC_FLAGS for them
GEN_CPU_SRC += cunumeric/ternary/where.cc \
cunumeric/binary/binary_op.cc \
cunumeric/fused/fused_op.cc \
cunumeric/binary/binary_red.cc \
cunumeric/unary/scalar_unary_red.cc \
cunumeric/unary/unary_op.cc \
Expand Down Expand Up @@ -46,6 +47,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \
ifeq ($(strip $(USE_OPENMP)),1)
GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \
cunumeric/binary/binary_op_omp.cc \
cunumeric/fused/fused_op_omp.cc \
cunumeric/binary/binary_red_omp.cc \
cunumeric/unary/unary_op_omp.cc \
cunumeric/unary/scalar_unary_red_omp.cc \
Expand Down Expand Up @@ -75,6 +77,7 @@ GEN_CPU_SRC += cunumeric/cunumeric.cc # This must always be the last file!

GEN_GPU_SRC += cunumeric/ternary/where.cu \
cunumeric/binary/binary_op.cu \
cunumeric/fused/fused_op.cu \
cunumeric/binary/binary_red.cu \
cunumeric/unary/scalar_unary_red.cu \
cunumeric/unary/unary_red.cu \
Expand Down
1 change: 1 addition & 0 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum CuNumericOpCode {
CUNUMERIC_UNARY_RED,
CUNUMERIC_WHERE,
CUNUMERIC_WRITE,
CUNUMERIC_FUSED_OP,
};

// Match these to CuNumericRedopCode in cunumeric/config.py
Expand Down
Loading