Neural network: Escaping from non-global minimum traps
Deep learning has been the hot spot in recent years and shown excellent performance in many applications. In bioinformatics, Google has developed a deep learning based variant caller: DeepVariant ( https://research.googleblog.com/2017/12/deepvariant-highly-accurate-genomes.html ). In this article, I wanna talk about the several scenario that may jeopardize neural network's performance. I will try to depict these situations in the way I understand it. As I am not a machine learning expert, Any corrections are appreciated.
People with basic understanding on the stricture of neural network (NN) know that back-propagation is essential algorithm that optimizes NN's parameters. The simplified procedure is:
Based on the current weights, the final output layer calculates the weighted sum of all the previous neurons.
Establish a loss function to represent the difference between the desired output value and actual output value.
Calculate the gradient of the loss function and update the weights by taking one small step towards the opposite direction of the gradient.
Repeat step 3 iteratively until the gradient becomes extremely low. At this point, we have our model that fits the training data best. This is what we know as gradient descent as shown in figure below.
Landscape of neural network
Now let's extend the function in figure above to a neural network with a total number of 5000 weights connecting each neurons. Let's imagine a multi-dimensional landscape as depicted in figure below. We vertically cut the landscape 5000 time on different direction. Each cut represents a weight dimension and the height represents the loss function's output.
We want the gradient descent eventually push the weights down to the global minimum where, by definition, the partial gradient of each dimension out of 5000 become near zero. However, if we look at the figure carefully, global minimum is not the only place that meets this criterion. Local minimum, saddle point and plateau are also such places. At these area, the gradient on every dimension (direction) "vanishes" and the weight update become extremely slow. At the worst case, the model stops training and returns the "optimized" model. When the amount of parameter is high and the model is complex, you can imagine the complexity of the landscape. For these special area, model's ability to escape depends on the actual "landscape" of the area.
The likelihood of escaping from a local minimum depends on how "deep" the valley is. The depth of the valley is proportional to the number of samples the model fits. shallow valley probably only fits for a small portion of samples. When different mini-batch is fetch to the model, BetweenSampleVariantion can easily help the model escaping from the valley. Valleys that can eventually trap the model are usually relatively deep. That is why models are usually perform adequately.
Saddle point is place where not all gradients of each dimension is towards the same direction. Considering the high dimension, when the gradient descent become really slow, the model most likely ( P = 1-0.5**5000 ) encounters saddle point instead of local / global minimum. The likelihood of escaping from this depends on the number of dimension on which the gradient is to the opposite direction. The number of dimensions towards the opposite direction is proportional to the number of weights that fit the current training data. The more dimensions towards the downhill, the more easily the model escapes the saddle point.
At the plateau, the gradient at each dimension is still near 0 and the model updates slowly as well. The plateau area can be viewed as case when large variation exist between many training data and therefor the loss function'output value is large. If the plateau area is large and really plaint, such area poses more threat to the model training.
Despite that we have trying to reach the global minimum, we should keep in mind that the global minimum may not be the optimal parameters for the model but the place where the model overfits. Instead, there may be a number of local minimum across the landscape that give model better performance.