From 28825e1c994feeb118782f4ae49584b34e77930a Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Tue, 10 Dec 2024 12:53:35 +0000 Subject: [PATCH] compiler: Reduce some code after reviews --- devito/mpi/routines.py | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index b176dfbd8c..255db7420b 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -443,10 +443,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): if d in fixed: continue - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) + rpeer, lpeer = self._make_peers(d, distributor, nb) if (d, LEFT) in hse.halos: # Sending to left, receiving from right @@ -491,6 +488,14 @@ def _make_basic_mapper(self, f, fixed): return mapper + def _make_peers(self, d, distributor, nb): + rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions) + rpeer = FieldFromPointer(rname, nb) + lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions) + lpeer = FieldFromPointer(lname, nb) + + return rpeer, lpeer + def _call_haloupdate(self, name, f, hse, *args): comm = f.grid.distributor._obj_comm nb = f.grid.distributor._obj_neighborhood @@ -616,10 +621,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): if d in fixed: continue - name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) - rpeer = FieldFromPointer(name, nb) - name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) - lpeer = FieldFromPointer(name, nb) + rpeer, lpeer = self._make_peers(d, distributor, nb) if (d, LEFT) in hse.halos: # Sending to left, receiving from right @@ -1297,6 +1299,7 @@ def _as_number(self, v, args): return int(subs_op_args(v, args)) def _allocate_buffers(self, f, shape, entry): + # Allocate the send/recv buffers entry.sizes = (c_int*len(shape))(*shape) size = reduce(mul, shape)*dtype_len(self.target.dtype) ctype = dtype_to_ctype(f.dtype) @@ -1429,21 +1432,12 @@ def _arg_defaults(self, allocator, alias, args=None): if d in fixed: continue - if (d, LEFT) in self.halos: - entry = self.value[i] - i = i + 1 - # Sending to left, receiving from right - shape = mapper[(d, LEFT, OWNED)] - # Allocate the send/recv buffers - self._allocate_buffers(f, shape, entry) - - if (d, RIGHT) in self.halos: - entry = self.value[i] - i = i + 1 - # Sending to right, receiving from left - shape = mapper[(d, RIGHT, OWNED)] - # Allocate the send/recv buffers - self._allocate_buffers(f, shape, entry) + for side in (LEFT, RIGHT): + if (d, side) in self.halos: + entry = self.value[i] + i += 1 + shape = mapper[(d, side, OWNED)] + self._allocate_buffers(f, shape, entry) return {self.name: self.value}