Skip to content

Commit

Permalink
Sort studies by name and add basic study intersection check. (#1023)
Browse files Browse the repository at this point in the history
* Sort studies by name and add basic study intersection validation.

* Fix wildcard version comparison, add tests.

* Make sure to properly cover empty channel/platform lists (although this is not allowed).
  • Loading branch information
goodov authored Apr 26, 2024
1 parent 90b2af5 commit 86ceff9
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions seed/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
import proto.variations_seed_pb2 as variations_seed_pb2
import argparse
import collections

SEED_BIN_PATH = "./seed.bin"
SERIALNUMBER_PATH = "./serialnumber"
Expand All @@ -26,6 +27,73 @@
'RELEASE': study_pb2.Study.Channel.STABLE
}

def version_to_int_array(version_str):
version_list = []
if version_str is None:
return version_list

parts = version_str.split('.')
for part in parts:
if part == '*':
version_list.append(part)
break
version_list.append(int(part))

return version_list

def compare_versions(version1, version2):
min_len = None
if not version1:
min_len = 0
elif version1[-1] == '*':
version1 = version1[:-1]
min_len = len(version1)

if not version2:
min_len = 0
elif version2[-1] == '*':
version2 = version2[:-1]
if min_len is not None:
min_len = min(min_len, len(version2))
else:
min_len = len(version2)

if min_len is not None:
version1 = version1[:min_len]
version2 = version2[:min_len]

if version1 > version2:
return 1
elif version1 < version2:
return -1
else:
return 0

def test_version_comparison():
# //base/version_unittest.cc VersionTest.CompareToWildcardString
test_cases = [
["1.0", "1.*", 0],
["1.0", "0.*", 1],
["1.0", "2.*", -1],
["1.2.3", "1.2.3.*", 0],
["10.0", "1.0.*", 1],
["1.0", "3.0.*", -1],
["1.4", "1.3.0.*", 1],
["1.3.9", "1.3.*", 0],
["1.4.1", "1.3.*", 1],
["1.3", "1.4.5.*", -1],
["1.5", "1.4.5.*", 1],
["1.3.9", "1.3.*", 0],
["1.2.0.0.0.0", "1.2.*", 0],
[None, None, 0],
[None, "1", 0],
["1", None, 0],
]
for test_case in test_cases:
version1 = version_to_int_array(test_case[0])
version2 = version_to_int_array(test_case[1])
assert compare_versions(version1, version2) == test_case[2]

def validate(seed):
for study in seed['studies']:
total_proba = 0
Expand All @@ -44,6 +112,57 @@ def validate(seed):
print("platform not in ", PLATFORMS)
return False

feature_names_to_studies = collections.defaultdict(list)
for study in seed['studies']:
used_feature_names = set()
for experiment in study['experiments']:
feature_association = experiment.get('feature_association')
if feature_association:
for enable_feature in feature_association.get('enable_feature', []):
used_feature_names.add(enable_feature)
for disable_feature in feature_association.get('disable_feature', []):
used_feature_names.add(disable_feature)

for used_feature_names in used_feature_names:
feature_names_to_studies[used_feature_names].append(study)

def get_study_platforms(study):
return set(study.get('filter', {}).get('platform', []))

def get_study_channels(study):
return set(study.get('filter', {}).get('channel', []))

def get_study_version_range(study):
return [
version_to_int_array(study.get('filter', {}).get('min_version')),
version_to_int_array(study.get('filter', {}).get('max_version')),
]

def is_filter_set_intersect(a, b):
return not a or not b or a.intersection(b)

def is_version_range_intersect(range1, range2):
return compare_versions(range1[1], range2[0]) >= 0 and compare_versions(range2[1], range1[0]) >= 0

test_version_comparison()
for studies in feature_names_to_studies.values():
for i, study1 in enumerate(studies):
study1_platform = get_study_platforms(study1)
study1_channel = get_study_channels(study1)
study1_version_range = get_study_version_range(study1)
for j in range(i + 1, len(studies)):
study2 = studies[j]
study2_platform = get_study_platforms(study2)
study2_channel = get_study_channels(study2)
study2_version_range = get_study_version_range(study2)
# Check if the studies overlap in platform
if is_filter_set_intersect(study1_platform, study2_platform):
# Check if the studies overlap in channel
if is_filter_set_intersect(study1_channel, study2_channel):
# Check if the studies overlap in version
if is_version_range_intersect(study1_version_range, study2_version_range):
raise ValueError(f"Studies overlap:\n{json.dumps(study1, indent=2)}\n\n{json.dumps(study2, indent=2)}")

return True


Expand Down Expand Up @@ -213,6 +332,7 @@ def main():

print("Load", args.seed_path.name)
seed_data = json.load(args.seed_path)
seed_data['studies'].sort(key=lambda study: study['name'])

print("Validate seed data")
if not validate(seed_data):
Expand Down

0 comments on commit 86ceff9

Please sign in to comment.