There are many visualization methods for neural networks, but most of them are limited to layered visualizations. Models evolved by NeuralFit are not strictly layered, but it might still be useful to visualize them. In this example we will treat an easy visualization method to get you started. We start by importing neuralfit (for the model), matplotlib (for plots) and networkx (for graph structures).
import neuralfit as nf import matplotlib.pyplot as plt import networkx as nx from itertools import count
Next, we create a random model to serve as an example. Feel free to use your own model here.
model = nf.Model(1,1,size=5)
We can use the
get_connections() function to gather all neurons and connections that shape the network. For more information on these methods, visit the documentation.
nodes = model.get_nodes() connections = model.get_connections()
With these nodes and connections, we can easily construct a directed graph using networkx.
graph = nx.DiGraph() graph.add_nodes_from(nodes) graph.add_edges_from(connections)
At this point we can already plot the graph, but it is nice to color neurons by their activation function. With the code below we can extract the unique activation functions present in the network, and give each of them a distinct color. It might be hard to grasp this code, but just note that the
colors array ends up simply containing an integer for each node, where each integer represents a certain activation functions.
groups = set(nx.get_node_attributes(graph,'activation').values()) mapping = dict(zip(sorted(groups),count())) colors = [mapping[graph.nodes[n]['activation']] for n in graph.nodes()]
Now we can finally visualize the graph. Using
with_labels=True we also annotate each neuron with an index that represents their activation order: neurons get activated in order, starting from the lowest index working up to the highest index. This is important to add to the graph, because unlike layered visualization methods it is otherwise not clear which neuron feeds into other neurons (or vice versa!).
nx.draw(graph, node_color=colors, with_labels=True) plt.show()
If all went well, you should get a plot similar to the one below! Note that
0 indicates the input neuron, while
4 indicates the output neuron.