Skip to content

Commit

Permalink
Implement source address filtering (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajsutton authored Dec 20, 2022
1 parent 7da0181 commit 9829c28
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

package org.ethereum.beacon.discovery;

import java.net.InetSocketAddress;
import org.ethereum.beacon.discovery.schema.NodeRecord;

@FunctionalInterface
public interface AddressAccessPolicy {
AddressAccessPolicy ALLOW_ALL = __ -> true;

boolean allow(InetSocketAddress address);

default boolean allow(NodeRecord record) {
return record.getTcpAddress().map(this::allow).orElse(true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.ethereum.beacon.discovery.pipeline.handler.NodeSessionRequestHandler;
import org.ethereum.beacon.discovery.pipeline.handler.OutgoingParcelHandler;
import org.ethereum.beacon.discovery.pipeline.handler.PacketDispatcherHandler;
import org.ethereum.beacon.discovery.pipeline.handler.PacketSourceFilter;
import org.ethereum.beacon.discovery.pipeline.handler.UnauthorizedMessagePacketHandler;
import org.ethereum.beacon.discovery.pipeline.handler.UnknownPacketTagToSender;
import org.ethereum.beacon.discovery.pipeline.handler.WhoAreYouPacketHandler;
Expand Down Expand Up @@ -66,6 +67,7 @@ public class DiscoveryManagerImpl implements DiscoveryManager {
private final Pipeline incomingPipeline = new PipelineImpl();
private final Pipeline outgoingPipeline = new PipelineImpl();
private final LocalNodeRecordStore localNodeRecordStore;
private final AddressAccessPolicy addressAccessPolicy;
private volatile DiscoveryClient discoveryClient;
private final NodeSessionManager nodeSessionManager;

Expand All @@ -78,8 +80,10 @@ public DiscoveryManagerImpl(
final Scheduler taskScheduler,
final ExpirationSchedulerFactory expirationSchedulerFactory,
final TalkHandler talkHandler,
final ExternalAddressSelector externalAddressSelector) {
final ExternalAddressSelector externalAddressSelector,
final AddressAccessPolicy addressAccessPolicy) {
this.localNodeRecordStore = localNodeRecordStore;
this.addressAccessPolicy = addressAccessPolicy;
final NodeRecord homeNodeRecord = localNodeRecordStore.getLocalNodeRecord();

this.discoveryServer = discoveryServer;
Expand All @@ -91,6 +95,7 @@ public DiscoveryManagerImpl(
outgoingPipeline,
expirationSchedulerFactory);
incomingPipeline
.addHandler(new PacketSourceFilter(addressAccessPolicy))
.addHandler(new IncomingDataPacker(homeNodeRecord.getNodeId()))
.addHandler(new WhoAreYouSessionResolver(nodeSessionManager))
.addHandler(new UnknownPacketTagToSender())
Expand All @@ -99,7 +104,11 @@ public DiscoveryManagerImpl(
.addHandler(new WhoAreYouPacketHandler(outgoingPipeline, taskScheduler))
.addHandler(
new HandshakeMessagePacketHandler(
outgoingPipeline, taskScheduler, nodeRecordFactory, nodeSessionManager))
outgoingPipeline,
taskScheduler,
nodeRecordFactory,
nodeSessionManager,
addressAccessPolicy))
.addHandler(new MessagePacketHandler(nodeRecordFactory))
.addHandler(new UnauthorizedMessagePacketHandler())
.addHandler(
Expand All @@ -111,7 +120,7 @@ public DiscoveryManagerImpl(
.addHandler(new BadPacketHandler());
final FluxSink<NetworkParcel> outgoingSink = outgoingMessages.sink();
outgoingPipeline
.addHandler(new OutgoingParcelHandler(outgoingSink))
.addHandler(new OutgoingParcelHandler(outgoingSink, addressAccessPolicy))
.addHandler(new NodeSessionRequestHandler())
.addHandler(nodeSessionManager)
.addHandler(new NewTaskHandler())
Expand Down Expand Up @@ -178,7 +187,7 @@ public CompletableFuture<Collection<NodeRecord>> findNodes(
new Request<>(
new CompletableFuture<>(),
reqId -> new FindNodeMessage(reqId, distances),
new FindNodeResponseHandler(distances));
new FindNodeResponseHandler(distances, addressAccessPolicy));
return executeTaskImpl(nodeRecord, request);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class DiscoverySystemBuilder {
private TalkHandler talkHandler = TalkHandler.NOOP;
private NettyDiscoveryServer discoveryServer = null;
private ExternalAddressSelector externalAddressSelector = null;
private AddressAccessPolicy addressAccessPolicy = AddressAccessPolicy.ALLOW_ALL;
private final Clock clock = Clock.systemUTC();
private final LivenessChecker livenessChecker = new LivenessChecker(clock);

Expand Down Expand Up @@ -151,6 +152,11 @@ public DiscoverySystemBuilder externalAddressSelector(
return this;
}

public DiscoverySystemBuilder addressAccessPolicy(final AddressAccessPolicy addressAccessPolicy) {
this.addressAccessPolicy = addressAccessPolicy;
return this;
}

private void createDefaults() {
newAddressHandler =
requireNonNullElseGet(
Expand Down Expand Up @@ -247,7 +253,8 @@ DiscoveryManagerImpl buildDiscoveryManager() {
schedulers.newSingleThreadDaemon("discovery-client-" + clientNumber),
expirationSchedulerFactory,
talkHandler,
externalAddressSelector);
externalAddressSelector,
addressAccessPolicy);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.tuweni.bytes.Bytes;
import org.ethereum.beacon.discovery.AddressAccessPolicy;
import org.ethereum.beacon.discovery.message.V5Message;
import org.ethereum.beacon.discovery.packet.HandshakeMessagePacket;
import org.ethereum.beacon.discovery.pipeline.Envelope;
Expand All @@ -32,16 +33,19 @@ public class HandshakeMessagePacketHandler implements EnvelopeHandler {
private final Scheduler scheduler;
private final NodeRecordFactory nodeRecordFactory;
private final NodeSessionManager nodeSessionManager;
private final AddressAccessPolicy addressAccessPolicy;

public HandshakeMessagePacketHandler(
Pipeline outgoingPipeline,
Scheduler scheduler,
NodeRecordFactory nodeRecordFactory,
NodeSessionManager nodeSessionManager) {
NodeSessionManager nodeSessionManager,
AddressAccessPolicy addressAccessPolicy) {
this.outgoingPipeline = outgoingPipeline;
this.scheduler = scheduler;
this.nodeRecordFactory = nodeRecordFactory;
this.nodeSessionManager = nodeSessionManager;
this.addressAccessPolicy = addressAccessPolicy;
}

@Override
Expand Down Expand Up @@ -97,9 +101,17 @@ public void handle(Envelope envelope) {
// Check the node record matches the ID we expect
if (!nodeRecordMaybe.map(r -> r.getNodeId().equals(session.getNodeId())).orElse(false)) {
LOG.debug(
String.format(
"Incorrect node ID for message [%s] from node %s in status %s",
packet, session.getNodeRecord(), session.getState()));
"Incorrect node ID for message [{}] from node {} in status {}",
packet,
session.getNodeRecord(),
session.getState());
markHandshakeAsFailed(envelope, session);
return;
} else if (!enr.map(addressAccessPolicy::allow).orElse(true)) {
LOG.debug(
"Rejecting handshake from node {} because the ENR was disallowed: {}",
session.getNodeRecord(),
enr);
markHandshakeAsFailed(envelope, session);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.ethereum.beacon.discovery.AddressAccessPolicy;
import org.ethereum.beacon.discovery.network.NetworkParcel;
import org.ethereum.beacon.discovery.pipeline.Envelope;
import org.ethereum.beacon.discovery.pipeline.EnvelopeHandler;
Expand All @@ -22,9 +23,12 @@ public class OutgoingParcelHandler implements EnvelopeHandler {
private static final Logger LOG = LogManager.getLogger(OutgoingParcelHandler.class);

private final FluxSink<NetworkParcel> outgoingSink;
private final AddressAccessPolicy addressAccessPolicy;

public OutgoingParcelHandler(FluxSink<NetworkParcel> outgoingSink) {
public OutgoingParcelHandler(
FluxSink<NetworkParcel> outgoingSink, final AddressAccessPolicy addressAccessPolicy) {
this.outgoingSink = outgoingSink;
this.addressAccessPolicy = addressAccessPolicy;
}

@Override
Expand All @@ -41,7 +45,10 @@ public void handle(Envelope envelope) {
if (envelope.get(Field.INCOMING) instanceof NetworkParcel) {
NetworkParcel parcel = (NetworkParcel) envelope.get(Field.INCOMING);
if (parcel.getPacket().getBytes().size() > IncomingDataPacker.MAX_PACKET_SIZE) {
LOG.error(() -> "Outgoing packet is too large, dropping it: " + parcel.getPacket());
LOG.error("Outgoing packet is too large, dropping it: {}", parcel.getPacket());
} else if (!addressAccessPolicy.allow(parcel.getDestination())) {
LOG.debug(
"Dropping outgoing packet to disallowed destination: {}", parcel.getDestination());
} else {
outgoingSink.next(parcel);
envelope.remove(Field.INCOMING);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

package org.ethereum.beacon.discovery.pipeline.handler;

import java.net.InetSocketAddress;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.ethereum.beacon.discovery.AddressAccessPolicy;
import org.ethereum.beacon.discovery.pipeline.Envelope;
import org.ethereum.beacon.discovery.pipeline.EnvelopeHandler;
import org.ethereum.beacon.discovery.pipeline.Field;
import org.ethereum.beacon.discovery.pipeline.HandlerUtil;

public class PacketSourceFilter implements EnvelopeHandler {
private static final Logger LOG = LogManager.getLogger(PacketSourceFilter.class);

private final AddressAccessPolicy addressAccessPolicy;

public PacketSourceFilter(final AddressAccessPolicy addressAccessPolicy) {
this.addressAccessPolicy = addressAccessPolicy;
}

@Override
public void handle(final Envelope envelope) {
if (!HandlerUtil.requireField(Field.REMOTE_SENDER, envelope)) {
return;
}
final InetSocketAddress sender = envelope.get(Field.REMOTE_SENDER);
if (!addressAccessPolicy.allow(sender)) {
envelope.remove(Field.INCOMING);
LOG.debug("Ignoring message from disallowed source {}", sender);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.ethereum.beacon.discovery.AddressAccessPolicy;
import org.ethereum.beacon.discovery.message.NodesMessage;
import org.ethereum.beacon.discovery.schema.NodeRecord;
import org.ethereum.beacon.discovery.schema.NodeSession;
Expand All @@ -19,11 +20,14 @@ public class FindNodeResponseHandler implements MultiPacketResponseHandler<Nodes
private static final int MAX_TOTAL_PACKETS = 16;
private final List<NodeRecord> foundNodes = new ArrayList<>();
private final Collection<Integer> distances;
private final AddressAccessPolicy addressAccessPolicy;
private int totalPackets = NOT_SET;
private int receivedPackets = 0;

public FindNodeResponseHandler(final Collection<Integer> distances) {
public FindNodeResponseHandler(
final Collection<Integer> distances, final AddressAccessPolicy addressAccessPolicy) {
this.distances = distances;
this.addressAccessPolicy = addressAccessPolicy;
}

@Override
Expand Down Expand Up @@ -53,6 +57,7 @@ public synchronized boolean handleResponseMessage(NodesMessage message, NodeSess
message.getNodeRecords().stream()
.filter(this::isValid)
.filter(record -> hasCorrectDistance(session, record))
.filter(addressAccessPolicy::allow)
.forEach(
nodeRecord -> {
foundNodes.add(nodeRecord);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.ethereum.beacon.discovery.AddressAccessPolicy.ALLOW_ALL;
import static org.ethereum.beacon.discovery.TestUtil.NODE_RECORD_FACTORY_NO_VERIFICATION;
import static org.ethereum.beacon.discovery.TestUtil.TEST_TRAFFIC_READ_LIMIT;
import static org.ethereum.beacon.discovery.TestUtil.waitFor;
Expand Down Expand Up @@ -80,7 +81,8 @@ public void test() throws Exception {
Schedulers.createDefault().newSingleThreadDaemon("tasks-1"),
expirationSchedulerFactory,
TalkHandler.NOOP,
ExternalAddressSelector.NOOP);
ExternalAddressSelector.NOOP,
ALLOW_ALL);
livenessChecker1.setPinger(discoveryManager1::ping);
DiscoveryManagerImpl discoveryManager2 =
new DiscoveryManagerImpl(
Expand All @@ -97,7 +99,8 @@ public void test() throws Exception {
Schedulers.createDefault().newSingleThreadDaemon("tasks-2"),
expirationSchedulerFactory,
TalkHandler.NOOP,
ExternalAddressSelector.NOOP);
ExternalAddressSelector.NOOP,
ALLOW_ALL);
livenessChecker2.setPinger(discoveryManager2::ping);

// 3) Expect standard 1 => 2 dialog
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.ethereum.beacon.discovery;

import static java.util.Collections.singletonList;
import static org.ethereum.beacon.discovery.AddressAccessPolicy.ALLOW_ALL;
import static org.ethereum.beacon.discovery.TestUtil.NODE_RECORD_FACTORY_NO_VERIFICATION;
import static org.ethereum.beacon.discovery.pipeline.Field.BAD_PACKET;
import static org.ethereum.beacon.discovery.pipeline.Field.MASKING_IV;
Expand Down Expand Up @@ -168,7 +169,7 @@ public void authHandlerWithMessageRoundTripTest() throws Exception {
new Request<>(
new CompletableFuture<>(),
id -> new FindNodeMessage(id, singletonList(1)),
new FindNodeResponseHandler(singletonList(1)));
new FindNodeResponseHandler(singletonList(1), ALLOW_ALL));
nodeSessionAt1For2.createNextRequest(request);

RawPacket whoAreYouRawPacket = outgoing2Packets.poll(1, TimeUnit.SECONDS);
Expand All @@ -181,7 +182,8 @@ public void authHandlerWithMessageRoundTripTest() throws Exception {
outgoingPipeline,
taskScheduler,
NODE_RECORD_FACTORY_NO_VERIFICATION,
mock(NodeSessionManager.class));
mock(NodeSessionManager.class),
ALLOW_ALL);
Envelope envelopeAt2From1 = new Envelope();
RawPacket handshakeRawPacket = outgoing1Packets.poll(1, TimeUnit.SECONDS);
envelopeAt2From1.put(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

package org.ethereum.beacon.discovery.pipeline.handler;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import org.apache.tuweni.bytes.Bytes;
import org.ethereum.beacon.discovery.AddressAccessPolicy;
import org.ethereum.beacon.discovery.network.NetworkParcel;
import org.ethereum.beacon.discovery.network.NetworkParcelV5;
import org.ethereum.beacon.discovery.packet.impl.RawPacketImpl;
import org.ethereum.beacon.discovery.pipeline.Envelope;
import org.ethereum.beacon.discovery.pipeline.Field;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.FluxSink;

class OutgoingParcelHandlerTest {

public static final RawPacketImpl PACKET =
new RawPacketImpl(Bytes.fromHexString("0x12341234123412341234123412341234"));
private static final InetSocketAddress DISALLOWED_ADDRESS =
new InetSocketAddress(InetAddress.getLoopbackAddress(), 12345);
private static final InetSocketAddress ALLOWED_ADDRESS =
new InetSocketAddress(InetAddress.getLoopbackAddress(), 8080);

@SuppressWarnings("unchecked")
private final FluxSink<NetworkParcel> outgoingSink =
(FluxSink<NetworkParcel>) mock(FluxSink.class);

private final AddressAccessPolicy addressAccessPolicy =
address -> !address.equals(DISALLOWED_ADDRESS);

private final OutgoingParcelHandler handler =
new OutgoingParcelHandler(outgoingSink, addressAccessPolicy);

@Test
void shouldNotSendPacketsToDisallowedHosts() {
final Envelope envelope = new Envelope();
envelope.put(Field.INCOMING, new NetworkParcelV5(PACKET, DISALLOWED_ADDRESS));
handler.handle(envelope);

verifyNoInteractions(outgoingSink);
}

@Test
void shouldSendPacketsToAllowedHosts() {
final Envelope envelope = new Envelope();
final NetworkParcelV5 parcel = new NetworkParcelV5(PACKET, ALLOWED_ADDRESS);
envelope.put(Field.INCOMING, parcel);
handler.handle(envelope);

verify(outgoingSink).next(parcel);
}
}
Loading

0 comments on commit 9829c28

Please sign in to comment.