From 9829c28005942fbc6f630fdd1bb30bbbf72c977f Mon Sep 17 00:00:00 2001 From: Adrian Sutton Date: Wed, 21 Dec 2022 06:06:59 +1000 Subject: [PATCH] Implement source address filtering (#171) --- .../beacon/discovery/AddressAccessPolicy.java | 19 ++++++ .../discovery/DiscoveryManagerImpl.java | 17 ++++-- .../discovery/DiscoverySystemBuilder.java | 9 ++- .../HandshakeMessagePacketHandler.java | 20 +++++-- .../handler/OutgoingParcelHandler.java | 11 +++- .../pipeline/handler/PacketSourceFilter.java | 36 +++++++++++ .../info/FindNodeResponseHandler.java | 7 ++- .../discovery/DiscoveryNetworkTest.java | 7 ++- .../discovery/HandshakeHandlersTest.java | 6 +- .../handler/OutgoingParcelHandlerTest.java | 60 +++++++++++++++++++ .../handler/PacketSourceFilterTest.java | 50 ++++++++++++++++ .../info/FindNodeResponseHandlerTest.java | 28 +++++++-- 12 files changed, 250 insertions(+), 20 deletions(-) create mode 100644 src/main/java/org/ethereum/beacon/discovery/AddressAccessPolicy.java create mode 100644 src/main/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilter.java create mode 100644 src/test/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandlerTest.java create mode 100644 src/test/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilterTest.java diff --git a/src/main/java/org/ethereum/beacon/discovery/AddressAccessPolicy.java b/src/main/java/org/ethereum/beacon/discovery/AddressAccessPolicy.java new file mode 100644 index 000000000..7131c6335 --- /dev/null +++ b/src/main/java/org/ethereum/beacon/discovery/AddressAccessPolicy.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.ethereum.beacon.discovery; + +import java.net.InetSocketAddress; +import org.ethereum.beacon.discovery.schema.NodeRecord; + +@FunctionalInterface +public interface AddressAccessPolicy { + AddressAccessPolicy ALLOW_ALL = __ -> true; + + boolean allow(InetSocketAddress address); + + default boolean allow(NodeRecord record) { + return record.getTcpAddress().map(this::allow).orElse(true); + } +} diff --git a/src/main/java/org/ethereum/beacon/discovery/DiscoveryManagerImpl.java b/src/main/java/org/ethereum/beacon/discovery/DiscoveryManagerImpl.java index f067975c4..19a0eff06 100644 --- a/src/main/java/org/ethereum/beacon/discovery/DiscoveryManagerImpl.java +++ b/src/main/java/org/ethereum/beacon/discovery/DiscoveryManagerImpl.java @@ -39,6 +39,7 @@ import org.ethereum.beacon.discovery.pipeline.handler.NodeSessionRequestHandler; import org.ethereum.beacon.discovery.pipeline.handler.OutgoingParcelHandler; import org.ethereum.beacon.discovery.pipeline.handler.PacketDispatcherHandler; +import org.ethereum.beacon.discovery.pipeline.handler.PacketSourceFilter; import org.ethereum.beacon.discovery.pipeline.handler.UnauthorizedMessagePacketHandler; import org.ethereum.beacon.discovery.pipeline.handler.UnknownPacketTagToSender; import org.ethereum.beacon.discovery.pipeline.handler.WhoAreYouPacketHandler; @@ -66,6 +67,7 @@ public class DiscoveryManagerImpl implements DiscoveryManager { private final Pipeline incomingPipeline = new PipelineImpl(); private final Pipeline outgoingPipeline = new PipelineImpl(); private final LocalNodeRecordStore localNodeRecordStore; + private final AddressAccessPolicy addressAccessPolicy; private volatile DiscoveryClient discoveryClient; private final NodeSessionManager nodeSessionManager; @@ -78,8 +80,10 @@ public DiscoveryManagerImpl( final Scheduler taskScheduler, final ExpirationSchedulerFactory expirationSchedulerFactory, final TalkHandler talkHandler, - final ExternalAddressSelector externalAddressSelector) { + final ExternalAddressSelector externalAddressSelector, + final AddressAccessPolicy addressAccessPolicy) { this.localNodeRecordStore = localNodeRecordStore; + this.addressAccessPolicy = addressAccessPolicy; final NodeRecord homeNodeRecord = localNodeRecordStore.getLocalNodeRecord(); this.discoveryServer = discoveryServer; @@ -91,6 +95,7 @@ public DiscoveryManagerImpl( outgoingPipeline, expirationSchedulerFactory); incomingPipeline + .addHandler(new PacketSourceFilter(addressAccessPolicy)) .addHandler(new IncomingDataPacker(homeNodeRecord.getNodeId())) .addHandler(new WhoAreYouSessionResolver(nodeSessionManager)) .addHandler(new UnknownPacketTagToSender()) @@ -99,7 +104,11 @@ public DiscoveryManagerImpl( .addHandler(new WhoAreYouPacketHandler(outgoingPipeline, taskScheduler)) .addHandler( new HandshakeMessagePacketHandler( - outgoingPipeline, taskScheduler, nodeRecordFactory, nodeSessionManager)) + outgoingPipeline, + taskScheduler, + nodeRecordFactory, + nodeSessionManager, + addressAccessPolicy)) .addHandler(new MessagePacketHandler(nodeRecordFactory)) .addHandler(new UnauthorizedMessagePacketHandler()) .addHandler( @@ -111,7 +120,7 @@ public DiscoveryManagerImpl( .addHandler(new BadPacketHandler()); final FluxSink outgoingSink = outgoingMessages.sink(); outgoingPipeline - .addHandler(new OutgoingParcelHandler(outgoingSink)) + .addHandler(new OutgoingParcelHandler(outgoingSink, addressAccessPolicy)) .addHandler(new NodeSessionRequestHandler()) .addHandler(nodeSessionManager) .addHandler(new NewTaskHandler()) @@ -178,7 +187,7 @@ public CompletableFuture> findNodes( new Request<>( new CompletableFuture<>(), reqId -> new FindNodeMessage(reqId, distances), - new FindNodeResponseHandler(distances)); + new FindNodeResponseHandler(distances, addressAccessPolicy)); return executeTaskImpl(nodeRecord, request); } diff --git a/src/main/java/org/ethereum/beacon/discovery/DiscoverySystemBuilder.java b/src/main/java/org/ethereum/beacon/discovery/DiscoverySystemBuilder.java index 010a28ece..94441338b 100644 --- a/src/main/java/org/ethereum/beacon/discovery/DiscoverySystemBuilder.java +++ b/src/main/java/org/ethereum/beacon/discovery/DiscoverySystemBuilder.java @@ -58,6 +58,7 @@ public class DiscoverySystemBuilder { private TalkHandler talkHandler = TalkHandler.NOOP; private NettyDiscoveryServer discoveryServer = null; private ExternalAddressSelector externalAddressSelector = null; + private AddressAccessPolicy addressAccessPolicy = AddressAccessPolicy.ALLOW_ALL; private final Clock clock = Clock.systemUTC(); private final LivenessChecker livenessChecker = new LivenessChecker(clock); @@ -151,6 +152,11 @@ public DiscoverySystemBuilder externalAddressSelector( return this; } + public DiscoverySystemBuilder addressAccessPolicy(final AddressAccessPolicy addressAccessPolicy) { + this.addressAccessPolicy = addressAccessPolicy; + return this; + } + private void createDefaults() { newAddressHandler = requireNonNullElseGet( @@ -247,7 +253,8 @@ DiscoveryManagerImpl buildDiscoveryManager() { schedulers.newSingleThreadDaemon("discovery-client-" + clientNumber), expirationSchedulerFactory, talkHandler, - externalAddressSelector); + externalAddressSelector, + addressAccessPolicy); } /** diff --git a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/HandshakeMessagePacketHandler.java b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/HandshakeMessagePacketHandler.java index c67a79c72..0013ae847 100644 --- a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/HandshakeMessagePacketHandler.java +++ b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/HandshakeMessagePacketHandler.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.tuweni.bytes.Bytes; +import org.ethereum.beacon.discovery.AddressAccessPolicy; import org.ethereum.beacon.discovery.message.V5Message; import org.ethereum.beacon.discovery.packet.HandshakeMessagePacket; import org.ethereum.beacon.discovery.pipeline.Envelope; @@ -32,16 +33,19 @@ public class HandshakeMessagePacketHandler implements EnvelopeHandler { private final Scheduler scheduler; private final NodeRecordFactory nodeRecordFactory; private final NodeSessionManager nodeSessionManager; + private final AddressAccessPolicy addressAccessPolicy; public HandshakeMessagePacketHandler( Pipeline outgoingPipeline, Scheduler scheduler, NodeRecordFactory nodeRecordFactory, - NodeSessionManager nodeSessionManager) { + NodeSessionManager nodeSessionManager, + AddressAccessPolicy addressAccessPolicy) { this.outgoingPipeline = outgoingPipeline; this.scheduler = scheduler; this.nodeRecordFactory = nodeRecordFactory; this.nodeSessionManager = nodeSessionManager; + this.addressAccessPolicy = addressAccessPolicy; } @Override @@ -97,9 +101,17 @@ public void handle(Envelope envelope) { // Check the node record matches the ID we expect if (!nodeRecordMaybe.map(r -> r.getNodeId().equals(session.getNodeId())).orElse(false)) { LOG.debug( - String.format( - "Incorrect node ID for message [%s] from node %s in status %s", - packet, session.getNodeRecord(), session.getState())); + "Incorrect node ID for message [{}] from node {} in status {}", + packet, + session.getNodeRecord(), + session.getState()); + markHandshakeAsFailed(envelope, session); + return; + } else if (!enr.map(addressAccessPolicy::allow).orElse(true)) { + LOG.debug( + "Rejecting handshake from node {} because the ENR was disallowed: {}", + session.getNodeRecord(), + enr); markHandshakeAsFailed(envelope, session); return; } diff --git a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandler.java b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandler.java index 5d4013980..649e7e505 100644 --- a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandler.java +++ b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandler.java @@ -6,6 +6,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.ethereum.beacon.discovery.AddressAccessPolicy; import org.ethereum.beacon.discovery.network.NetworkParcel; import org.ethereum.beacon.discovery.pipeline.Envelope; import org.ethereum.beacon.discovery.pipeline.EnvelopeHandler; @@ -22,9 +23,12 @@ public class OutgoingParcelHandler implements EnvelopeHandler { private static final Logger LOG = LogManager.getLogger(OutgoingParcelHandler.class); private final FluxSink outgoingSink; + private final AddressAccessPolicy addressAccessPolicy; - public OutgoingParcelHandler(FluxSink outgoingSink) { + public OutgoingParcelHandler( + FluxSink outgoingSink, final AddressAccessPolicy addressAccessPolicy) { this.outgoingSink = outgoingSink; + this.addressAccessPolicy = addressAccessPolicy; } @Override @@ -41,7 +45,10 @@ public void handle(Envelope envelope) { if (envelope.get(Field.INCOMING) instanceof NetworkParcel) { NetworkParcel parcel = (NetworkParcel) envelope.get(Field.INCOMING); if (parcel.getPacket().getBytes().size() > IncomingDataPacker.MAX_PACKET_SIZE) { - LOG.error(() -> "Outgoing packet is too large, dropping it: " + parcel.getPacket()); + LOG.error("Outgoing packet is too large, dropping it: {}", parcel.getPacket()); + } else if (!addressAccessPolicy.allow(parcel.getDestination())) { + LOG.debug( + "Dropping outgoing packet to disallowed destination: {}", parcel.getDestination()); } else { outgoingSink.next(parcel); envelope.remove(Field.INCOMING); diff --git a/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilter.java b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilter.java new file mode 100644 index 000000000..5493bb5b2 --- /dev/null +++ b/src/main/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilter.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.ethereum.beacon.discovery.pipeline.handler; + +import java.net.InetSocketAddress; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.ethereum.beacon.discovery.AddressAccessPolicy; +import org.ethereum.beacon.discovery.pipeline.Envelope; +import org.ethereum.beacon.discovery.pipeline.EnvelopeHandler; +import org.ethereum.beacon.discovery.pipeline.Field; +import org.ethereum.beacon.discovery.pipeline.HandlerUtil; + +public class PacketSourceFilter implements EnvelopeHandler { + private static final Logger LOG = LogManager.getLogger(PacketSourceFilter.class); + + private final AddressAccessPolicy addressAccessPolicy; + + public PacketSourceFilter(final AddressAccessPolicy addressAccessPolicy) { + this.addressAccessPolicy = addressAccessPolicy; + } + + @Override + public void handle(final Envelope envelope) { + if (!HandlerUtil.requireField(Field.REMOTE_SENDER, envelope)) { + return; + } + final InetSocketAddress sender = envelope.get(Field.REMOTE_SENDER); + if (!addressAccessPolicy.allow(sender)) { + envelope.remove(Field.INCOMING); + LOG.debug("Ignoring message from disallowed source {}", sender); + } + } +} diff --git a/src/main/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandler.java b/src/main/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandler.java index 8ec2460b2..618885011 100644 --- a/src/main/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandler.java +++ b/src/main/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandler.java @@ -8,6 +8,7 @@ import java.util.List; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.ethereum.beacon.discovery.AddressAccessPolicy; import org.ethereum.beacon.discovery.message.NodesMessage; import org.ethereum.beacon.discovery.schema.NodeRecord; import org.ethereum.beacon.discovery.schema.NodeSession; @@ -19,11 +20,14 @@ public class FindNodeResponseHandler implements MultiPacketResponseHandler foundNodes = new ArrayList<>(); private final Collection distances; + private final AddressAccessPolicy addressAccessPolicy; private int totalPackets = NOT_SET; private int receivedPackets = 0; - public FindNodeResponseHandler(final Collection distances) { + public FindNodeResponseHandler( + final Collection distances, final AddressAccessPolicy addressAccessPolicy) { this.distances = distances; + this.addressAccessPolicy = addressAccessPolicy; } @Override @@ -53,6 +57,7 @@ public synchronized boolean handleResponseMessage(NodesMessage message, NodeSess message.getNodeRecords().stream() .filter(this::isValid) .filter(record -> hasCorrectDistance(session, record)) + .filter(addressAccessPolicy::allow) .forEach( nodeRecord -> { foundNodes.add(nodeRecord); diff --git a/src/test/java/org/ethereum/beacon/discovery/DiscoveryNetworkTest.java b/src/test/java/org/ethereum/beacon/discovery/DiscoveryNetworkTest.java index 40add0ad0..b7b712865 100644 --- a/src/test/java/org/ethereum/beacon/discovery/DiscoveryNetworkTest.java +++ b/src/test/java/org/ethereum/beacon/discovery/DiscoveryNetworkTest.java @@ -6,6 +6,7 @@ import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; +import static org.ethereum.beacon.discovery.AddressAccessPolicy.ALLOW_ALL; import static org.ethereum.beacon.discovery.TestUtil.NODE_RECORD_FACTORY_NO_VERIFICATION; import static org.ethereum.beacon.discovery.TestUtil.TEST_TRAFFIC_READ_LIMIT; import static org.ethereum.beacon.discovery.TestUtil.waitFor; @@ -80,7 +81,8 @@ public void test() throws Exception { Schedulers.createDefault().newSingleThreadDaemon("tasks-1"), expirationSchedulerFactory, TalkHandler.NOOP, - ExternalAddressSelector.NOOP); + ExternalAddressSelector.NOOP, + ALLOW_ALL); livenessChecker1.setPinger(discoveryManager1::ping); DiscoveryManagerImpl discoveryManager2 = new DiscoveryManagerImpl( @@ -97,7 +99,8 @@ public void test() throws Exception { Schedulers.createDefault().newSingleThreadDaemon("tasks-2"), expirationSchedulerFactory, TalkHandler.NOOP, - ExternalAddressSelector.NOOP); + ExternalAddressSelector.NOOP, + ALLOW_ALL); livenessChecker2.setPinger(discoveryManager2::ping); // 3) Expect standard 1 => 2 dialog diff --git a/src/test/java/org/ethereum/beacon/discovery/HandshakeHandlersTest.java b/src/test/java/org/ethereum/beacon/discovery/HandshakeHandlersTest.java index cbebe6fd0..27b3effa8 100644 --- a/src/test/java/org/ethereum/beacon/discovery/HandshakeHandlersTest.java +++ b/src/test/java/org/ethereum/beacon/discovery/HandshakeHandlersTest.java @@ -5,6 +5,7 @@ package org.ethereum.beacon.discovery; import static java.util.Collections.singletonList; +import static org.ethereum.beacon.discovery.AddressAccessPolicy.ALLOW_ALL; import static org.ethereum.beacon.discovery.TestUtil.NODE_RECORD_FACTORY_NO_VERIFICATION; import static org.ethereum.beacon.discovery.pipeline.Field.BAD_PACKET; import static org.ethereum.beacon.discovery.pipeline.Field.MASKING_IV; @@ -168,7 +169,7 @@ public void authHandlerWithMessageRoundTripTest() throws Exception { new Request<>( new CompletableFuture<>(), id -> new FindNodeMessage(id, singletonList(1)), - new FindNodeResponseHandler(singletonList(1))); + new FindNodeResponseHandler(singletonList(1), ALLOW_ALL)); nodeSessionAt1For2.createNextRequest(request); RawPacket whoAreYouRawPacket = outgoing2Packets.poll(1, TimeUnit.SECONDS); @@ -181,7 +182,8 @@ public void authHandlerWithMessageRoundTripTest() throws Exception { outgoingPipeline, taskScheduler, NODE_RECORD_FACTORY_NO_VERIFICATION, - mock(NodeSessionManager.class)); + mock(NodeSessionManager.class), + ALLOW_ALL); Envelope envelopeAt2From1 = new Envelope(); RawPacket handshakeRawPacket = outgoing1Packets.poll(1, TimeUnit.SECONDS); envelopeAt2From1.put( diff --git a/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandlerTest.java b/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandlerTest.java new file mode 100644 index 000000000..3d64eb606 --- /dev/null +++ b/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/OutgoingParcelHandlerTest.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.ethereum.beacon.discovery.pipeline.handler; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import org.apache.tuweni.bytes.Bytes; +import org.ethereum.beacon.discovery.AddressAccessPolicy; +import org.ethereum.beacon.discovery.network.NetworkParcel; +import org.ethereum.beacon.discovery.network.NetworkParcelV5; +import org.ethereum.beacon.discovery.packet.impl.RawPacketImpl; +import org.ethereum.beacon.discovery.pipeline.Envelope; +import org.ethereum.beacon.discovery.pipeline.Field; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.FluxSink; + +class OutgoingParcelHandlerTest { + + public static final RawPacketImpl PACKET = + new RawPacketImpl(Bytes.fromHexString("0x12341234123412341234123412341234")); + private static final InetSocketAddress DISALLOWED_ADDRESS = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 12345); + private static final InetSocketAddress ALLOWED_ADDRESS = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 8080); + + @SuppressWarnings("unchecked") + private final FluxSink outgoingSink = + (FluxSink) mock(FluxSink.class); + + private final AddressAccessPolicy addressAccessPolicy = + address -> !address.equals(DISALLOWED_ADDRESS); + + private final OutgoingParcelHandler handler = + new OutgoingParcelHandler(outgoingSink, addressAccessPolicy); + + @Test + void shouldNotSendPacketsToDisallowedHosts() { + final Envelope envelope = new Envelope(); + envelope.put(Field.INCOMING, new NetworkParcelV5(PACKET, DISALLOWED_ADDRESS)); + handler.handle(envelope); + + verifyNoInteractions(outgoingSink); + } + + @Test + void shouldSendPacketsToAllowedHosts() { + final Envelope envelope = new Envelope(); + final NetworkParcelV5 parcel = new NetworkParcelV5(PACKET, ALLOWED_ADDRESS); + envelope.put(Field.INCOMING, parcel); + handler.handle(envelope); + + verify(outgoingSink).next(parcel); + } +} diff --git a/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilterTest.java b/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilterTest.java new file mode 100644 index 000000000..677330d22 --- /dev/null +++ b/src/test/java/org/ethereum/beacon/discovery/pipeline/handler/PacketSourceFilterTest.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.ethereum.beacon.discovery.pipeline.handler; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import org.apache.tuweni.bytes.Bytes; +import org.ethereum.beacon.discovery.AddressAccessPolicy; +import org.ethereum.beacon.discovery.pipeline.Envelope; +import org.ethereum.beacon.discovery.pipeline.Field; +import org.junit.jupiter.api.Test; + +class PacketSourceFilterTest { + + private static final InetSocketAddress DISALLOWED_ADDRESS = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 12345); + private static final InetSocketAddress ALLOWED_ADDRESS = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 8080); + + private final AddressAccessPolicy addressAccessPolicy = + address -> !address.equals(DISALLOWED_ADDRESS); + + private final PacketSourceFilter filter = new PacketSourceFilter(addressAccessPolicy); + + @Test + void shouldAllowPacketsWhenSourceAllowed() { + final Envelope envelope = new Envelope(); + final Bytes incoming = Bytes.fromHexString("0x1234"); + envelope.put(Field.INCOMING, incoming); + envelope.put(Field.REMOTE_SENDER, ALLOWED_ADDRESS); + filter.handle(envelope); + + assertThat(envelope.get(Field.INCOMING)).isEqualTo(incoming); + } + + @Test + void shouldDropPacketsWhenSourceDisallowed() { + final Envelope envelope = new Envelope(); + final Bytes incoming = Bytes.fromHexString("0x1234"); + envelope.put(Field.INCOMING, incoming); + envelope.put(Field.REMOTE_SENDER, DISALLOWED_ADDRESS); + filter.handle(envelope); + + assertThat(envelope.get(Field.INCOMING)).isNull(); + } +} diff --git a/src/test/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandlerTest.java b/src/test/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandlerTest.java index d895c9124..650c89d25 100644 --- a/src/test/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandlerTest.java +++ b/src/test/java/org/ethereum/beacon/discovery/pipeline/info/FindNodeResponseHandlerTest.java @@ -7,6 +7,7 @@ import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.ethereum.beacon.discovery.AddressAccessPolicy.ALLOW_ALL; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -15,6 +16,7 @@ import java.util.List; import org.apache.tuweni.bytes.Bytes; +import org.ethereum.beacon.discovery.AddressAccessPolicy; import org.ethereum.beacon.discovery.TestUtil; import org.ethereum.beacon.discovery.TestUtil.NodeInfo; import org.ethereum.beacon.discovery.message.NodesMessage; @@ -27,8 +29,10 @@ import org.junit.jupiter.params.provider.ValueSource; class FindNodeResponseHandlerTest { + private static final Bytes PEER_ID = Bytes.fromHexString("0x1234567890ABCDEF"); private static final Bytes REQUEST_ID = Bytes.fromHexString("0x1234"); + public static final AddressAccessPolicy DISALLOW_ALL = record -> false; private final NodeSession session = mock(NodeSession.class); @BeforeEach @@ -40,7 +44,8 @@ public void setUp() { public void shouldAddReceivedRecordsToNodeTableButNotNodeBuckets() { final NodeInfo nodeInfo = TestUtil.generateNode(9000); final int distance = Functions.logDistance(PEER_ID, nodeInfo.getNodeRecord().getNodeId()); - final FindNodeResponseHandler handler = new FindNodeResponseHandler(singletonList(distance)); + final FindNodeResponseHandler handler = + new FindNodeResponseHandler(singletonList(distance), ALLOW_ALL); final List records = singletonList(nodeInfo.getNodeRecord()); final NodesMessage message = new NodesMessage(REQUEST_ID, records.size(), records); @@ -49,11 +54,26 @@ public void shouldAddReceivedRecordsToNodeTableButNotNodeBuckets() { verify(session).onNodeRecordReceived(nodeInfo.getNodeRecord()); } + @Test + public void shouldRejectReceivedRecordsThatAreDisallowed() { + final NodeInfo nodeInfo = TestUtil.generateNode(9000); + final int distance = Functions.logDistance(PEER_ID, nodeInfo.getNodeRecord().getNodeId()); + final FindNodeResponseHandler handler = + new FindNodeResponseHandler(singletonList(distance), DISALLOW_ALL); + + final List records = singletonList(nodeInfo.getNodeRecord()); + final NodesMessage message = new NodesMessage(REQUEST_ID, records.size(), records); + assertThat(handler.handleResponseMessage(message, session)).isTrue(); + + verify(session, never()).onNodeRecordReceived(nodeInfo.getNodeRecord()); + } + @Test public void shouldRejectReceivedRecordsThatAreInvalid() { final NodeInfo nodeInfo = TestUtil.generateInvalidNode(9000); final int distance = Functions.logDistance(PEER_ID, nodeInfo.getNodeRecord().getNodeId()); - final FindNodeResponseHandler handler = new FindNodeResponseHandler(singletonList(distance)); + final FindNodeResponseHandler handler = + new FindNodeResponseHandler(singletonList(distance), ALLOW_ALL); final List records = singletonList(nodeInfo.getNodeRecord()); final NodesMessage message = new NodesMessage(REQUEST_ID, records.size(), records); handler.handleResponseMessage(message, session); @@ -66,7 +86,7 @@ public void shouldRejectReceivedRecordsThatAreNotAtCorrectDistance() { final NodeInfo nodeInfo = TestUtil.generateNode(9000); final int distance = Functions.logDistance(PEER_ID, nodeInfo.getNodeRecord().getNodeId()); final FindNodeResponseHandler handler = - new FindNodeResponseHandler(singletonList(distance + 1)); + new FindNodeResponseHandler(singletonList(distance + 1), ALLOW_ALL); final List records = singletonList(nodeInfo.getNodeRecord()); final NodesMessage message = new NodesMessage(REQUEST_ID, records.size(), records); handler.handleResponseMessage(message, session); @@ -78,7 +98,7 @@ public void shouldRejectReceivedRecordsThatAreNotAtCorrectDistance() { @ValueSource(ints = {-1, 0, 17}) public void shouldRejectInvalidTotalPackets(final int numPackets) { final NodesMessage message = new NodesMessage(REQUEST_ID, numPackets, emptyList()); - final FindNodeResponseHandler handler = new FindNodeResponseHandler(emptyList()); + final FindNodeResponseHandler handler = new FindNodeResponseHandler(emptyList(), ALLOW_ALL); assertThatThrownBy(() -> handler.handleResponseMessage(message, session)) .hasMessageContaining("Invalid number of total packets") .isInstanceOf(RuntimeException.class);