From f2d25c3d581200c3fcd0ba71e0ed5f49080c59d3 Mon Sep 17 00:00:00 2001 From: currantw Date: Wed, 30 Oct 2024 10:22:43 -0700 Subject: [PATCH] Refactor implementation to use `com.github.seancfoley:ipaddress:5.4.2` (already a dependency for :core) and to match behaviour in Spark. Signed-off-by: currantw --- core/build.gradle | 1 + .../sql/expression/ip/IPFunction.java | 98 +++++++++---------- .../sql/expression/ip/IPFunctionTest.java | 68 ++++++++++--- 3 files changed, 98 insertions(+), 69 deletions(-) diff --git a/core/build.gradle b/core/build.gradle index f36777030c..c596251342 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -57,6 +57,7 @@ dependencies { api group: 'com.google.code.gson', name: 'gson', version: '2.8.9' api group: 'com.tdunning', name: 't-digest', version: '3.3' api project(':common') + implementation "com.github.seancfoley:ipaddress:5.4.2" testImplementation('org.junit.jupiter:junit-jupiter:5.9.3') testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' diff --git a/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java b/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java index c7f81bf564..1c3a7a8123 100644 --- a/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/ip/IPFunction.java @@ -5,15 +5,9 @@ package org.opensearch.sql.expression.ip; -import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.expression.function.FunctionDSL.*; - -import com.google.common.net.InetAddresses; -import java.math.BigInteger; -import java.net.InetAddress; -import java.util.regex.Matcher; -import java.util.regex.Pattern; +import inet.ipaddr.AddressStringException; +import inet.ipaddr.IPAddressString; +import inet.ipaddr.IPAddressStringParameters; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -22,13 +16,14 @@ import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.expression.function.FunctionDSL.*; + /** Utility class that defines and registers IP functions. */ @UtilityClass public class IPFunction { - private static final Pattern cidrPattern = - Pattern.compile("(?
.+)[/](?[0-9]+)"); - public void register(BuiltinFunctionRepository repository) { repository.register(cidrmatch()); } @@ -40,8 +35,8 @@ private DefaultFunctionResolver cidrmatch() { } /** - * Returns whether the given IP address is within the specified CIDR IP address range. - * Supports both IPv4 and IPv6 addresses. + * Returns whether the given IP address is within the specified CIDR IP address range. Supports + * both IPv4 and IPv6 addresses. * * @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329"). * @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or @@ -51,56 +46,51 @@ private DefaultFunctionResolver cidrmatch() { */ private ExprValue exprCidrMatch(ExprValue addressExprValue, ExprValue rangeExprValue) { - // Get address String addressString = addressExprValue.stringValue(); - if (!InetAddresses.isInetAddress(addressString)) { - return ExprValueUtils.nullValue(); - } - - InetAddress address = InetAddresses.forString(addressString); - - // Get range and network length String rangeString = rangeExprValue.stringValue(); - Matcher cidrMatcher = cidrPattern.matcher(rangeString); - if (!cidrMatcher.matches()) - throw new SemanticCheckException( - String.format("CIDR notation '%s' in not valid", rangeString)); - - String rangeAddressString = cidrMatcher.group("address"); - if (!InetAddresses.isInetAddress(rangeAddressString)) + final IPAddressStringParameters validationOptions = + new IPAddressStringParameters.Builder() + .allowEmpty(false) + .setEmptyAsLoopback(false) + .allow_inet_aton(false) + .allowSingleSegment(false) + .toParams(); + + // Get and validate IP address. + IPAddressString address = + new IPAddressString(addressExprValue.stringValue(), validationOptions); + + try { + address.validate(); + } catch (AddressStringException e) { throw new SemanticCheckException( - String.format("IP address '%s' in not valid", rangeAddressString)); - - InetAddress rangeAddress = InetAddresses.forString(rangeAddressString); - - // Address and range must use the same IP version (IPv4 or IPv6). - if (!address.getClass().equals(rangeAddress.getClass())) { - return ExprValueUtils.booleanValue(false); + String.format( + "IP address '%s' is not supported. Error details: %s", + addressString, e.getMessage())); } - int networkLengthBits = Integer.parseInt(cidrMatcher.group("networkLength")); - int addressLengthBits = address.getAddress().length * Byte.SIZE; + // Get and validate CIDR IP address range. + IPAddressString range = new IPAddressString(rangeExprValue.stringValue(), validationOptions); - if (networkLengthBits > addressLengthBits) + try { + range.validate(); + } catch (AddressStringException e) { throw new SemanticCheckException( - String.format("Network length of '%s' bits is not valid", networkLengthBits)); - - // Build bounds by converting the address to an integer, setting all the non-significant bits to - // zero for the lower bounds and one for the upper bounds, and then converting back to - // addresses. - BigInteger lowerBoundInt = InetAddresses.toBigInteger(rangeAddress); - BigInteger upperBoundInt = InetAddresses.toBigInteger(rangeAddress); + String.format( + "CIDR IP address range '%s' is not supported. Error details: %s", + rangeString, e.getMessage())); + } - int hostLengthBits = addressLengthBits - networkLengthBits; - for (int bit = 0; bit < hostLengthBits; bit++) { - lowerBoundInt = lowerBoundInt.clearBit(bit); - upperBoundInt = upperBoundInt.setBit(bit); + // Address and range must use the same IP version (IPv4 or IPv6). + if (address.isIPv4() ^ range.isIPv4()) { + throw new SemanticCheckException( + String.format( + "IP address '%s' and CIDR IP address range '%s' are not compatible. Both must be" + + " either IPv4 or IPv6.", + addressString, rangeString)); } - // Convert the address to an integer and compare it to the bounds. - BigInteger addressInt = InetAddresses.toBigInteger(address); - return ExprValueUtils.booleanValue( - (addressInt.compareTo(lowerBoundInt) >= 0) && (addressInt.compareTo(upperBoundInt) <= 0)); + return ExprValueUtils.booleanValue(range.contains(address)); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java index 96ca53ee5a..ad31402cd7 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/ip/IPFunctionTest.java @@ -5,10 +5,10 @@ package org.opensearch.sql.expression.ip; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.when; -import static org.opensearch.sql.data.model.ExprValueUtils.*; +import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE; +import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import org.junit.jupiter.api.Test; @@ -46,20 +46,45 @@ public class IPFunctionTest { @Test public void cidrmatch_invalid_address() { - assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range)); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> execute(ExprValueUtils.stringValue("INVALID"), IPv4Range)); + assertTrue( + exception.getMessage().matches("IP address 'INVALID' is not supported. Error details: .*")); } @Test public void cidrmatch_invalid_range() { - assertThrows( - SemanticCheckException.class, - () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID"))); - assertThrows( - SemanticCheckException.class, - () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32"))); - assertThrows( - SemanticCheckException.class, - () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33"))); + SemanticCheckException exception; + + exception = + assertThrows( + SemanticCheckException.class, + () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID"))); + assertTrue( + exception + .getMessage() + .matches("CIDR IP address range 'INVALID' is not supported. Error details: .*")); + + exception = + assertThrows( + SemanticCheckException.class, + () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32"))); + assertTrue( + exception + .getMessage() + .matches("CIDR IP address range 'INVALID/32' is not supported. Error details: .*")); + + exception = + assertThrows( + SemanticCheckException.class, + () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33"))); + assertTrue( + exception + .getMessage() + .matches( + "CIDR IP address range '198.51.100.0/33' is not supported. Error details: .*")); } @Test @@ -78,8 +103,21 @@ public void cidrmatch_valid_ipv6() { @Test public void cidrmatch_valid_different_versions() { - assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range)); - assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range)); + SemanticCheckException exception; + + exception = + assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, IPv6Range)); + assertEquals( + "IP address '198.51.100.1' and CIDR IP address range '2001:0db8::/32' are not compatible." + + " Both must be either IPv4 or IPv6.", + exception.getMessage()); + + exception = + assertThrows(SemanticCheckException.class, () -> execute(IPv6AddressWithin, IPv4Range)); + assertEquals( + "IP address '2001:0db8::ff00:42:8329' and CIDR IP address range '198.51.100.0/24' are not" + + " compatible. Both must be either IPv4 or IPv6.", + exception.getMessage()); } /**