forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon_with_cwrap.py
213 lines (184 loc) · 7.98 KB
/
common_with_cwrap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap
from copy import deepcopy
from itertools import product
def parse_arguments(args):
new_args = []
for arg in args:
# Simple arg declaration of form "<type> <name>"
if isinstance(arg, str):
t, _, name = arg.partition(' ')
new_args.append({'type': t, 'name': name})
elif isinstance(arg, dict):
if 'arg' in arg:
arg['type'], _, arg['name'] = arg['arg'].partition(' ')
del arg['arg']
new_args.append(arg)
else:
raise AssertionError()
return new_args
def set_declaration_defaults(declaration):
if 'schema_string' not in declaration:
declaration['schema_string'] = ''
if 'matches_jit_signature' not in declaration:
declaration['matches_jit_signature'] = False
declaration.setdefault('arguments', [])
declaration.setdefault('return', 'void')
if 'cname' not in declaration:
declaration['cname'] = declaration['name']
if 'backends' not in declaration:
declaration['backends'] = ['CPU', 'CUDA']
if 'api_name' not in declaration:
declaration['api_name'] = declaration['name']
# Simulate multiple dispatch, even if it's not necessary
if 'options' not in declaration:
declaration['options'] = [{'arguments': declaration['arguments']}]
del declaration['arguments']
# Parse arguments (some of them can be strings)
for option in declaration['options']:
option['arguments'] = parse_arguments(option['arguments'])
# Propagate defaults from declaration to options
for option in declaration['options']:
for k, v in declaration.items():
# TODO(zach): why does cwrap not propagate 'name'? I need it
# propagaged for ATen
if k != 'options':
option.setdefault(k, v)
# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
def exclude_arg(arg):
return arg.get('ignore_check') or arg['type'] == 'CONSTANT'
def exclude_arg_with_self_check(arg):
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
def signature(option, kwarg_only_count):
if kwarg_only_count == 0:
kwarg_only_count = None
else:
kwarg_only_count = -kwarg_only_count
arg_signature = '#'.join(
type_to_signature.get(arg['type'], arg['type'])
for arg in option['arguments'][:kwarg_only_count]
if not exclude_arg_with_self_check(arg))
if kwarg_only_count is None:
return arg_signature
kwarg_only_signature = '#'.join(
arg['name'] + '#' + arg['type']
for arg in option['arguments'][kwarg_only_count:]
if not exclude_arg(arg))
return arg_signature + "#-#" + kwarg_only_signature
seen_signatures = set()
unique = []
for option in options:
# if only check num_kwarg_only == 0 if allow_kwarg == False
limit = len(option['arguments']) if allow_kwarg else 0
for num_kwarg_only in range(0, limit + 1):
sig = signature(option, num_kwarg_only)
if sig not in seen_signatures:
if num_kwarg_only > 0:
for arg in option['arguments'][-num_kwarg_only:]:
arg['kwarg_only'] = True
unique.append(option)
seen_signatures.add(sig)
break
return unique
def enumerate_options_due_to_default(declaration,
allow_kwarg=True, type_to_signature=None, remove_self=True):
if type_to_signature is None:
type_to_signature = []
# Checks to see if an argument with a default keyword is a Tensor that
# by default can be NULL. In this case, instead of generating another
# option that excludes this argument, we will instead generate a single
# function call that allows for the Tensor to be NULL
def is_nullable_tensor_arg(arg):
return arg['type'] == 'THTensor*' and arg['default'] == 'nullptr'
# TODO(zach): in cwrap this is shared among all declarations
# but seems to assume that all declarations will have the same
new_options = []
for option in declaration['options']:
optional_args = []
for i, arg in enumerate(option['arguments']):
if 'default' in arg:
optional_args.append(i)
for permutation in product((True, False), repeat=len(optional_args)):
option_copy = deepcopy(option)
option_copy['has_full_argument_list'] = sum(permutation) == len(optional_args)
for i, bit in zip(optional_args, permutation):
arg = option_copy['arguments'][i]
# PyYAML interprets NULL as None...
arg['default'] = 'NULL' if arg['default'] is None else arg['default']
if not bit:
arg['declared_type'] = arg['type']
arg['type'] = 'CONSTANT'
arg['ignore_check'] = True
new_options.append(option_copy)
declaration['options'] = filter_unique_options(new_options,
allow_kwarg, type_to_signature, remove_self)
def sort_by_number_of_options(declaration, reverse=True):
def num_checked_args(option):
return sum(map(lambda a: not a.get('ignore_check', False), option['arguments']))
declaration['options'].sort(key=num_checked_args, reverse=reverse)
class Function(object):
def __init__(self, name):
self.name = name
self.arguments = []
def add_argument(self, arg):
assert isinstance(arg, Argument)
self.arguments.append(arg)
def __repr__(self):
return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')'
class Argument(object):
def __init__(self, _type, name, is_optional):
self.type = _type
self.name = name
self.is_optional = is_optional
def __repr__(self):
return self.type + ' ' + self.name
def parse_header(path):
with open(path, 'r') as f:
lines = f.read().split('\n')
# Remove empty lines and prebackend directives
lines = filter(lambda l: l and not l.startswith('#'), lines)
# Remove line comments
lines = map(lambda l: l.partition('//'), lines)
# Select line and comment part
lines = map(lambda l: (l[0].strip(), l[2].strip()), lines)
# Remove trailing special signs
lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines)
# Split arguments
lines = map(lambda l: (l[0].split(','), l[1]), lines)
# Flatten lines
new_lines = []
for l, c in lines:
for split in l:
new_lines.append((split, c))
lines = new_lines
del new_lines
# Remove unnecessary whitespace
lines = map(lambda l: (l[0].strip(), l[1]), lines)
# Remove empty lines
lines = filter(lambda l: l[0], lines)
generic_functions = []
for l, c in lines:
if l.startswith('TH_API void THNN_'):
fn_name = l[len('TH_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l.startswith('THC_API void THNN_'):
fn_name = l[len('THC_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l:
t, name = l.split()
if '*' in name:
t = t + '*'
name = name[1:]
generic_functions[-1].add_argument(
Argument(t, name, '[OPTIONAL]' in c))
return generic_functions