Skip to content

Commit

Permalink
feat(python-example): learn it button with dropdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Az-r-ow-Kaliop committed Mar 28, 2024
1 parent ac65a29 commit acd16b8
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 15 deletions.
2 changes: 2 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

## IN PROGRESS :

- [ ] Add default arguments to python bindings
- [ ] Add verbose argument for progess bar
- [ ] Interactive Python example
- [ ] Python tests
- [ ] Optimize `Catch2`'s build
Expand Down
21 changes: 12 additions & 9 deletions examples/train-predict-MNIST/guess_it.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
margin = s(10)
screen_width = s(640)
screen_height = s(480)
guess_button_width = s(100)
guess_button_height = s(50)
button_width = s(100)
button_height = s(50)
guess_text_width = s(200)
guess_text_height = s(75)
drawing_surface_width = drawing_surface_height = s(300)
Expand All @@ -34,24 +34,27 @@
drawing_surface.set_clip(None)

manager = pygame_gui.UIManager((screen_width, screen_height))
guess_button_rect = pygame.Rect(0, 0, guess_button_width, guess_button_height)
guess_button_rect = pygame.Rect(0, 0, button_width, button_height)
guess_button_rect.bottomright = (-margin, -margin)
guess_button = pygame_gui.elements.UIButton(relative_rect=guess_button_rect, text="Guess it", manager=manager, anchors={'right': 'right', 'bottom': 'bottom'})
clear_button_rect = pygame.Rect((-(margin + (button_width * 2))), -(margin + button_height), button_width, button_height)
clear_button = pygame_gui.elements.UIButton(relative_rect=clear_button_rect, text="Clear", manager=manager, anchors={'right': 'right', 'bottom': 'bottom'})
guess_text_rect = pygame.Rect(0, margin, guess_text_width, guess_text_height)
guess_text = pygame_gui.elements.UITextBox(html_text="", relative_rect=guess_text_rect, manager=manager, anchors={'centerx': 'centerx'})
learn_button_rect = pygame.Rect(0, 0, guess_button_width, guess_button_height)
learn_button_rect.bottomleft = (margin, -margin)
learn_button = pygame_gui.elements.UIButton(relative_rect=learn_button_rect, text="Learn It", manager=manager, anchors={'left': 'left', 'bottom': 'bottom'})
dropdown_rect = pygame.Rect(0,-(margin + guess_button_height), dropdown_width, dropdown_height)
dropdown = pygame_gui.elements.UIDropDownMenu(DROPDOWN_OPTIONS, DROPDOWN_OPTIONS[0], dropdown_rect, manager, anchors={'centerx': 'centerx', 'bottom': 'bottom'})
learn_button_rect = pygame.Rect(margin, margin, button_width, button_height)
learn_button = pygame_gui.elements.UIButton(relative_rect=learn_button_rect, text="Learn It", manager=manager, anchors={'left': 'left', 'top': 'top'})
dropdown_rect = pygame.Rect(0, 0, dropdown_width, dropdown_height)
dropdown_rect.topleft = (margin, margin + button_height)
dropdown = pygame_gui.elements.UIDropDownMenu(DROPDOWN_OPTIONS, DROPDOWN_OPTIONS[0], dropdown_rect, manager=manager, anchors={'left': 'left', 'top': 'top'})

# This dict will be passed to event handlers
ui_elements = {
"guess_button": guess_button,
"guess_text": guess_text,
"drawing_surface": drawing_surface,
"learn_button": learn_button,
"dropdown": dropdown
"dropdown": dropdown,
'clear_button': clear_button
}

# Fill the main_window with white
Expand Down
41 changes: 37 additions & 4 deletions examples/train-predict-MNIST/helpers/event_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,53 @@

NNP.models.Model.load_from_file("model.bin", network)

network.setup(optimizer=NNP.optimizers.SGD(0.001), loss=NNP.LOSS.MCE)

drawing = False # Track drawing
erasing = False # Track erasing
drawing_color = COLORS["black"]
erasing_color = COLORS["white"]

def get_drawing(context):
"""
Get the drawing from the drawing_surface
Args:
context
Returns:
normalized_img: The normalized drawing
"""
drawing_image = pygame.surfarray.pixels3d(context["ui_elements"]["drawing_surface"])
grayscale_image = format_image_grayscale(drawing_image, (28, 28))
normalized_image = normalize_img(numpy.transpose(grayscale_image))
return normalized_image


def handle_ui_button_pressed(context):
if context["event"].ui_element == context["ui_elements"]["guess_button"]:
drawing_image = pygame.surfarray.pixels3d(context["ui_elements"]["drawing_surface"])
grayscale_image = format_image_grayscale(drawing_image, (28, 28))
normalized_image = normalize_img(numpy.transpose(grayscale_image))
normalized_image = get_drawing(context)
prediction = find_highest_indexes_in_matrix(network.predict([normalized_image]))
context["ui_elements"]["guess_text"].append_html_text(f"I'm guessing : {prediction[0]}<br>")

if context["event"].ui_element == context["ui_elements"]["learn_button"]:
print("I'm learning")
normalized_image = get_drawing(context)
target = float(context["ui_elements"]["dropdown"].selected_option)
loss = network.train([normalized_image], [target], 1)
context["ui_elements"]["guess_text"].append_html_text(f"I'm learning that it's a {int(target)}<br>loss : {loss}")

if context["event"].ui_element == context["ui_elements"]["dropdown"]:
print("dropdown has been clicked")

if context["event"].ui_element == context["ui_elements"]["clear_button"]:
context["ui_elements"]["drawing_surface"].fill(erasing_color)
return

def handle_dropdown_change(context):
event = context['event']
if event.ui_element == context["ui_elements"]["dropdown"]:
print("Selected Option ", event.text)
context["ui_elements"]["dropdown"].close = True
return

def handle_mouse_button_down(context):
Expand Down Expand Up @@ -64,6 +96,7 @@ def handle_quit(context):

EVENT_HANDLER_MAP = {
pygame_gui.UI_BUTTON_PRESSED: handle_ui_button_pressed,
pygame_gui.UI_DROP_DOWN_MENU_CHANGED: handle_dropdown_change,
pygame.MOUSEBUTTONDOWN: handle_mouse_button_down,
pygame.MOUSEBUTTONUP: handle_mouse_button_up,
pygame.MOUSEMOTION: handle_mouse_motion,
Expand Down
4 changes: 2 additions & 2 deletions examples/train-predict-MNIST/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from helpers.utils import *
from halo import Halo

NUM_TRAININGS = 60000
NUM_TRAININGS = 10000
NUM_PREDICTIONS = 1000
MNIST_DATASET_FILE = "./dataset/mnist.npz"

Expand All @@ -35,7 +35,7 @@
network.addLayer(NNP.layers.Dense(10, NNP.ACTIVATION.SOFTMAX, NNP.WEIGHT_INIT.LECUN))

# Setting up the networks parameters
network.setup(optimizer=NNP.optimizers.Adam(0.001), loss=NNP.LOSS.MCE)
network.setup(optimizer=NNP.optimizers.Adam(0.01), loss=NNP.LOSS.MCE)

# combining the data with the labels for later shuffling
combined = list(zip(x_train, y_train))
Expand Down
8 changes: 8 additions & 0 deletions src/bindings/NeuralNetPy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ PYBIND11_MODULE(NeuralNetPy, m) {
static_cast<double (Network::*)(
std::vector<std::vector<double>>, std::vector<double>, int,
const std::vector<std::shared_ptr<Callback>>)>(&Network::train),
py::arg("inputs"),
py::arg("targets"),
py::arg("epochs"),
py::arg("callbacks") = std::vector<std::shared_ptr<Callback>>(),
R"pbdoc(
Train the network by passing it 2 dimensional inputs (vectors).
Expand Down Expand Up @@ -458,6 +462,10 @@ PYBIND11_MODULE(NeuralNetPy, m) {
std::vector<std::vector<std::vector<double>>>,
std::vector<double>, int,
const std::vector<std::shared_ptr<Callback>>)>(&Network::train),
py::arg("inputs"),
py::arg("targets"),
py::arg("epochs"),
py::arg("callbacks") = std::vector<std::shared_ptr<Callback>>(),
R"pbdoc(
Train the network by passing it a list of 3 dimensional inputs (matrices).
Expand Down

0 comments on commit acd16b8

Please sign in to comment.