Tandem learning example
Neuro-evolution is great for finding tailored and compact architectures, but backpropagation is the golden-standard for reaching low errors. We can combine the two optimization techniques through tandem learning. We start by importing neuralfit (for fitting), numpy (for arrays), keras (for backpropagation) and matplotlib (for plotting).
import neuralfit as nf import matplotlib.pyplot as plt import numpy as np
The next step is to define the dataset, in this case the step function with a step at x=0.5 on the domain x\in[0,1]. For this example we will linearly sample 1000 datapoints in the domain, but feel free to change the sampling method or the number of points.
x = np.linspace(0, 1, 100).reshape(-1,1) y = (x > 0.5) * 1
Afterwards we can create the NeuralFit model, specifying
1 input and
model = nf.Model(inputs=1, outputs=1)
Now comes the interesting part: we repeatedly evolve the model, convert it to a Keras model, backpropagate the model, and convert it back to a NeuralFit model. It is not particularly efficient, but we plan to incorporate full tandem learning in NeuralFit in the future. Note that for every conversion between NeuralFit and Keras, we need to recompile the model.
for i in range(50): model.compile(optimizer='alpha', loss='mse', monitors=['size']) model.evolve(x, y, epochs=10) model = model.to_keras() model.compile(optimizer='adam', loss='mse') model.fit(x, y, epochs=10) model = nf.from_keras(model)
In the above loop, we evolve for 10 epochs and backpropagate 10 epochs until a total of 1000 epochs has been reached. Feel free to change the distribution of epochs, but be cautious: while backpropagation more epochs at first results in an initially lower error, it will drive the model parameters to a local minimum resulting in less effective evolution. After running the above code, we can visualize the results.
# Get model predictions y_hat = model.predict(x) # Plot results plt.plot(x, y, label='True', color='k', linestyle='--') plt.plot(x, y_hat, label='Predicted',color='#52C560', linewidth=2) plt.show()
If all went well, you should get a plot similar to the one below!