Skip to content

Commit

Permalink
Updated example to record benchmark times
Browse files Browse the repository at this point in the history
  • Loading branch information
lllangWV committed Oct 16, 2024
1 parent 1955e62 commit df7ae06
Showing 1 changed file with 55 additions and 16 deletions.
71 changes: 55 additions & 16 deletions examples/scripts/Example 1 - 3D Alexandria Database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,27 @@
config.logging_config.loggers.parquetdb.level='ERROR'
config.apply()

@timeit
def read_json(json_file):
with open(json_file, 'r') as f:
data = json.load(f)
return data

@timeit
def create_dataset(db, data, **kwargs):
db.create(data=data, **kwargs)

@timeit
def read_dataset(db,**kwargs):
return db.read(**kwargs)

@timeit

def normalize_dataset(db, **kwargs):
return db.normalize(**kwargs)

if __name__ == "__main__":
normalize=True
from_scratch=False
base_dir=os.path.join('data','external', 'alexandria', 'AlexandriaDB')


benchmark_dir=os.path.join(config.data_dir, 'benchmarks', 'alexandria')
os.makedirs(benchmark_dir, exist_ok=True)
if from_scratch and os.path.exists(base_dir):
print(f"Removing existing directory: {base_dir}")
shutil.rmtree(base_dir, ignore_errors=True)
Expand All @@ -52,14 +49,25 @@ def normalize_dataset(db, **kwargs):
db=ParquetDB(dataset_name='alexandria_3D',dir=base_dir)

print(f"Dataset dir: {db.dataset_dir}")



benchmark_dict={
'create_times':[],
'json_load_times':[],
'n_rows_per_file':[],
}
# Here, we create the dataset inside the database
start_time = time.time()
print('-'*200)
if len(os.listdir(db.dataset_dir))==0:
print("The dataset does not exist. Creating it.")
json_files=glob(os.path.join(alexandria_dir,'*.json'))
for json_file in json_files:

start_time = time.time()
data = read_json(json_file)
json_load_time = time.time() - start_time

base_name=os.path.basename(json_file)
n_materials=len(data['entries'])
Expand All @@ -69,26 +77,35 @@ def normalize_dataset(db, **kwargs):
try:
# Since we are importing alot of data it is best
# to normalize the database afterwards
start_time = time.time()
create_dataset(db,data['entries'], normalize_dataset=False)
create_time = time.time() - start_time

except Exception as e:
print(e)

print('-'*200)
print(f"Time taken to create dataset: {time.time() - start_time}")

benchmark_dict['create_times'].append(create_time)
benchmark_dict['json_load_times'].append(json_load_time)
benchmark_dict['n_rows_per_file'].append(n_materials)

# Now that the data is in the database, we can normalize it.
# This means we enfore our parquet files have a certain number of rows.
# This improve performance as there are less files to io operations in subsequent steps.
if normalize:
print("Normalizing the database")
start_time = time.time()
normalize_dataset(db,
output_format='batch_generator', # Uses the batch generator to normalize
load_format='batches', # Uses the batch generator to normalize
load_kwargs={'batch_readahead': 10, # Controls the number of batches to load in memory a head of time. This will have impacts on amount of RAM consumed
'fragment_readahead': 2, # Controls the number of files to load in memory ahead of time. This will have impacts on amount of RAM consumed
},
batch_size = 100000, # Controls the batchsize when to use when normalizing. This will have impacts on amount of RAM consumed
max_rows_per_file=500000, # Controls the max number of rows per parquet file
max_rows_per_group=500000) # Controls the max number of rows per group parquet file
benchmark_dict['normalization_time']=time.time() - start_time
else:
print("Skipping normalization. Change normalize=True to normalize the database.")
print("Done with normalization")
Expand All @@ -98,10 +115,13 @@ def normalize_dataset(db, **kwargs):


# Here we read a record from the database with id of 0
start_time = time.time()
table=read_dataset(db,
ids=[0,10,100,1000,10000,100000,1000000], # Controls which rows we want to read
output_format='table' # Controls the output format. The options are 'table', 'batch_generator', `dataset`.
)
load_format='table' # Controls the output format. The options are 'table', 'batches', `dataset`.
)
read_time = time.time() - start_time
benchmark_dict['read_ids_time']=read_time
df=table.to_pandas() # Converts the table to a pandas dataframe
print(df.head())
print(df.shape)
Expand All @@ -112,32 +132,39 @@ def normalize_dataset(db, **kwargs):


# Here we read all the records from the database, but only read the 'id' column
start_time = time.time()
table=read_dataset(db,
columns=['id'],
output_format='table')
load_format='table')
end_time = time.time() - start_time
benchmark_dict['read_single_column_time']=end_time
print(table.shape)
print('-'*200)

# With only some subset of columns, we can use built in pyarrow functions to calculate statistics of our column
start_time = time.time()
table=read_dataset(db,
columns=['energy'],
output_format='table')
load_format='table')
print(table.shape)

result = pc.min_max(table['energy'])
# The result will be a struct with 'min' and 'max' fields
min_value = result['min'].as_py()
max_value = result['max'].as_py()
benchmark_dict['min_max_time']=time.time() - start_time

print(f"Min: {min_value}, Max: {max_value}")
print('-'*200)


# Here we filter for rows that have energy above -1.0, but only read the 'id', 'energy' column
start_time = time.time()
table=read_dataset(db,
columns=['id','energy'],
filters=[pc.field('energy') > -1.0],
output_format='table')
load_format='table')
benchmark_dict['read_filtered_energy_above_-1_time']=time.time() - start_time
df=table.to_pandas() # Converts the table to a pandas dataframe
print(df.head())
print(df.shape)
Expand All @@ -146,10 +173,12 @@ def normalize_dataset(db, **kwargs):

# Here we filter for rows havea nested subfield we would like to filter by.
# In this case I want to filter the 204 space groups
start_time = time.time()
table=read_dataset(db,
columns=['id', 'data.spg'],
filters=[pc.field('data.spg') == 204],
output_format='table')
load_format='table')
benchmark_dict['read_filtered_spg_204_time']=time.time() - start_time
df=table.to_pandas() # Converts the table to a pandas dataframe
print(df.head())
print(df.shape)
Expand All @@ -159,14 +188,16 @@ def normalize_dataset(db, **kwargs):

# We can also read in batches. This will batch all the rows in 1000 and return tables and
# these tables will bet filter by the given filters and columns
start_time = time.time()
generator=read_dataset(db,
output_format='batch_generator',
load_format='batches',
batch_size=1000,
load_kwargs={'batch_readahead': 10,
'fragment_readahead': 2,
},
columns=['id', 'data.spg'],
filters=[pc.field('data.spg') == 204])
benchmark_dict['read_filtered_spg_204_1000_batches_time']=time.time() - start_time
batch_count=0
num_rows=0
for table in generator:
Expand All @@ -181,9 +212,11 @@ def normalize_dataset(db, **kwargs):

# Here we filter for rows havea nested subfield we would like to filter by.
# In this case I want to filter the 204 space groups
start_time = time.time()
table=read_dataset(db,
columns=['id', 'structure.sites'],
output_format='table')
load_format='table')
benchmark_dict['read_nested_column_selection_time']=time.time() - start_time
print(table.shape)
print(table['structure.sites'].type)
print(table['structure.sites'].combine_chunks().type)
Expand All @@ -192,12 +225,15 @@ def normalize_dataset(db, **kwargs):
# By default the database flattens nested structure for storage.
# However, we provide an option to rebuild the nested structure. This will create a new dataset in {dataset_name}_nested.
# After the creation of the new dataset, the query parameters are applied to the new dataset.
start_time = time.time()
table=read_dataset(db,
columns=['id', 'structure','data'], # Instead of using the flatten syntax, we can use the nested syntax
ids=[0],
rebuild_nested_struct=True, # When set to True to rebuild the nested structure
rebuild_nested_from_scratch=False, # When set to True, the nested structure will be rebuilt from scratch
output_format='table')
load_format='table')

benchmark_dict['read_rebuild_nested_struct_time']=time.time() - start_time
print(table.shape)
print(table['data'].type)

Expand All @@ -209,6 +245,9 @@ def normalize_dataset(db, **kwargs):
print(structure)
except Exception as e:
print(e)

with open(os.path.join(benchmark_dir, 'alexandria_benchmark.json'), 'w') as f:
json.dump(benchmark_dict, f, indent=4)



Expand Down

0 comments on commit df7ae06

Please sign in to comment.