Skip to content

Commit

Permalink
Refactor implementation to use `com.github.seancfoley:ipaddress:5.4.2…
Browse files Browse the repository at this point in the history
…` (already a dependency for :core) and to match behaviour in Spark.

Signed-off-by: currantw <[email protected]>
  • Loading branch information
currantw committed Oct 30, 2024
1 parent b4a233c commit f2d25c3
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 69 deletions.
1 change: 1 addition & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies {
api group: 'com.google.code.gson', name: 'gson', version: '2.8.9'
api group: 'com.tdunning', name: 't-digest', version: '3.3'
api project(':common')
implementation "com.github.seancfoley:ipaddress:5.4.2"

testImplementation('org.junit.jupiter:junit-jupiter:5.9.3')
testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1'
Expand Down
98 changes: 44 additions & 54 deletions core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,9 @@

package org.opensearch.sql.expression.ip;

import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.*;

import com.google.common.net.InetAddresses;
import java.math.BigInteger;
import java.net.InetAddress;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import inet.ipaddr.AddressStringException;
import inet.ipaddr.IPAddressString;
import inet.ipaddr.IPAddressStringParameters;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand All @@ -22,13 +16,14 @@
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;

import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.*;

/** Utility class that defines and registers IP functions. */
@UtilityClass
public class IPFunction {

private static final Pattern cidrPattern =
Pattern.compile("(?<address>.+)[/](?<networkLength>[0-9]+)");

public void register(BuiltinFunctionRepository repository) {
repository.register(cidrmatch());
}
Expand All @@ -40,8 +35,8 @@ private DefaultFunctionResolver cidrmatch() {
}

/**
* Returns whether the given IP address is within the specified CIDR IP address range.
* Supports both IPv4 and IPv6 addresses.
* Returns whether the given IP address is within the specified CIDR IP address range. Supports
* both IPv4 and IPv6 addresses.
*
* @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329").
* @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or
Expand All @@ -51,56 +46,51 @@ private DefaultFunctionResolver cidrmatch() {
*/
private ExprValue exprCidrMatch(ExprValue addressExprValue, ExprValue rangeExprValue) {

// Get address
String addressString = addressExprValue.stringValue();
if (!InetAddresses.isInetAddress(addressString)) {
return ExprValueUtils.nullValue();
}

InetAddress address = InetAddresses.forString(addressString);

// Get range and network length
String rangeString = rangeExprValue.stringValue();

Matcher cidrMatcher = cidrPattern.matcher(rangeString);
if (!cidrMatcher.matches())
throw new SemanticCheckException(
String.format("CIDR notation '%s' in not valid", rangeString));

String rangeAddressString = cidrMatcher.group("address");
if (!InetAddresses.isInetAddress(rangeAddressString))
final IPAddressStringParameters validationOptions =
new IPAddressStringParameters.Builder()
.allowEmpty(false)
.setEmptyAsLoopback(false)
.allow_inet_aton(false)
.allowSingleSegment(false)
.toParams();

// Get and validate IP address.
IPAddressString address =
new IPAddressString(addressExprValue.stringValue(), validationOptions);

try {
address.validate();
} catch (AddressStringException e) {
throw new SemanticCheckException(
String.format("IP address '%s' in not valid", rangeAddressString));

InetAddress rangeAddress = InetAddresses.forString(rangeAddressString);

// Address and range must use the same IP version (IPv4 or IPv6).
if (!address.getClass().equals(rangeAddress.getClass())) {
return ExprValueUtils.booleanValue(false);
String.format(
"IP address '%s' is not supported. Error details: %s",
addressString, e.getMessage()));
}

int networkLengthBits = Integer.parseInt(cidrMatcher.group("networkLength"));
int addressLengthBits = address.getAddress().length * Byte.SIZE;
// Get and validate CIDR IP address range.
IPAddressString range = new IPAddressString(rangeExprValue.stringValue(), validationOptions);

if (networkLengthBits > addressLengthBits)
try {
range.validate();
} catch (AddressStringException e) {
throw new SemanticCheckException(
String.format("Network length of '%s' bits is not valid", networkLengthBits));

// Build bounds by converting the address to an integer, setting all the non-significant bits to
// zero for the lower bounds and one for the upper bounds, and then converting back to
// addresses.
BigInteger lowerBoundInt = InetAddresses.toBigInteger(rangeAddress);
BigInteger upperBoundInt = InetAddresses.toBigInteger(rangeAddress);
String.format(
"CIDR IP address range '%s' is not supported. Error details: %s",
rangeString, e.getMessage()));
}

int hostLengthBits = addressLengthBits - networkLengthBits;
for (int bit = 0; bit < hostLengthBits; bit++) {
lowerBoundInt = lowerBoundInt.clearBit(bit);
upperBoundInt = upperBoundInt.setBit(bit);
// Address and range must use the same IP version (IPv4 or IPv6).
if (address.isIPv4() ^ range.isIPv4()) {
throw new SemanticCheckException(
String.format(
"IP address '%s' and CIDR IP address range '%s' are not compatible. Both must be"
+ " either IPv4 or IPv6.",
addressString, rangeString));
}

// Convert the address to an integer and compare it to the bounds.
BigInteger addressInt = InetAddresses.toBigInteger(address);
return ExprValueUtils.booleanValue(
(addressInt.compareTo(lowerBoundInt) >= 0) && (addressInt.compareTo(upperBoundInt) <= 0));
return ExprValueUtils.booleanValue(range.contains(address));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.sql.expression.ip;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.*;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -46,20 +46,45 @@ public class IPFunctionTest {

@Test
public void cidrmatch_invalid_address() {
assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range));
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() -> execute(ExprValueUtils.stringValue("INVALID"), IPv4Range));
assertTrue(
exception.getMessage().matches("IP address 'INVALID' is not supported. Error details: .*"));
}

@Test
public void cidrmatch_invalid_range() {
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID")));
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32")));
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33")));
SemanticCheckException exception;

exception =
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID")));
assertTrue(
exception
.getMessage()
.matches("CIDR IP address range 'INVALID' is not supported. Error details: .*"));

exception =
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32")));
assertTrue(
exception
.getMessage()
.matches("CIDR IP address range 'INVALID/32' is not supported. Error details: .*"));

exception =
assertThrows(
SemanticCheckException.class,
() -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33")));
assertTrue(
exception
.getMessage()
.matches(
"CIDR IP address range '198.51.100.0/33' is not supported. Error details: .*"));
}

@Test
Expand All @@ -78,8 +103,21 @@ public void cidrmatch_valid_ipv6() {

@Test
public void cidrmatch_valid_different_versions() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range));
SemanticCheckException exception;

exception =
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, IPv6Range));
assertEquals(
"IP address '198.51.100.1' and CIDR IP address range '2001:0db8::/32' are not compatible."
+ " Both must be either IPv4 or IPv6.",
exception.getMessage());

exception =
assertThrows(SemanticCheckException.class, () -> execute(IPv6AddressWithin, IPv4Range));
assertEquals(
"IP address '2001:0db8::ff00:42:8329' and CIDR IP address range '198.51.100.0/24' are not"
+ " compatible. Both must be either IPv4 or IPv6.",
exception.getMessage());
}

/**
Expand Down

0 comments on commit f2d25c3

Please sign in to comment.