Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Apr 29, 2024
1 parent 71ab960 commit fb16984
Show file tree
Hide file tree
Showing 6 changed files with 985 additions and 8 deletions.
444 changes: 444 additions & 0 deletions evaluation_LR.txt

Large diffs are not rendered by default.

221 changes: 221 additions & 0 deletions evaluation_hard_coded.txt

Large diffs are not rendered by default.

137 changes: 137 additions & 0 deletions script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from collections import Counter, defaultdict
import json
import os
import pickle
import shutil

import pandas as pd
from sdv.datasets.demo import download_demo, get_available_demos
from sdv.metadata.multi_table import MultiTableMetadata
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np
from sdv._utils import train_foreign_key_detector

def dump_relationships(metadata, outdir):
relationships = set()
for relation in metadata.relationships:
relationships.add((
relation['parent_table_name'],
relation['parent_primary_key'],
relation['child_table_name'],
relation['child_foreign_key']
))
with open(f'{outdir}/relationships.pkl', 'wb') as f:
pickle.dump(relationships, f)

def store_datasets():
if os.path.exists('test_set'):
answer = input('Test set already exists. Press "y" to overwrite: ')
if answer != 'y':
return
shutil.rmtree('test_set')

os.mkdir('test_set')
for demo_name in get_available_demos('multi_table')['dataset_name']:
outdir = f'test_set/{demo_name}'
os.mkdir(outdir)
data, metadata = download_demo('multi_table', demo_name)
for table_name, table_data in data.items():
table_data.to_csv(f'{outdir}/{table_name}.csv', index=False)

metadata.save_to_json(f'{outdir}/metadata.json')
dump_relationships(metadata, outdir)

def confusion_matrix(set1, set2):
true_positive, false_positive, false_negative = set(), set(), set()
for key in set1:
if key in set2:
true_positive.add(key)
else:
false_positive.add(key)

for key in set2:
if key not in set1:
false_negative.add(key)

return {
'True Positive': true_positive,
'False Positive': false_positive,
'False Negative': false_negative
}

def accuracy(set1, set2):
return len(set1.intersection(set2)) / len(set1.union(set2))

def evaluate():
total, i, tp, fp, fn = 0, 0, 0, 0, 0
# total
with open('evaluation.txt', 'w') as file:
demo_names = get_available_demos('multi_table')['dataset_name']
#demo_names = ['world_v1']
for demo_name in demo_names:
with open(f'test_set/{demo_name}/relationships.pkl', 'rb') as f:
true_relationships = pickle.load(f)
with open(f'predicted/{demo_name}/relationships.pkl', 'rb') as f:
predicted_relationships = pickle.load(f)

cm = confusion_matrix(predicted_relationships, true_relationships)
ac = accuracy(true_relationships, predicted_relationships)
file.write(f'{demo_name}\n')
file.write(f'Confusion Matrix: {cm}\n')
file.write(f'Num Foreign Keys: {len(cm['True Positive']) + len(cm['False Positive']) + len(cm['False Negative'])}\n')
file.write(f'Num True Positive: {len(true_relationships)}\n')
file.write(f'Num False Positive: {len(cm["False Positive"])}\n')
file.write(f'Num False Negative: {len(cm["False Negative"])}\n')
file.write(f'Accuracy: {ac}\n\n')
total += ac
i += 1
tp += len(cm["True Positive"])
fp += len(cm["False Positive"])
fn += len(cm["False Negative"])

file.write(f'Average Accuracy: {total / i}') # It's actually the Jaccard index
file.write(f'\nNum True Positive: {tp}')
file.write(f'\nNum False Positive: {fp}')
file.write(f'\nNum False Negative: {fn}')

def predict():
if os.path.exists('predicted'):
#answer = input('Predicted relationships already exist. Press "y" to overwrite: ')
#if answer != 'y':
# return
shutil.rmtree('predicted')

os.mkdir('predicted')
for demo_name in os.listdir('test_set'):
os.mkdir(f'predicted/{demo_name}')
data = {}
for table_name in os.listdir(f'test_set/{demo_name}'):
if table_name.endswith('.csv'):
data[table_name[:-4]] = pd.read_csv(f'test_set/{demo_name}/{table_name}', low_memory=False)

metadata = MultiTableMetadata()
metadata = metadata.load_from_json(f'test_set/{demo_name}/metadata.json')
metadata.relationships = []
metadata._detect_relationships_hard_coded(data)
dump_relationships(metadata, f'predicted/{demo_name}')

def visualize_metadata(dataset):
with open(f'test_set/{dataset}/metadata.json', 'r') as f:
metadata = json.load(f)
metadata = MultiTableMetadata.load_from_dict(metadata)
fig = metadata.visualize()
fig.view()

def add_metadata():
metadata = MultiTableMetadata()
metadata.detect_from_csvs('instacart')
metadata.save_to_json(f'instacart/metadata.json')

#store_datasets()
#predict()
#evaluate()
#visualize_metadata('world_v1')
#train_foreign_key_detector()
add_metadata()
83 changes: 82 additions & 1 deletion sdv/_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""Miscellaneous utility functions."""
import operator
import os
import pickle
import uuid
import warnings
from collections import defaultdict
from collections import Counter, defaultdict
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from sklearn.discriminant_analysis import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

from sdv import version
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError
Expand Down Expand Up @@ -409,3 +415,78 @@ def generate_synthesizer_id(synthesizer):
synth_version = version.public
unique_id = ''.join(str(uuid.uuid4()).split('-'))
return f'{class_name}_{synth_version}_{unique_id}'

def _generate_feature_vector(data, foreign_key):
parent_name = foreign_key[0]
parent_col, child_col = data[foreign_key[0]][foreign_key[1]], data[foreign_key[2]][foreign_key[3]]
parent_set, child_set = set(parent_col), set(child_col)

return [
len(child_set) / (len(parent_set) + 1e-5),
len(child_set) / (len(child_col) + 1e-5),
1.0 if parent_col.name == child_col.name else 0.0,
1.0 if child_col.name.lower().endswith('id') or child_col.name.lower().endswith('key') else 0.0,
1.0 if parent_name[:-1] in child_col else 0.0,
]

def confusion_matrix(set1, set2):
true_positive, false_positive, false_negative = set(), set(), set()
for key in set1:
if key in set2:
true_positive.add(key)
else:
false_positive.add(key)

for key in set2:
if key not in set1:
false_negative.add(key)

return {
'True Positive': true_positive,
'False Positive': false_positive,
'False Negative': false_negative
}

def train_foreign_key_detector():
"""Generate a foreign key detection model using logistic regression and pickle it.
This function is used to create and train a foreign key detection model.
"""
features, target = np.empty(shape=(0,5)), np.empty(shape=(0,))
pipeline = Pipeline([
('scaler', StandardScaler()),
('detector', LogisticRegression())
])

# Load the data
for demo_name in os.listdir('test_set'):
with open(f'test_set/{demo_name}/relationships.pkl', 'rb') as f:
true_relationships = pickle.load(f)
with open(f'predicted/{demo_name}/relationships.pkl', 'rb') as f:
predicted_relationships = pickle.load(f)

data = {}
for table_name in os.listdir(f'test_set/{demo_name}'):
if table_name.endswith('.csv'):
data[table_name[:-4]] = pd.read_csv(f'test_set/{demo_name}/{table_name}', low_memory=False)

cm = confusion_matrix(predicted_relationships, true_relationships)
for foreign_key in cm['True Positive']:
features = np.vstack((features, _generate_feature_vector(data, foreign_key)))
target = np.append(target, 1.)

for foreign_key in cm['False Positive']:
features = np.vstack((features, _generate_feature_vector(data, foreign_key)))
target = np.append(target, 0.)

pipeline.fit(features, target)
with open('trained_model.pkl', 'wb') as f:
pickle.dump(pipeline, f)


def predict_foreign_keys(data, parent_candidate, primary_key, child_candidate, column_name, threshold):
features = np.array(_generate_feature_vector(data, (parent_candidate, primary_key, child_candidate, column_name))).reshape(1, -1)
trained_model = pickle.load(open('trained_model.pkl', 'rb'))
if trained_model.predict_proba(features)[0, 1] > threshold:
return True
return False
2 changes: 1 addition & 1 deletion sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_data(modality, output_folder_name, in_memory_directory):
for filename, file_ in in_memory_directory.items():
if filename.endswith('.csv'):
table_name = Path(filename).stem
data[table_name] = pd.read_csv(io.StringIO(file_.decode()))
data[table_name] = pd.read_csv(io.StringIO(file_.decode()), low_memory=False)

if modality != 'multi_table':
data = data.popitem()[1]
Expand Down
Loading

0 comments on commit fb16984

Please sign in to comment.