December 27, 2024

Explainable networks

Explainable machine learning is my white whale as a scientist looking to leverage numerical regression against the gigabytes of data I typically need to decipher. It's a step in the direction of how can I use ML to determine the underlying physics.

Supervised machine learning (ML) boils down to two items, definition of an objective and an algorithmic means to implement regression. Today, this is typically done with the definition of a loss function between a model against existing data and in the case of a neural network (NN), updating weights and biases via backpropagation. Yet this trained NN is iconically a black box that takes inputs, spits out predictions, and leaves me with no understanding of the relation between the 2 ends. But as a scientist, this relation is what I actually care about. Explainable ML is an active field, which I know only little of, but here is one idea that may be worth pursuing by someone who has the time for it.

The goal is to turn a NN into a map between inputs and outputs. A NN is a nested configuration of junctions (neurons) connected via activation functions. Pictorially, this appears as sequential layers of nodes, all connected. Pruning is the the sub-field of ML that researches how to optimize NN performance while literally cutting the connections deemed non-essential. But what I see is a way to create a map. The basic premise is to take a simple problem, train a 2 layer NN to it, and trim the connections. The result is pictorially a map between inputs and outputs. As a scientist it is my job to check if these connections make sense, which is a lot simpler than just checking what possible relations exist within a given dataset.