Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwoods2 committed Aug 27, 2024
1 parent a1a92e8 commit 52c7854
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 189 deletions.
5 changes: 1 addition & 4 deletions imdreader/IMDClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ def stop(self):
if self._multithreaded:
if not self._stopped:
self._buf.notify_consumer_finished()
# NOTE: fix producer thread to catch errors
# where socket is disconnected by this thread
# rather than generic expection
self._disconnect()
self._stopped = True
else:
Expand Down Expand Up @@ -235,7 +232,7 @@ def _disconnect(self):
try:
disconnect = create_header_bytes(IMDHeaderType.IMD_DISCONNECT, 0)
self._conn.sendall(disconnect)
logger.debug("IMDProducer: Disconnected from server")
logger.debug("IMDClient: Disconnected from server")
except (ConnectionResetError, BrokenPipeError):
logger.debug(
f"IMDProducer: Attempted to disconnect but server already terminated the connection"
Expand Down
12 changes: 3 additions & 9 deletions imdreader/IMDREADER.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@
class IMDReader(ReaderBase):
"""
Reader for IMD protocol packets.
Default buffer size is set to 8MB for testing
Buffer_size kwarg is in bytes
We are assuming the header will never be sent without the body as in the sample code.
If this assumption is violated, the producer thread can cause a deadlock.
"""

format = "IMD"
Expand Down Expand Up @@ -171,11 +165,11 @@ def _load_imdframe_into_ts(self, imdf):
if imdf.positions is not None:
# must call copy because reference is expected to reset
# see 'test_frame_collect_all_same' in MDAnalysisTests.coordinates.base
self.ts.positions = imdf.positions.copy()
self.ts.positions = imdf.positions
if imdf.velocities is not None:
self.ts.velocities = imdf.velocities.copy()
self.ts.velocities = imdf.velocities
if imdf.forces is not None:
self.ts.forces = imdf.forces.copy()
self.ts.forces = imdf.forces

@property
def n_frames(self):
Expand Down
47 changes: 42 additions & 5 deletions imdreader/tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,23 @@ def __init__(self, traj):
def set_imdsessioninfo(self, imdsinfo):
self.imdsinfo = imdsinfo

def listen_accept_handshake_send_ts(self, host, port):
def handshake_sequence(self, host, port, first_frame=True):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((host, port))
logger.debug(f"InThreadIMDServer: Listening on {host}:{port}")
s.listen(60)
self.listen_socket = s
self.accept_thread = threading.Thread(
target=self._accept_handshake_send_ts,
)
if first_frame:
self.accept_thread = threading.Thread(
target=self._accept_handshake_send_ts,
)
else:
self.accept_thread = threading.Thread(target=self._accept_handshake)
self.accept_thread.start()

def _accept_handshake_send_ts(self):
"""Accepts the connection, sends the handshake & imdsessionfo, and sends the first frame
For testing IMDReader integration"""
logger.debug(f"InThreadIMDServer: Entering accept thread")
waited = 0

Expand All @@ -62,6 +67,25 @@ def _accept_handshake_send_ts(self):
# IMDReader will fail out if it fails to connect
return

def _accept_handshake(self):
"""Accepts the connection and sends the handshake & imdsessionfo. For testing IMDClient directly"""
waited = 0

if sock_contains_data(self.listen_socket, 5):
logger.debug(f"InThreadIMDServer: Accepting connection")
self.conn, _ = self.listen_socket.accept()
# NOTE: may need to reorganize this
self.conn.settimeout(5)
if self.imdsinfo.version == 2:
self._send_handshakeV2()
elif self.imdsinfo.version == 3:
self._send_handshakeV3()
self.expect_packet(IMDHeaderType.IMD_GO)
return
else:
# IMDReader will fail out if it fails to connect
return

def _send_handshakeV2(self):
header = struct.pack("!i", IMDHeaderType.IMD_HANDSHAKE.value)
header += struct.pack(f"{self.imdsinfo.endianness}i", 2)
Expand Down Expand Up @@ -115,7 +139,7 @@ def send_frame(self, i):
if self.imdsinfo.time:
time_header = create_header_bytes(IMDHeaderType.IMD_TIME, 1)
time = struct.pack(
f"{endianness}ff", self.traj[i].data["dt"], self.traj[i].time
f"{endianness}ff", self.traj[i].dt, self.traj[i].time
)

self.conn.sendall(time_header + time)
Expand Down Expand Up @@ -176,6 +200,19 @@ def send_frame(self, i):

self.conn.sendall(force_header + force)

def expect_packet(self, packet_type, expected_length=None):
head_buf = bytearray(IMDHEADERSIZE)
read_into_buf(self.conn, head_buf)
header = IMDHeader(head_buf)
if header.type != packet_type:
raise ValueError(
f"Expected {packet_type} packet, got {header.type}"
)
if expected_length is not None and header.length != expected_length:
raise ValueError(
f"Expected packet length {expected_length}, got {header.length}"
)

def disconnect(self):
self.conn.shutdown(socket.SHUT_RD)
self.conn.close()
Expand Down
84 changes: 83 additions & 1 deletion imdreader/tests/test_imdclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
)
import MDAnalysis as mda
import imdreader
from imdreader.IMDClient import imdframe_memsize
from imdreader.IMDClient import imdframe_memsize, IMDClient
from imdreader.IMDProtocol import IMDHeaderType
from .utils import (
IMDServerEventType,
DummyIMDServer,
get_free_port,
ExpectPauseLoopV2Behavior,
create_default_imdsinfo_v2,
create_default_imdsinfo_v3,
)
from .server import InThreadIMDServer
from MDAnalysisTests.coordinates.base import (
MultiframeReaderTest,
BaseReference,
Expand Down Expand Up @@ -59,6 +62,84 @@ def log_config():
]


class TestIMDReaderV3:

@pytest.fixture
def port(self):
return get_free_port()

@pytest.fixture
def universe(self):
return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)

@pytest.fixture
def imdsinfo(self):
return create_default_imdsinfo_v3()

@pytest.fixture
def server_client(self, universe, imdsinfo, port):
server = InThreadIMDServer(universe.trajectory)
server.set_imdsessioninfo(imdsinfo)
server.handshake_sequence("localhost", port, first_frame=False)
client = IMDClient(
f"localhost",
port,
universe.trajectory.n_atoms,
buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo)
* 2,
)
yield server, client
client.stop()
server.cleanup()

def test_pause_resume_continue(self, server_client):
server, client = server_client
server.send_frames(0, 2)
# Client's buffer is filled. client should send pause
server.expect_packet(IMDHeaderType.IMD_PAUSE)
# Empty buffer
client.get_imdframe()
# only the second call actually frees buffer memory
client.get_imdframe()
# client has free memory. should send resume
server.expect_packet(IMDHeaderType.IMD_RESUME)
server.send_frame(1)
client.get_imdframe()

def test_pause_resume_disconnect(self, server_client):
"""Client pauses because buffer is full, empties buffer and attempt to resume, but
finds that simulation has already ended and raises EOF"""
server, client = server_client
server.send_frames(0, 2)
server.expect_packet(IMDHeaderType.IMD_PAUSE)
client.get_imdframe()
client.get_imdframe()
# client has free frame. should send resume
server.expect_packet(IMDHeaderType.IMD_RESUME)
# simulation is over. client should raise EOF
server.disconnect()
with pytest.raises(EOFError):
client.get_imdframe()

def test_pause_resume_no_disconnect(self, server_client):
"""Client pauses because buffer is full, empties buffer and attempt to resume, but
finds that simulation has already ended (but has not yet disconnected) and raises EOF
"""
server, client = server_client
server.send_frames(0, 2)
server.expect_packet(IMDHeaderType.IMD_PAUSE)
client.get_imdframe()
client.get_imdframe()
# client has free frame. should send resume
server.expect_packet(IMDHeaderType.IMD_RESUME)
# simulation is over. client should raise EOF
with pytest.raises(EOFError):
client.get_imdframe()
# server should receive disconnect from client (though it doesn't have to do anything)
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)


"""
class TestIMDReaderV2:
@pytest.fixture
Expand Down Expand Up @@ -219,3 +300,4 @@ def test_change_endianness_traj_unchanged(self, ref, server, endianness):
i += 1
assert i == len(ref)
"""
2 changes: 1 addition & 1 deletion imdreader/tests/test_imdreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def ref(self):
def reader(self, ref):
# This will start the test IMD Server, waiting for a connection
# to then send handshake & first frame
ref.server.listen_accept_handshake_send_ts("localhost", ref.port)
ref.server.handshake_sequence("localhost", ref.port)
# This will connect to the test IMD Server and read the first frame
reader = ref.reader(ref.trajectory, n_atoms=ref.n_atoms)
# Send the rest of the frames- small enough to all fit in socket itself
Expand Down
Loading

0 comments on commit 52c7854

Please sign in to comment.