Package smile.regression
Class NeuralNetwork
- java.lang.Object
-
- smile.regression.NeuralNetwork
-
- All Implemented Interfaces:
java.io.Serializable,OnlineRegression<double[]>,Regression<double[]>
public class NeuralNetwork extends java.lang.Object implements OnlineRegression<double[]>
Multilayer perceptron neural network for regression. An MLP consists of several layers of nodes, interconnected through weighted acyclic arcs from each preceding layer to the following, without lateral or feedback connections. Each node calculates a transformed weighted linear combination of its inputs (output activations from the preceding layer), with one of the weights acting as a trainable bias connected to a constant input. The transformation, called activation function, is a bounded non-decreasing (non-linear) function, such as the sigmoid functions (ranges from 0 to 1). Another popular activation function is hyperbolic tangent which is actually equivalent to the sigmoid function in shape but ranges from -1 to 1.- Author:
- Sam Erickson
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classNeuralNetwork.ActivationFunctionstatic classNeuralNetwork.TrainerTrainer for neural networks.
-
Constructor Summary
Constructors Constructor Description NeuralNetwork(int... numUnits)Constructor.NeuralNetwork(NeuralNetwork.ActivationFunction activation, double alpha, double lambda, int... numUnits)Constructor.NeuralNetwork(NeuralNetwork.ActivationFunction activation, int... numUnits)Constructor.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NeuralNetworkclone()doublegetLearningRate()Returns the learning rate.doublegetMomentum()Returns the momentum factor.double[][]getWeight(int layer)Returns the weights of a layer.doublegetWeightDecay()Returns the weight decay factor.voidlearn(double[][] x, double[] y)Trains the neural network with the given dataset for one epoch by stochastic gradient descent.voidlearn(double[] x, double y)Online update the regression model with a new training instance.doublelearn(double[] x, double y, double weight)Update the neural network with given instance and associated target value.doublepredict(double[] x)Predicts the dependent variable of an instance.voidsetLearningRate(double eta)Sets the learning rate.voidsetMomentum(double alpha)Sets the momentum factor.voidsetWeightDecay(double lambda)Sets the weight decay factor.-
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
-
Methods inherited from interface smile.regression.Regression
predict
-
-
-
-
Constructor Detail
-
NeuralNetwork
public NeuralNetwork(int... numUnits)
Constructor. The default activation function is the logistic sigmoid function.- Parameters:
numUnits- the number of units in each layer.
-
NeuralNetwork
public NeuralNetwork(NeuralNetwork.ActivationFunction activation, int... numUnits)
Constructor.- Parameters:
activation- the activation function of output layer.numUnits- the number of units in each layer.
-
NeuralNetwork
public NeuralNetwork(NeuralNetwork.ActivationFunction activation, double alpha, double lambda, int... numUnits)
Constructor.- Parameters:
activation- the activation function of output layer.numUnits- the number of units in each layer.
-
-
Method Detail
-
clone
public NeuralNetwork clone()
- Overrides:
clonein classjava.lang.Object
-
setLearningRate
public void setLearningRate(double eta)
Sets the learning rate.- Parameters:
eta- the learning rate.
-
getLearningRate
public double getLearningRate()
Returns the learning rate.
-
setMomentum
public void setMomentum(double alpha)
Sets the momentum factor.- Parameters:
alpha- the momentum factor.
-
getMomentum
public double getMomentum()
Returns the momentum factor.
-
setWeightDecay
public void setWeightDecay(double lambda)
Sets the weight decay factor. After each weight update, every weight is simply ''decayed'' or shrunk according w = w * (1 - eta * lambda).- Parameters:
lambda- the weight decay for regularization.
-
getWeightDecay
public double getWeightDecay()
Returns the weight decay factor.
-
getWeight
public double[][] getWeight(int layer)
Returns the weights of a layer.- Parameters:
layer- the layer of netural network, 0 for input layer.
-
predict
public double predict(double[] x)
Description copied from interface:RegressionPredicts the dependent variable of an instance.- Specified by:
predictin interfaceRegression<double[]>- Parameters:
x- the instance.- Returns:
- the predicted value of dependent variable.
-
learn
public double learn(double[] x, double y, double weight)Update the neural network with given instance and associated target value. Note that this method is NOT multi-thread safe.- Parameters:
x- the training instance.y- the target value.weight- a positive weight value associated with the training instance.- Returns:
- the weighted training error before back-propagation.
-
learn
public void learn(double[] x, double y)Description copied from interface:OnlineRegressionOnline update the regression model with a new training instance. In general, this method may be NOT multi-thread safe.- Specified by:
learnin interfaceOnlineRegression<double[]>- Parameters:
x- training instance.y- response variable.
-
learn
public void learn(double[][] x, double[] y)Trains the neural network with the given dataset for one epoch by stochastic gradient descent.- Parameters:
x- training instances.y- training labels in [0, k), where k is the number of classes.
-
-