-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_logic.py
769 lines (661 loc) · 29 KB
/
graph_logic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
import streamlit as st
import pandas as pd
import prototype as pt
import plotly.graph_objects as go
import re
import requests
from utils import log
class FilterData:
def __init__(self, continents=[], countries=[], venues=[], min_publication_count=0, publication_types=[], research_areas=[]):
self.continents = continents
self.countries = countries
self.venues = venues
self.publication_types = publication_types
self.research_areas = research_areas
self.min_publication_count = min_publication_count
def is_any_list_empty(self):
if (
not self.continents
or not self.countries
or not self.venues
or not self.min_publication_count
or not self.publication_types
or not self.research_areas
):
return True
return False
# Display all the filters that the user can select
def display_filters():
if "filters" not in st.session_state:
st.session_state.filters = FilterData()
if st.session_state.filters.is_any_list_empty():
# Concept for getting the filters:
# 1. Read the csv
# 2. Get the specific column
# 3. Convert it to a list
# 4. Sort the list ascending
# 5. Convert the sorted list into a tuple for future processes
#
# Concwept applies to all the other filters as well
country_continent_data = pd.read_csv("filters/Countries.csv")
country_continent_data = pd.concat(
[
country_continent_data,
pd.DataFrame(
{
"Country": ["Unknown"],
"Continent": ["Unknown"],
},
),
],
)
st.session_state.country_continent_dataframe = country_continent_data
st.session_state.min_venue_publications = 0
st.session_state.venue_data = pd.read_csv("filters/Venues.csv")
publication_types_data = pd.read_csv("filters/PublicationTypes.csv")
research_areas_data = pd.read_csv("filters/ResearchAreas.csv")
continents = sorted(list(set(country_continent_data["Continent"])))
countries = sorted(list(country_continent_data["Country"]))
venues = update_min_venue_publications(1)
publication_types = sorted(list(publication_types_data["PublicationType"]))
research_areas = sorted(list(research_areas_data["ResearchArea"]))
st.session_state.filters = FilterData(
continents,
countries,
venues,
0,
publication_types,
research_areas,
)
prefill_graph()
with st.sidebar:
st.subheader("Filters")
widget_research_areas = st.multiselect(
"Filter by Research Area$\\newline$(selected conferences):",
st.session_state.filters.research_areas,
key="research_area",
)
widget_publication_types = st.multiselect(
"Filter by publication type:",
st.session_state.filters.publication_types,
key="publication_type",
)
widget_venues = st.multiselect(
"Filter by Conference/Journals:",
st.session_state.filters.venues,
format_func=format_function,
key="venue"
)
widget_min_publication_count = st.number_input(
"Minimum number of publications per venue:",
min_value=0, max_value=100000, value=1, step=5,
on_change=update_min_venue_publications,
key="min_publication_count"
)
widget_continents = st.multiselect(
"Filter by Continent$\\newline$(only authors with known affiliation):",
st.session_state.filters.continents,
key="cont",
)
if widget_continents != st.session_state.widget_continents:
st.session_state.widget_continents = widget_continents
update_available_countries()
widget_countries = st.multiselect(
"Filter by Country$\\newline$(only authors with known affiliation):",
st.session_state.filters.countries,
key="country",
)
widget_author_position = st.radio(
"Filter by Gender Author Position:",
(
"First author woman",
"Middle author woman",
"Last author woman",
"Any author woman",
"First author man",
"Middle author man",
"Last author man",
"Any author man",
),
key="author_position",
)
st.button("Clear Filters", on_click=clear_filters)
# Only submit the newest changes after the Button was clicked, prevents the
# graph to update if the user hasn't done all filters yet
button = st.button("**Submit and Compare**")
if button:
if st.session_state.is_first_submit:
st.session_state.is_first_submit = False
update_graph(
widget_venues,
widget_min_publication_count,
widget_countries,
widget_continents,
widget_publication_types,
widget_author_position,
widget_research_areas,
st.session_state.widget_data_representation,
)
def format_function(name):
"""
formats the venue options, to display their number of publications
"""
return f"{st.session_state.pub_counts[st.session_state.filters.venues.index(name)]} | {name}"
def update_min_venue_publications(minimum=None):
"""
Gets called by on_change attribute of the number input.
Updates the list of available venues according to their number of publications
"""
venues = []
st.session_state.pub_counts = []
# iterates over all venues and filters out the ones with less publications than min_publication_count
try:
for venue in st.session_state.venue_data.iterrows():
# df.iterrows returns (Index, Series(data)) so we have to get rid of the index using [1]
if venue[1]["NumOfPublications"] >= st.session_state.min_publication_count:
venues.append(venue[1]["Name"])
st.session_state.pub_counts.append(venue[1]["NumOfPublications"])
except AttributeError:
for venue in st.session_state.venue_data.iterrows():
# df.iterrows returns (Index, Series(data)) so we have to get rid of the index using [1]
if venue[1]["NumOfPublications"] >= minimum:
venues.append(venue[1]["Name"])
st.session_state.pub_counts.append(venue[1]["NumOfPublications"])
# removes the few conferences, that include a " char
venues = [x for x in venues if '"' not in x]
st.session_state.filters.venues = venues
return venues
def clear_history():
st.session_state.y_columns = []
st.session_state.graph = None
def clear_filters():
st.session_state["author_position"] = "First author woman"
st.session_state["cont"] = []
st.session_state["venue"] = []
st.session_state["country"] = []
st.session_state["publication_type"] = []
st.session_state["research_area"] = []
# Update the year range
# As soon as the year range changes, the graph
# will be rebuild
# The minimum value and maximum value
# automatically get converted into a list between
# these two values
def update_year_range():
# When the user sets the year range to 2 exact same values, e.g. 2023 and 2023,
# it will apply a range that is the selected year and the seelcted year - 5
# If the user selects the minimum possible values twice, it will apply a range
# with the selected year and the selected year + 5
if st.session_state.year_range[0] == st.session_state.year_range[1]:
if st.session_state.year_range[0] < st.session_state.min_max[0] + 5:
st.session_state["year_range"] = (
st.session_state.year_range[0],
st.session_state.year_range[1] + 5,
)
else:
st.session_state["year_range"] = (
st.session_state.year_range[0] - 5,
st.session_state.year_range[1],
)
st.session_state.graph_years = list(
range(
st.session_state.year_range[0],
# "+ 1" is to include the highest selected year.
# If, for example, the highest year selected is 2023, it
# wouldn't include 2023 in the query without the + 1
st.session_state.year_range[1] + 1,
)
)
paint_graph()
def update_available_countries():
df = st.session_state.country_continent_dataframe
filtered_countries = ()
# Check if the continents to filter is not empty
if not st.session_state.widget_continents:
# If it is empty, return a selection for all countries
filtered_countries = tuple(list(df["Country"]))
else:
# If it is not empty (the user filters by countries)
# it will go through the whole list to filter
# Get all the countries for each continent
# And adds them into one result tuple
for i in range(len(st.session_state.widget_continents)):
filtered_countries = filtered_countries + tuple(
list(df[df["Continent"] == st.session_state.widget_continents[i]]["Country"])
)
# At the end, the tuple will get sorted
filtered_countries = sorted(filtered_countries)
# And gets inserted into the country filter
st.session_state.filters.countries = filtered_countries
def prefill_graph():
if st.session_state.is_first_run == True:
continents = ["Europe", "Asia", "North America", "South America", "Africa", "Oceania"]
st.session_state.is_first_submit = False
for i in continents:
update_graph(
[],
0,
[],
[i],
[],
"First author woman",
[],
"Relative numbers",
)
# st.session_state["cont"] = [continents[-1]]
st.session_state.is_first_run = False
st.session_state.is_first_submit = True
# Insert all the data gotten by the form into the session state and populate the graph
def update_graph(
widget_venues,
widget_min_publication_count,
widget_countries,
widget_continents,
widget_publication_types,
widget_author_position,
widget_research_areas,
widget_data_representation,
):
(
st.session_state.widget_venues,
st.session_state.widget_min_publication_count,
st.session_state.widget_countries,
st.session_state.widget_continents,
st.session_state.widget_publication_types,
st.session_state.widget_author_position,
st.session_state.widget_research_areas,
st.session_state.widget_data_representation,
) = (
widget_venues,
widget_min_publication_count,
widget_countries,
widget_continents,
widget_publication_types,
widget_author_position,
widget_research_areas,
widget_data_representation,
)
populate_graph(
widget_venues,
widget_min_publication_count,
widget_countries,
widget_continents,
widget_publication_types,
widget_author_position,
widget_research_areas,
)
# Creates Dynamic queries based on selection and
# runs the query to generate the count to populate the line graphs
def populate_graph(venue, min_publication_count, country, cont, publication_type, author_position, research_area):
if st.session_state.is_first_submit:
return
# the column/fiter names for each selection
y_name = ""
# Creates query
# For each available filter, check if the user has filtered something there
# If so, go through every selection and add them as a filter group (statement OR statement OR...)
def build_filter(filter_list, field_name, y_name):
if not filter_list:
return "", y_name
filter_str = "({})".format(
" or ".join(
f'alto.{field_name} = "{item}"' if item != "Unknown" else f"{field_name} IS NULL" for item in filter_list
)
)
y_name += ", ".join(filter_list) + ", "
return filter_str, y_name
f_1, y_name = build_filter(venue, "Venue", y_name)
f_2, y_name = build_filter(research_area, "ResearchArea", y_name)
f_3, y_name = build_filter(country, "Country", y_name)
f_4, y_name = build_filter(cont, "Continent", y_name)
f_6, y_name = build_filter(publication_type, "PublicationType", y_name)
f_7 = f"v.NumOfPublications >= {min_publication_count}"
# if there is already a venue in the filter it doesnt make sense to communikate the min pub count
if min_publication_count > 1 and venue == []:
y_name += f"min. pub. count: {min_publication_count} "
author_position_filters = {
"First author woman": ('Position = "1"', "woman"),
"Last author woman": ("CAST(Position AS INT) = AuthorCount", "woman"),
"Middle author woman": ("Position > 1 AND CAST(Position AS INT) < AuthorCount", "woman"),
"First author man": ('Position = "1"', "man"),
"Last author man": ("CAST(Position AS INT) = AuthorCount", "man"),
"Middle author man": ("Position > 1 AND CAST(Position AS INT) < AuthorCount", "man"),
}
sql_gender = ""
if author_position in {"Any author woman", "Any author man"}:
f_5 = ""
y_name += author_position
sql_gender = "woman" if author_position == "Any author woman" else "man"
elif author_position in author_position_filters:
filter_str, sql_gender = author_position_filters[author_position]
f_5 = f"({filter_str})"
y_name += author_position
else:
f_5 = ""
sql_logic = [f_1, f_2, f_3, f_4, f_5, f_6, f_7]
newf = ""
f_count = 0
# Combine each filter group with an AND operation
if not all(not f for f in sql_logic):
for f in sql_logic:
if f != "":
if f_count > 0:
newf = newf + " AND "
f_count += 1
newf = newf + f
# Convert the data from the range selector into a list
# that includes all the years within this range
year = list(
range(
list(st.session_state.year_range)[0],
list(st.session_state.year_range)[1] + 1,
)
)
# Basic SQL query structure
# The query creates a table with Year | Absolute | Relative columns
# It first counts all the Publications that match the WHERE conditions and where at least one woman is found
# The same is done for relative, but this also includes a calculation of the
# percentage where the publications with woman gender are divided by all the unique publications
sql_start = f"""SELECT
alto.Year,
COUNT(DISTINCT
CASE
WHEN alto.Gender = '{sql_gender}' THEN alto.PublicationID
END
) AS Absolute,
COUNT(DISTINCT
CASE
WHEN alto.Gender = '{sql_gender}' THEN alto.PublicationID
END
) * 100 / COUNT(DISTINCT alto.PublicationID) AS Relative
FROM AllTogether alto
INNER JOIN Venue v ON alto.Venue = v.Name
"""
sql_filter_start = """\nWHERE """
sql_end = """\nGROUP BY alto.Year;"""
# Checks if the query was already requested
# .startswith() is used, because item.name has a "(total:...) at the end"
if not [item for item in st.session_state.y_columns if item.name.startswith(y_name)]:
with st.spinner("Creating graph..."):
# If the query wasn't already requested, combine the different parts of it
sql_query = sql_start + (sql_filter_start if newf else "") + newf + sql_end
# Run the sql query and process it, so that it's ready for the graph
grouped_absolutes, grouped_relatives = query_and_process(sql_query)
# saves the absolutes in session_state for later use
# .split(",")[0] is used to obtain the continent, from the name "Europe, First author Women" for example
st.session_state.grouped_absolutes[y_name.split(",")[0]] = sum(grouped_absolutes["Absolute"])
# Write a line of code that gets the sum of grouped_absolutes over years from result
# and store it in a variable called total_absolutes
total_absolutes = sum(grouped_absolutes['Absolute'])
y_name = y_name + f" (Total: {total_absolutes})"
# Set the specific graph color with colors and the modulo
# of the length of colors. This ensures, that the graph color of
# one specific graph does not change if another graph is added
# The first color is the theme color, the other ones the standard
# plotly colors
COLORS = [
"#b1073b",
"#636EFA",
"#00CC96",
"#AB63FA",
"#FFA15A",
"#19D3F3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52",
]
color_index = len(st.session_state.y_columns) % len(COLORS)
# Add all the gotten data into the y_columns session state,
# That provides the data for the graph history, change between
# Relative and Absolute numbers and some other features
st.session_state.y_columns.append(
pt.GraphData(
y_name,
True,
grouped_absolutes.sort_index().to_dict()["Absolute"],
grouped_relatives.sort_index().to_dict()["Relative"],
COLORS[color_index],
), ),
# If statement to prevent logging the default graphs
if (len(cont) == 1 and cont[0] != "Unknown") and not any([
research_area, publication_type, venue, country
]) and author_position == "First author woman":
pass
else:
try:
requests.get('http://localhost:6502/log_graph_creation',
params={
'research_areas': research_area,
'publication_types': publication_type,
'venues': venue,
'continents': cont,
'countries': country,
'author_position': author_position,
})
except requests.exceptions.RequestException as e:
print(f"Error logging graph: {e}")
else:# if graph was already requested, do nothing
return
# The graph_years are important for displaying only the
# Selected years on the chart
st.session_state.graph_years = year
# Visualize the collected data
paint_graph()
@st.cache_data(max_entries=1000, show_spinner=False)
def query_and_process(sql_query):
# Run the sql query and convert it to a pandas dataframe
output = pd.read_sql(sql_query, st.session_state.connection)
# Drop the columns that are not needed for the specific use case
# And set the Year as the index
# Remove 2024 from response as well, because the data is not relevant
grouped_absolutes = output.drop("Relative", axis=1).set_index("Year").drop(2024, axis=0, errors="ignore")
grouped_relatives = output.drop("Absolute", axis=1).set_index("Year").drop(2024, axis=0, errors="ignore")
# Get all the available years that the user could have selected
# and check if some of them are not in the output data
#
# It is necessary to have every year, including these with 0 values
# inside of the list for further operation
available_years = list(range(st.session_state.min_max[0], st.session_state.min_max[1] + 1))
for i in available_years:
if i not in grouped_absolutes.index:
grouped_absolutes.loc[i] = {"Absolute": 0} # type: ignore
if i not in grouped_relatives.index:
grouped_relatives.loc[i] = {"Relative": 0} #type: ignore
return grouped_absolutes, grouped_relatives
# Determines the font color of the hover
# Based on the luminance of the background color (trace color)
def get_hover_font_color(bg_color):
# Convert hex color to RGB
hex_color = re.search(r"^#?([A-Fa-f0-9]{6})$", bg_color)
if hex_color:
rgb_color = tuple(int(hex_color.group(1)[i : i + 2], 16) for i in (0, 2, 4))
else:
raise ValueError(f"Invalid hex color: {bg_color}")
# Calculate the luminance
r, g, b = [x / 255.0 for x in rgb_color]
rgb_color = [x / 12.92 if x <= 0.03928 else ((x + 0.055) / 1.055) ** 2.4 for x in (r, g, b)]
luminance = 0.2126 * rgb_color[0] + 0.7152 * rgb_color[1] + 0.0722 * rgb_color[2]
# Choose font color based on luminance
if luminance > 0.3:
return "black"
else:
return "white"
# Functionality for visualizing the collected data
def paint_graph():
# Set pandas graph processing to plotly library
pd.options.plotting.backend = "plotly"
# Get only the dataframes that were
# selected in graph history
line_graph_data = get_selected_df()
# Filter the data by the year range, that the user wants
# to be displayed
line_graph_data = line_graph_data[
(line_graph_data["Year"] >= min(st.session_state.graph_years))
& (line_graph_data["Year"] <= max(st.session_state.graph_years))
]
line_graph_data = line_graph_data.set_index("Year")
fig = go.Figure()
filtered_y_columns = [
y_column for y_column in st.session_state.y_columns
if y_column.isVisible
]
data_column_names = list(line_graph_data.columns)
# Create the figure
fig = go.Figure()
for idx, column in enumerate(data_column_names):
if column in [y_column.name for y_column in st.session_state.y_columns]:
index = filtered_y_columns[idx]
value_title = (
"Count"
if st.session_state.widget_data_representation == "Absolute numbers"
else "Share of Publications"
)
# filters out zeros
filtered_data = [(x, y) for x, y in zip(line_graph_data.index, line_graph_data[column]) if y != 0]
# splits data into seperate lists
filtered_x = [x for x, y in filtered_data]
filtered_y = [y for x, y in filtered_data]
if(st.session_state.widget_data_representation == "Relative numbers"):
customdata = [
f"{v}%" if (v == 0 or index.absoluteData[k] == 0) else
f"{v}% ({index.absoluteData[k]}/{int(index.absoluteData[k] / (v / 100))})"# calculates %
for k, v in index.relativeData.items()
if st.session_state.graph_years[0] <= k <=
st.session_state.graph_years[-1] and line_graph_data[column][k] != 0
]
else:
customdata =[
index.absoluteData[k]
for k, v in index.relativeData.items()
if st.session_state.graph_years[0] <= k <= st.session_state.graph_years[-1] and line_graph_data[column][k] != 0
]
else:
filtered_data = line_graph_data
filtered_x = line_graph_data.index
filtered_y = line_graph_data[column]
customdata = [
index.absoluteData[k]
for k, v in index.relativeData.items()
if st.session_state.graph_years[0] <= k <=
st.session_state.graph_years[-1]
]
fig.add_trace(
go.Scatter(
x=filtered_x,
y=filtered_y,
mode="lines",
name=column,
line_shape="spline",
line_smoothing=0.7,
meta=[column, value_title],
# The list to display the value alongside with the absolute numbers
# if the selected data representation is "Relative numbers"
customdata=customdata,
hovertemplate=
# Plotly's hovertemplate uses %{...} syntax to access data from the plot's data
# and customdata attributes. To access the name of the index, we use %{meta[0]}.
# To access the x-axis value, we use %{x}, and to access the y-axis value, we use
# %{customdata}.
"<b>%{meta[0]}</b><br>Year: %{x}<br>%{meta[1]}: %{customdata}<extra></extra>",
# We can customize the appearance of the hover label using the hoverlabel attribute.
# The bgcolor attribute sets the background color, and the font attribute sets
# the font properties.
hoverlabel=dict(
bgcolor=index.color,
font=dict(
color=get_hover_font_color(index.color),
),
),
marker=dict(color=filtered_y_columns[idx].color),
))
fig.update_layout(
font_size=13,
legend_title="Filters (click to toggle on/off)",
autosize=True,
height=500,
legend=dict(
orientation="v",
yanchor="top",
y=-0.1,
xanchor="left",
x=0,
),
)
fig.update_xaxes(tickformat="d")
fig.update_yaxes(automargin=True, rangemode="tozero")
if st.session_state.widget_data_representation == "Relative numbers":
fig.update_layout(yaxis_title="Share of Publications", yaxis_ticksuffix="%")
else:
fig.update_layout(yaxis_title="Number of Publications")
# Update the session state graph
# -> Because of this update, it will
# automatically rebuild the chart
st.session_state.graph = fig
# Get all the graphs that the user selected in "Graph History"
def get_selected_df():
true_df = pd.DataFrame()
# Go through every possible dataframe
for i in range(len(st.session_state.y_columns)):
# If the user has selected the graph in Graph history,
# Do further operations for displaying it
if st.session_state.y_columns[i].isVisible is True:
# Access the different stored values.
# If Absolute numbers is selected, get the data for absolute numbers
# and the same for relative numbers
if st.session_state.widget_data_representation == "Absolute numbers":
true_df.insert(
loc=len(true_df.columns),
column=st.session_state.y_columns[i].name,
value=pd.Series(list(st.session_state.y_columns[i].absoluteData.values())),
)
else:
true_df.insert(
loc=len(true_df.columns),
column=st.session_state.y_columns[i].name,
value=pd.Series(list(st.session_state.y_columns[i].relativeData.values())),
)
# Insert year column
#
# There will be enough values for every filter for sure,
# because missing values from the original query were filled
# with 0 in previous steps
true_df.insert(
loc=0,
column="Year",
value=pd.Series(list(
range(st.session_state.min_max[0], st.session_state.min_max[1] + 1),
)),
)
return true_df
# Display the checkboxes for the Graph history with the logic of selecting/unselecting the checkboxes
def display_graph_checkboxes():
if len(st.session_state.y_columns) != 0:
st.subheader("Graph history")
# Sort the graph names ascending
st.session_state.y_columns.sort(key=lambda x: x.name, reverse=True)
# Create dynamic variables for each graph checkbox,
# So every checkbox can be handled individually
for i in range(len(st.session_state.y_columns)):
globals()["graph_checkbox_%s" % i] = st.checkbox(
st.session_state.y_columns[i].name,
value=st.session_state.y_columns[i].isVisible,
# Setting the key makes this specific value
# Accessible via the session state
key=f"graph_checkbox_{i}",
on_change=change_graph_checkbox,
args=(i,),
)
st.session_state.y_columns[i].isVisible = globals()["graph_checkbox_%s" % i]
# Set the variable to it's own value due to a bug by streamlit
# Over which we have no influence on
globals()["graph_checkbox_%s" % i] = globals()["graph_checkbox_%s" % i]
st.button("Clear History", on_click=clear_history)
def change_graph_checkbox(i):
if st.session_state[f"graph_checkbox_{i}"]:
st.session_state.y_columns[i].isVisible = True
if not st.session_state[f"graph_checkbox_{i}"]:
st.session_state.y_columns[i].isVisible = False
# After a checkbox has been changed,
# Automatically repaint the graph
paint_graph()