-
Notifications
You must be signed in to change notification settings - Fork 1
/
hostlist.py
393 lines (320 loc) · 13.4 KB
/
hostlist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Hostlist library
#
# Copyright (C) 2008 Kent Engström <[email protected]>,
# Thomas Bellman <[email protected]> and
# Pär Andersson <[email protected]>,
# National Supercomputer Centre
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
"""Handle hostlist expressions.
This module provides operations to expand and collect hostlist
expressions.
The hostlist expression syntax is the same as in several programs
developed at LLNL (https://computing.llnl.gov/linux/). However in
corner cases the behaviour of this module have not been compared for
compatibility with pdsh/dshbak/SLURM et al.
"""
__version__ = "1.17"
import re
import itertools
# Exception used for error reporting to the caller
class BadHostlist(Exception): pass
# Configuration to guard against ridiculously long expanded lists
MAX_SIZE = 100000
# Hostlist expansion
def expand_hostlist(hostlist, allow_duplicates=False, sort=False):
"""Expand a hostlist expression string to a Python list.
Example: expand_hostlist("n[9-11],d[01-02]") ==>
['n9', 'n10', 'n11', 'd01', 'd02']
Unless allow_duplicates is true, duplicates will be purged
from the results. If sort is true, the output will be sorted.
"""
results = []
bracket_level = 0
part = ""
for c in hostlist + ",":
if c == "," and bracket_level == 0:
# Comma at top level, split!
if part: results.extend(expand_part(part))
part = ""
bad_part = False
else:
part += c
if c == "[": bracket_level += 1
elif c == "]": bracket_level -= 1
if bracket_level > 1:
raise BadHostlist("nested brackets")
elif bracket_level < 0:
raise BadHostlist("unbalanced brackets")
if bracket_level > 0:
raise BadHostlist("unbalanced brackets")
if not allow_duplicates:
results = remove_duplicates(results)
if sort:
results = numerically_sorted(results)
return results
def expand_part(s):
"""Expand a part (e.g. "x[1-2]y[1-3][1-3]") (no outer level commas)."""
# Base case: the empty part expand to the singleton list of ""
if s == "":
return [""]
# Split into:
# 1) prefix string (may be empty)
# 2) rangelist in brackets (may be missing)
# 3) the rest
m = re.match(r'([^,\[]*)(\[[^\]]*\])?(.*)', s)
(prefix, rangelist, rest) = m.group(1,2,3)
# Expand the rest first (here is where we recurse!)
rest_expanded = expand_part(rest)
# Expand our own part
if not rangelist:
# If there is no rangelist, our own contribution is the prefix only
us_expanded = [prefix]
else:
# Otherwise expand the rangelist (adding the prefix before)
us_expanded = expand_rangelist(prefix, rangelist[1:-1])
# Combine our list with the list from the expansion of the rest
# (but guard against too large results first)
if len(us_expanded) * len(rest_expanded) > MAX_SIZE:
raise BadHostlist("results too large")
return [us_part + rest_part
for us_part in us_expanded
for rest_part in rest_expanded]
def expand_rangelist(prefix, rangelist):
""" Expand a rangelist (e.g. "1-10,14"), putting a prefix before."""
# Split at commas and expand each range separately
results = []
for range_ in rangelist.split(","):
results.extend(expand_range(prefix, range_))
return results
def expand_range(prefix, range_):
""" Expand a range (e.g. 1-10 or 14), putting a prefix before."""
# Check for a single number first
m = re.match(r'^[0-9]+$', range_)
if m:
return ["%s%s" % (prefix, range_)]
# Otherwise split low-high
m = re.match(r'^([0-9]+)-([0-9]+)$', range_)
if not m:
raise BadHostlist("bad range")
(s_low, s_high) = m.group(1,2)
low = int(s_low)
high = int(s_high)
width = len(s_low)
if high < low:
raise BadHostlist("start > stop")
elif high - low > MAX_SIZE:
raise BadHostlist("range too large")
results = []
for i in range(low, high+1):
results.append("%s%0*d" % (prefix, width, i))
return results
def remove_duplicates(l):
"""Remove duplicates from a list (but keep the order)."""
seen = set()
results = []
for e in l:
if e not in seen:
results.append(e)
seen.add(e)
return results
# Hostlist collection
def collect_hostlist(hosts, silently_discard_bad = False):
"""Collect a hostlist string from a Python list of hosts.
We start grouping from the rightmost numerical part.
Duplicates are removed.
A bad hostname raises an exception (unless silently_discard_bad
is true causing the bad hostname to be silently discarded instead).
"""
# Split hostlist into a list of (host, "") for the iterative part.
# (Also check for bad node names now)
# The idea is to move already collected numerical parts from the
# left side (seen by each loop) to the right side (just copied).
left_right = []
for host in hosts:
# We remove leading and trailing whitespace first, and skip empty lines
host = host.strip()
if host == "": continue
# We cannot accept a host containing any of the three special
# characters in the hostlist syntax (comma and flat brackets)
if re.search(r'[][,]', host):
if silently_discard_bad:
continue
else:
raise BadHostlist("forbidden character")
left_right.append((host, ""))
# Call the iterative function until it says it's done
looping = True
while looping:
left_right, looping = collect_hostlist_1(left_right)
return ",".join([left + right for left, right in left_right])
def collect_hostlist_1(left_right):
"""Collect a hostlist string from a list of hosts (left+right).
The input is a list of tuples (left, right). The left part
is analyzed, while the right part is just passed along
(it can contain already collected range expressions).
"""
# Scan the list of hosts (left+right) and build two things:
# *) a set of all hosts seen (used later)
# *) a list where each host entry is preprocessed for correct sorting
sortlist = []
remaining = set()
for left, right in left_right:
host = left + right
remaining.add(host)
# Match the left part into parts
m = re.match(r'^(.*?)([0-9]+)?([^0-9]*)$', left)
(prefix, num_str, suffix) = m.group(1,2,3)
# Add the right part unprocessed to the suffix.
# This ensures than an already computed range expression
# in the right part is not analyzed again.
suffix = suffix + right
if num_str is None:
# A left part with no numeric part at all gets special treatment!
# The regexp matches with the whole string as the suffix,
# with nothing in the prefix or numeric parts.
# We do not want that, so we move it to the prefix and put
# None as a special marker where the suffix should be.
assert prefix == ""
sortlist.append(((host, None), None, None, host))
else:
# A left part with at least an numeric part
# (we care about the rightmost numeric part)
num_int = int(num_str)
num_width = len(num_str) # This width includes leading zeroes
sortlist.append(((prefix, suffix), num_int, num_width, host))
# Sort lexicographically, first on prefix, then on suffix, then on
# num_int (numerically), then...
# This determines the order of the final result.
sortlist.sort()
# We are ready to collect the result parts as a list of new (left,
# right) tuples.
results = []
needs_another_loop = False
# Now group entries with the same prefix+suffix combination (the
# key is the first element in the sortlist) to loop over them and
# then to loop over the list of hosts sharing the same
# prefix+suffix combination.
for ((prefix, suffix), group) in itertools.groupby(sortlist,
key=lambda x:x[0]):
if suffix is None:
# Special case: a host with no numeric part
results.append(("", prefix)) # Move everything to the right part
remaining.remove(prefix)
else:
# The general case. We prepare to collect a list of
# ranges expressed as (low, high, width) for later
# formatting.
range_list = []
for ((prefix2, suffix2), num_int, num_width, host) in group:
if host not in remaining:
# Below, we will loop internally to enumate a whole range
# at a time. We then remove the covered hosts from the set.
# Therefore, skip the host here if it is gone from the set.
continue
assert num_int is not None
# Scan for a range starting at the current host
low = num_int
while True:
host = "%s%0*d%s" % (prefix, num_width, num_int, suffix)
if host in remaining:
remaining.remove(host)
num_int += 1
else:
break
high = num_int - 1
assert high >= low
range_list.append((low, high, num_width))
# We have a list of ranges to format. We make sure
# we move our handled numerical part to the right to
# stop it from being processed again.
needs_another_loop = True
if len(range_list) == 1 and range_list[0][0] == range_list[0][1]:
# Special case to make sure that n1 is not shown as n[1] etc
results.append((prefix,
"%0*d%s" %
(range_list[0][2], range_list[0][0], suffix)))
else:
# General case where high > low
results.append((prefix, "[" + \
",".join([format_range(l, h, w)
for l, h, w in range_list]) + \
"]" + suffix))
# At this point, the set of remaining hosts should be empty and we
# are ready to return the result, together with the flag that says
# if we need to loop again (we do if we have added something to a
# left part).
assert not remaining
return results, needs_another_loop
def format_range(low, high, width):
"""Format a range from low to high inclusively, with a certain width."""
if low == high:
return "%0*d" % (width, low)
else:
return "%0*d-%0*d" % (width, low, width, high)
# Sort a list of hosts numerically
def numerically_sorted(l):
"""Sort a list of hosts numerically.
E.g. sorted order should be n1, n2, n10; not n1, n10, n2.
"""
return sorted(l, key=numeric_sort_key)
nsk_re = re.compile("([0-9]+)|([^0-9]+)")
def numeric_sort_key(x):
return [handle_int_nonint(i_ni) for i_ni in nsk_re.findall(x)]
def handle_int_nonint(int_nonint_tuple):
if int_nonint_tuple[0]:
return int(int_nonint_tuple[0])
else:
return int_nonint_tuple[1]
# Parse SLURM_TASKS_PER_NODE into a list of task numbers
#
# Description from the SLURM sbatch man page:
# Number of tasks to be initiated on each node. Values
# are comma separated and in the same order as
# SLURM_NODELIST. If two or more consecutive nodes are
# to have the same task count, that count is followed by
# "(x#)" where "#" is the repetition count. For example,
# "SLURM_TASKS_PER_NODE=2(x3),1" indicates that the first
# three nodes will each execute three tasks and the
# fourth node will execute one task.
def parse_slurm_tasks_per_node(s):
res = []
for part in s.split(","):
m = re.match(r'^([0-9]+)(\(x([0-9]+)\))?$', part)
if m:
tasks = int(m.group(1))
repetitions = m.group(3)
if repetitions is None:
repetitions = 1
else:
repetitions = int(repetitions)
if repetitions > MAX_SIZE:
raise BadHostlist("task list repetitions too large")
for i in range(repetitions):
res.append(tasks)
else:
raise BadHostlist("bad task list syntax")
return res
#
# Keep this part to tell users where the command line interface went
#
if __name__ == '__main__':
import os, sys
sys.stderr.write("The command line utility has been moved to a separate 'hostlist' program.\n")
sys.exit(os.EX_USAGE)