Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement RateLimiter into enroll_verb_handler and add unit test #1547

Merged
merged 4 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/at_secondary_server/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,8 @@ testing:
enrollment:
# The maximum time in hours for an enrollment to expire, beyond which any action on enrollment is forbidden.
# Default values is 48 hours.
expiryInHours: 48
expiryInHours: 48
# The maximum number of requests allowed within the time window.
maxRequestsPerTimeFrame: 5
# The duration of the time window in hours.
timeFrameInHours: 1
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import 'dart:io';

import 'package:at_secondary/src/connection/inbound/inbound_connection_metadata.dart';
import 'package:at_secondary/src/server/at_secondary_config.dart';
import 'package:at_server_spec/at_server_spec.dart';

/// A dummy implementation of [InboundConnection] class which returns a dummy inbound connection.
class DummyInboundConnection implements InboundConnection {
var metadata = InboundConnectionMetadata();

@override
int maxRequestsPerTimeFrame = AtSecondaryConfig.maxEnrollRequestsAllowed;

@override
int timeFrameInMillis = AtSecondaryConfig.timeFrameInMills;

@override
void acceptRequests(Function(String p1, InboundConnection p2) callback,
Function(List<int>, InboundConnection) streamCallback) {}
Expand Down Expand Up @@ -54,4 +61,9 @@ class DummyInboundConnection implements InboundConnection {

@override
Socket? receiverSocket;

@override
bool isRequestAllowed() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import 'dart:collection';
import 'dart:io';
import 'dart:math';

import 'package:at_secondary/src/connection/base_connection.dart';
import 'package:at_secondary/src/connection/inbound/inbound_connection_metadata.dart';
import 'package:at_secondary/src/connection/inbound/inbound_connection_pool.dart';
import 'package:at_secondary/src/connection/inbound/inbound_message_listener.dart';
import 'package:at_secondary/src/server/at_secondary_config.dart';
import 'package:at_secondary/src/server/server_context.dart';
import 'package:at_secondary/src/server/at_secondary_impl.dart';
import 'package:at_secondary/src/utils/logging_util.dart';
Expand Down Expand Up @@ -42,6 +44,17 @@ class InboundConnectionImpl extends BaseConnection
late double lowWaterMarkRatio;
late bool progressivelyReduceAllowableInboundIdleTime;

/// The maximum number of requests allowed within the specified time frame.
@override
late int maxRequestsPerTimeFrame;

/// The duration of the time frame within which requests are limited.
@override
late int timeFrameInMillis;

/// A list of timestamps representing the times when requests were made.
late final Queue<int> requestTimestampQueue;

InboundConnectionImpl(Socket? socket, String? sessionId, {this.owningPool})
: super(socket) {
metaData = InboundConnectionMetadata()
Expand Down Expand Up @@ -69,6 +82,10 @@ class InboundConnectionImpl extends BaseConnection
secondaryContext.authenticatedInboundIdleTimeMillis;
authenticatedMinAllowableIdleTimeMillis =
secondaryContext.authenticatedMinAllowableIdleTimeMillis;

maxRequestsPerTimeFrame = AtSecondaryConfig.maxEnrollRequestsAllowed;
timeFrameInMillis = AtSecondaryConfig.timeFrameInMills;
requestTimestampQueue = Queue();
}

/// Returns true if the underlying socket is not null and socket's remote address and port match.
Expand Down Expand Up @@ -230,4 +247,31 @@ class InboundConnectionImpl extends BaseConnection
metaData, 'SENT: ${BaseConnection.truncateForLogging(data)}'));
}
}

@override
bool isRequestAllowed() {
int currentTimeInMills = DateTime.now().toUtc().millisecondsSinceEpoch;
_checkAndUpdateQueue(currentTimeInMills);
if (requestTimestampQueue.length < maxRequestsPerTimeFrame) {
requestTimestampQueue.addLast(currentTimeInMills);
return true;
}
return false;
}

/// Checks and updates the request timestamp queue based on the current time.
///
/// This method removes timestamps from the queue that are older than the specified
/// time window.
///
/// [currentTimeInMillis] is the current time in milliseconds since epoch.
void _checkAndUpdateQueue(int currentTimeInMillis) {
if (requestTimestampQueue.isEmpty) return;
int calculatedTime = (currentTimeInMillis - requestTimestampQueue.first);
while (calculatedTime >= timeFrameInMillis) {
requestTimestampQueue.removeFirst();
if (requestTimestampQueue.isEmpty) break;
calculatedTime = (currentTimeInMillis - requestTimestampQueue.first);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class GlobalExceptionHandler {
exception is KeyNotFoundException ||
exception is AtConnectException ||
exception is SocketException ||
exception is AtTimeoutException) {
exception is AtTimeoutException ||
exception is AtThrottleLimitExceeded) {
logger.info(exception.toString());
await _sendResponseForException(exception, atConnection);
} else if (exception is InternalServerError) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,22 @@ class AtSecondaryConfig {
? ConfigUtil.getPubspecConfig()!['version']
: null;

static final int _enrollmentExpiryInHours = 48;

static final Map<String, String> _envVars = Platform.environment;

static String? get secondaryServerVersion => _secondaryServerVersion;

// Enrollment Configurations
static const int _enrollmentExpiryInHours = 48;
static int _maxEnrollRequestsAllowed = 5;

static final int _timeFrameInHours = 1;

// For easy of testing, duration in hours is long. Hence introduced "timeFrameInMills"
// to have a shorter time frame. This is defaulted to "_timeFrameInHours", can be modified
// via the config verb
static int _timeFrameInMills =
Duration(hours: _timeFrameInHours).inMilliseconds;

static int get enrollmentExpiryInHours => _enrollmentExpiryInHours;

// TODO: Medium priority: Most (all?) getters in this class return a default value but the signatures currently
Expand Down Expand Up @@ -716,6 +726,54 @@ class AtSecondaryConfig {
}
}

static int get maxEnrollRequestsAllowed {
// For easy of testing purpose, we need to reduce the number of requests.
// So, in testing mode, enable to modify the "maxEnrollRequestsAllowed"
// can be set via the config verb
// Defaults to value in config.yaml
if (testingMode) {
return _maxEnrollRequestsAllowed;
}
var result = _getIntEnvVar('maxEnrollRequestsAllowed');
if (result != null) {
return result;
}
try {
return getConfigFromYaml(['enrollment', 'maxRequestsPerTimeFrame']);
} on ElementNotFoundException {
return _maxEnrollRequestsAllowed;
}
}

static set maxEnrollRequestsAllowed(int value) {
_maxEnrollRequestsAllowed = value;
}

static int get timeFrameInMills {
// For easy of testing purpose, we need to reduce the time frame.
// So, in testing mode, enable to modify the "timeFrameInMills"
// can be set via the config verb
// Defaults to value in config.yaml
if (testingMode) {
return _timeFrameInMills;
}
var result = _getIntEnvVar('enrollTimeFrameInHours');
if (result != null) {
return Duration(hours: result).inMilliseconds;
}
try {
return Duration(
hours: getConfigFromYaml(['enrollment', 'timeFrameInHours']))
.inMilliseconds;
} on ElementNotFoundException {
return Duration(hours: _timeFrameInHours).inMilliseconds;
}
}

static set timeFrameInMills(int timeWindowInMills) {
_timeFrameInMills = timeWindowInMills;
}

//implementation for config:set. This method returns a data stream which subscribers listen to for updates
static Stream<dynamic>? subscribe(ModifiableConfigs configName) {
if (testingMode) {
Expand Down Expand Up @@ -786,6 +844,10 @@ class AtSecondaryConfig {
return false;
case ModifiableConfigs.doCacheRefreshNow:
return false;
case ModifiableConfigs.maxRequestsPerTimeFrame:
return maxEnrollRequestsAllowed;
case ModifiableConfigs.timeFrameInMills:
return Duration(hours: _timeFrameInHours).inMilliseconds;
}
}

Expand Down Expand Up @@ -866,7 +928,9 @@ enum ModifiableConfigs {
maxNotificationRetries,
checkCertificateReload,
shouldReloadCertificates,
doCacheRefreshNow
doCacheRefreshNow,
maxRequestsPerTimeFrame,
timeFrameInMills
}

class ModifiableConfigurationEntry {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,14 @@ class AtSecondaryServerImpl implements AtSecondaryServer {
notificationResourceManager.setMaxRetries(newCount);
QueueManager.getInstance().setMaxRetries(newCount);
});

AtSecondaryConfig.subscribe(ModifiableConfigs.maxRequestsPerTimeFrame)?.listen((maxEnrollRequestsAllowed) {
AtSecondaryConfig.maxEnrollRequestsAllowed = maxEnrollRequestsAllowed;
});

AtSecondaryConfig.subscribe(ModifiableConfigs.timeFrameInMills)?.listen((timeWindowInMills) {
AtSecondaryConfig.timeFrameInMills = timeWindowInMills;
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,17 @@ class EnrollVerbHandler extends AbstractVerbHandler {
/// and its corresponding state.
///
/// Throws "AtEnrollmentException", if the OTP provided is invalid.
/// Throws [AtThrottleLimitExceeded], if the number of requests exceed within
/// a time window.
Future<void> _handleEnrollmentRequest(
EnrollParams enrollParams,
currentAtSign,
Map<dynamic, dynamic> responseJson,
InboundConnection atConnection) async {
if (!atConnection.isRequestAllowed()) {
throw AtThrottleLimitExceeded(
'Enrollment requests have exceeded the limit within the specified time frame');
}
if (!atConnection.getMetaData().isAuthenticated) {
var otp = enrollParams.otp;
if (otp == null ||
Expand Down
2 changes: 1 addition & 1 deletion packages/at_secondary_server/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
at_utils: 3.0.15
at_chops: 1.0.4
at_lookup: 3.0.40
at_server_spec: 3.0.14
at_server_spec: 3.0.15
at_persistence_spec: 2.0.14
at_persistence_secondary_server: 3.0.57
expire_cache: ^2.0.1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import 'dart:io';

import 'package:at_secondary/src/connection/inbound/inbound_connection_impl.dart';
import 'package:at_server_spec/at_server_spec.dart';
import 'package:test/test.dart';

void main(){
group('A test to verify the rate limiter on inbound connection', () {
test('A test to verify requests exceeding the limit are rejected', () {
Socket? dummySocket;
AtConnection connection1 = InboundConnectionImpl(dummySocket, 'aaa');
(connection1 as InboundConnectionImpl).maxRequestsPerTimeFrame = 1;
connection1.timeFrameInMillis =
Duration(milliseconds: 10).inMilliseconds;
expect(connection1.isRequestAllowed(), true);
expect(connection1.isRequestAllowed(), false);
});

test('A test to verify requests after the time window are accepted',
() async {
Socket? dummySocket;
AtConnection connection1 = InboundConnectionImpl(dummySocket, 'aaa');
(connection1 as InboundConnectionImpl).maxRequestsPerTimeFrame = 1;
connection1.timeFrameInMillis = Duration(milliseconds: 2).inMilliseconds;
expect(connection1.isRequestAllowed(), true);
expect(connection1.isRequestAllowed(), false);
await Future.delayed(Duration(milliseconds: 2));
expect(connection1.isRequestAllowed(), true);
});

test('A test to verify request from different connection is allowed', () {
Socket? dummySocket;
AtConnection connection1 = InboundConnectionImpl(dummySocket, 'aaa');
AtConnection connection2 = InboundConnectionImpl(dummySocket, 'aaa');
(connection1 as InboundConnectionImpl).maxRequestsPerTimeFrame = 1;
(connection2 as InboundConnectionImpl).maxRequestsPerTimeFrame = 1;
connection1.timeFrameInMillis =
Duration(milliseconds: 10).inMilliseconds;
expect(connection1.isRequestAllowed(), true);
expect(connection1.isRequestAllowed(), false);
expect(connection2.isRequestAllowed(), true);
});
});
}
2 changes: 1 addition & 1 deletion tests/at_functional_test/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
ref: trunk
at_chops: ^1.0.1
at_lookup: ^3.0.32
at_commons: ^3.0.53
at_commons: ^3.0.55
uuid: ^3.0.7
elliptic: ^0.3.8

Expand Down
Loading