diff --git a/README.md b/README.md index 806bdec..5c9cfea 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,12 @@ The network will use every midi file in ./midi_songs to train the network. The m ## Generating music -Once you have trained the network you can generate text using **predict.py** +Once you have trained the network you can generate music using the **predict.py** script with the weight file generated in the previous step. E.g. ``` -python predict.py +python predict.py weights.hdf5 ``` -You can run the prediction file right away using the **weights.hdf5** file +You can run the prediction script right away using the **weights.hdf5** file which is a weight file genereted with the default dataset provided in the midi_songs folder diff --git a/lstm.py b/lstm.py index 7b0e1e7..fe1b406 100644 --- a/lstm.py +++ b/lstm.py @@ -31,10 +31,10 @@ def get_notes(): notes = [] for file in glob.glob("midi_songs/*.mid"): - midi = converter.parse(file) - print("Parsing %s" % file) + midi = converter.parse(file) + notes_to_parse = None try: # file has instrument parts diff --git a/predict.py b/predict.py index 8ee24fa..853c966 100644 --- a/predict.py +++ b/predict.py @@ -1,5 +1,6 @@ """ This module generates notes for a midi file using the trained neural network """ +import sys import pickle import numpy from music21 import instrument, note, stream, chord @@ -10,7 +11,7 @@ from keras.layers import BatchNormalization as BatchNorm from keras.layers import Activation -def generate(): +def generate(weight_file): """ Generate a piano midi file """ #load the notes used to train the model with open('data/notes', 'rb') as filepath: @@ -22,7 +23,7 @@ def generate(): n_vocab = len(set(notes)) network_input, normalized_input = prepare_sequences(notes, pitchnames, n_vocab) - model = create_network(normalized_input, n_vocab) + model = create_network(normalized_input, n_vocab, weight_file) prediction_output = generate_notes(model, network_input, pitchnames, n_vocab) create_midi(prediction_output) @@ -49,7 +50,7 @@ def prepare_sequences(notes, pitchnames, n_vocab): return (network_input, normalized_input) -def create_network(network_input, n_vocab): +def create_network(network_input, n_vocab, weight_file): """ create the structure of the neural network """ model = Sequential() model.add(LSTM( @@ -135,4 +136,7 @@ def create_midi(prediction_output): midi_stream.write('midi', fp='test_output.mid') if __name__ == '__main__': - generate() + if len(sys.argv) == 2: + generate(sys.argv[1]) + else: + print("You need to invoke this script with an argument specifying a weight file e.g. : python predict.py weights.hdf5")