This is a JAX-based modification of the classic NEAT algorithm that uses backpropagation to optimize the weights of each network in the population in parallel.
Using this, we can quickly evolve minimal neural networks that have strong inductive biases for specific tasks. I showcased the performance on classification tasks including XOR, Circle, and Spiral.