Skip to content

Commit

Permalink
feat: implement TCP-endpoint-based throttling (#46)
Browse files Browse the repository at this point in the history
We need this feature to write test for StreamAllContext, which in turn
is one of the TODOs referenced by
ooni/probe#2654.
  • Loading branch information
bassosimone authored Feb 5, 2024
1 parent 3cc1ea5 commit 14e4ce9
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 2 deletions.
54 changes: 53 additions & 1 deletion dpithrottle.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type DPIThrottleTrafficForTLSSNI struct {
// Logger is the MANDATORY logger to use.
Logger Logger

// PLR is the OPTIONAL extra packet loss rate to apply to the packet
// PLR is the OPTIONAL extra packet loss rate to apply to the packet.
PLR float64

// SNI is the OPTIONAL offending SNI
Expand Down Expand Up @@ -70,3 +70,55 @@ func (r *DPIThrottleTrafficForTLSSNI) Filter(
}
return policy, true
}

// DPIThrottleTrafficForTCPEndpoint is a [DPIRule] that throttles traffic
// for a given TCP endpoint. The zero value is not valid. Make sure
// you initialize all fields marked as MANDATORY.
type DPIThrottleTrafficForTCPEndpoint struct {
// Delay is the OPTIONAL extra delay to add to the flow.
Delay time.Duration

// Logger is the MANDATORY logger to use.
Logger Logger

// PLR is the OPTIONAL extra packet loss rate to apply to the packet.
PLR float64

// ServerIPAddress is the MANDATORY server endpoint IP address.
ServerIPAddress string

// ServerPort is the MANDATORY server endpoint port.
ServerPort uint16
}

var _ DPIRule = &DPIThrottleTrafficForTCPEndpoint{}

// Filter implements DPIRule
func (r *DPIThrottleTrafficForTCPEndpoint) Filter(
direction DPIDirection, packet *DissectedPacket) (*DPIPolicy, bool) {
// short circuit for the return path
if direction != DPIDirectionClientToServer {
return nil, false
}

// make sure the packet is TCP and for the proper endpoint
if !packet.MatchesDestination(layers.IPProtocolTCP, r.ServerIPAddress, r.ServerPort) {
return nil, false
}

r.Logger.Infof(
"netem: dpi: throttling flow %s:%d %s:%d/%s because the endpoint is filtered",
packet.SourceIPAddress(),
packet.SourcePort(),
packet.DestinationIPAddress(),
packet.DestinationPort(),
packet.TransportProtocol(),
)
policy := &DPIPolicy{
Delay: r.Delay,
Flags: 0,
PLR: r.PLR,
Spoofed: nil,
}
return policy, true
}
155 changes: 154 additions & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ func TestDPITCPThrottleForSNI(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Log("checking for TLS flow throttling", tc.name)

// throttle the offending SNI to have high latency and hig losses
// throttle the offending SNI to have high latency and high losses
dpiEngine := netem.NewDPIEngine(log.Log)
dpiEngine.AddRule(&netem.DPIThrottleTrafficForTLSSNI{
Delay: 10 * time.Millisecond,
Expand Down Expand Up @@ -581,6 +581,159 @@ func TestDPITCPThrottleForSNI(t *testing.T) {
}
}

// TestDPITCPThrottleForTCPEndpoint verifies we can use the DPI to throttle
// connections using a specific TCP endpoint.
func TestDPITCPThrottleForTCPEndpoint(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}

// testcase describes a test case
type testcase struct {
// name is the name of the test case
name string

// endpointAddress is the address of the endpoint to block.
endpointAddress string

// endpointPort is the port of the endpoint to block.
endpointPort uint16

// checkAvgSpeed is a function the check whether
// the speed is consistent with expectations
checkAvgSpeed func(t *testing.T, speed float64)
}

var testcases = []testcase{{
name: "when the client is using a throttled endpoint",
endpointAddress: "10.0.0.1",
endpointPort: 443,
checkAvgSpeed: func(t *testing.T, speed float64) {
// See above comment regarding expected performance
// under the given RTT, MSS, and PLR constraints
const expectation = 5
if speed > expectation {
t.Fatal("goodput", speed, "above expectation", expectation)
}
},
}, {
name: "when the client is not using a throttled endpoint",
endpointAddress: "10.0.0.1",
endpointPort: 555, // different port
checkAvgSpeed: func(t *testing.T, speed float64) {
// See above comment regarding expected performance
// under the given RTT, MSS, and PLR constraints
const expectation = 5
if speed < expectation {
t.Fatal("goodput", speed, "below expectation", expectation)
}
},
}}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
t.Log("checking for TLS flow throttling", tc.name)

// throttle the offending endpoint to have high latency and high losses
dpiEngine := netem.NewDPIEngine(log.Log)
dpiEngine.AddRule(&netem.DPIThrottleTrafficForTCPEndpoint{
Delay: 10 * time.Millisecond,
Logger: log.Log,
PLR: 0.1,
ServerIPAddress: tc.endpointAddress,
ServerPort: tc.endpointPort,
})
lc := &netem.LinkConfig{
DPIEngine: dpiEngine,
LeftToRightDelay: 100 * time.Microsecond,
RightToLeftDelay: 100 * time.Microsecond,
}

// create a point-to-point topology, which consists of a single
// [Link] connecting two userspace network stacks.
topology := netem.MustNewPPPTopology(
"10.0.0.2",
"10.0.0.1",
log.Log,
lc,
)
defer topology.Close()

// make sure we have a deadline bound context
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// add DNS server to resolve the clientSNI domain
dnsConfig := netem.NewDNSConfig()
dnsConfig.AddRecord("ndt0.local", "", "10.0.0.1")
dnsServer, err := netem.NewDNSServer(log.Log, topology.Server, "10.0.0.1", dnsConfig)
if err != nil {
t.Fatal(err)
}
defer dnsServer.Close()

// start an NDT0 server in the background
ready, serverErrorCh := make(chan net.Listener, 1), make(chan error, 1)
go netem.RunNDT0Server(
ctx,
topology.Server,
net.ParseIP("10.0.0.1"),
443,
log.Log,
ready,
serverErrorCh,
true,
"ndt0.local",
"ndt0.xyz",
)

// await for the NDT0 server to be listening
listener := <-ready
defer listener.Close()

// run NDT0 client in the background and measure speed
clientErrorCh := make(chan error, 1)
perfch := make(chan *netem.NDT0PerformanceSample)
go netem.RunNDT0Client(
ctx,
topology.Client,
net.JoinHostPort("ndt0.local", "443"),
log.Log,
true,
clientErrorCh,
perfch,
)

// collect the average speed
var avgSpeed float64
for p := range perfch {
if p.Final {
avgSpeed = p.AvgSpeedMbps()
}
}

// make sure we have collected samples
if avgSpeed <= 0 {
t.Fatal("did not collect the average speed")
}

// make sure that neither the client nor the server
// reported a fundamental error
if err := <-clientErrorCh; err != nil {
t.Fatal(err)
}
if err := <-serverErrorCh; err != nil {
t.Fatal(err)
}

t.Log("measured goodput", avgSpeed)

// make sure that the speed is consistent with expectations
tc.checkAvgSpeed(t, avgSpeed)
})
}
}

// TestDPITCPResetForSNI verifies we can use the DPI to reset TCP
// connections using specific TLS SNI values.
func TestDPITCPResetForSNI(t *testing.T) {
Expand Down

0 comments on commit 14e4ce9

Please sign in to comment.