Skip to content

Commit

Permalink
fix materialization task count when buffering an existing job (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsschult authored Oct 27, 2024
1 parent 5a61d27 commit 8301751
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions iceprod/materialization/materialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,25 @@ async def run_once(self, only_dataset=None, set_status=None, num=10000, dryrun=F
# check that last job was buffered correctly
job_index = max(jobs[i]['job_index'] for i in jobs)+1 if jobs else 0
num_tasks = sum(tasks.values())
logger.info('job_index: %d', job_index)
logger.info('num_tasks: %d', num_tasks)
logger.info('tasks_per_job: %d', dataset['tasks_per_job'])
while num_tasks % dataset['tasks_per_job'] != 0 and job_index > 0:
# a job must have failed to buffer, so check in reverse order
job_index -= 1
logger.info('a job must have failed to buffer, so check in reverse order. job_index=%d, num_tasks=%d', job_index, num_tasks)
job_tasks = await self.rest_client.request('GET', f'/datasets/{dataset_id}/tasks',
{'job_index': job_index, 'keys': 'task_id|job_id|task_index'})
if len(job_tasks) != dataset['tasks_per_job']:
logger.info('fixing buffer of job %d for dataset %s', job_index, dataset_id)
ret = await self.rest_client.request('GET', f'/datasets/{dataset_id}/jobs',
{'job_index': job_index, 'keys': 'job_id'})
job_id = list(ret.keys())[0]
logger.info(' fixing job_id %s, num existing tasks: %d', job_id, len(job_tasks))
tasks_buffered = await self.buffer_job(dataset, job_index, job_id=job_id,
tasks=list(job_tasks.values()),
set_status=set_status, dryrun=dryrun)
num_tasks += tasks_buffered
logger.info('buffered %d tasks. num_tasks increased to: %d', tasks_buffered, num_tasks)

# now try buffering new tasks
job_index = max(jobs[i]['job_index'] for i in jobs)+1 if jobs else 0
Expand Down Expand Up @@ -123,7 +128,7 @@ async def buffer_job(self, dataset, job_index, job_id=None, tasks=None, set_stat
if dryrun:
job_id = {'result': 'DRYRUN'}
task_ids = []
task_iter = enumerate(task_names)
task_iter = list(enumerate(task_names))
elif job_id:
task_ids = [task['task_id'] for task in tasks] if tasks else []
task_indexes = {task['task_index'] for task in tasks} if tasks else {}
Expand All @@ -134,10 +139,11 @@ async def buffer_job(self, dataset, job_index, job_id=None, tasks=None, set_stat
ret = await self.rest_client.request('POST', '/jobs', args)
job_id = ret['result']
task_ids = []
task_iter = enumerate(task_names)
task_iter = list(enumerate(task_names))

# buffer tasks
for task_index,name in task_iter:
logger.info(' buffering task_index %d, name %s', task_index, name)
depends = await self.get_depends(config, job_index,
task_index, task_ids)
config['options']['job'] = job_index
Expand Down Expand Up @@ -167,7 +173,7 @@ async def buffer_job(self, dataset, job_index, job_id=None, tasks=None, set_stat
p = await self.prio.get_task_prio(dataset_id, task_id)
await self.rest_client.request('PATCH', f'/tasks/{task_id}', {'priority': p})

return len(task_ids)
return len(task_iter)

async def get_config(self, dataset_id):
"""Get dataset config"""
Expand Down

0 comments on commit 8301751

Please sign in to comment.