public class ParameterAveragingTrainingWorker extends BaseTrainingWorker<ParameterAveragingTrainingResult>
Constructor and Description |
---|
ParameterAveragingTrainingWorker(org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcast,
boolean saveUpdater,
WorkerConfiguration configuration,
Collection<TrainingHook> trainingHooks,
Collection<TrainingListener> listeners,
StatsStorageRouterProvider routerProvider) |
public ParameterAveragingTrainingWorker(org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcast, boolean saveUpdater, WorkerConfiguration configuration, Collection<TrainingHook> trainingHooks, Collection<TrainingListener> listeners, StatsStorageRouterProvider routerProvider)
public void removeHook(TrainingHook trainingHook)
trainingHook
- the training hook to removepublic void addHook(TrainingHook trainingHook)
trainingHook
- the training hook to addpublic MultiLayerNetwork getInitialModel()
TrainingWorker
public ComputationGraph getInitialModelGraph()
TrainingWorker
public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorker
dataSet
- Data set to train onnetwork
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic ParameterAveragingTrainingResult processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(DataSet, MultiLayerNetwork, boolean)
but used when SparkTrainingStats
are being collectepublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(DataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedpublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(MultiDataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedpublic ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network)
TrainingWorker
network
- Current state of the networkpublic ParameterAveragingTrainingResult getFinalResult(ComputationGraph network)
TrainingWorker
network
- Current state of the networkpublic ParameterAveragingTrainingResult getFinalResultNoData()
TrainingWorker
public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultNoDataWithStats()
TrainingWorker
TrainingWorker.getFinalResultNoData()
but used when SparkTrainingStats
are being collectedpublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network)
TrainingWorker
TrainingWorker.getFinalResult(MultiLayerNetwork)
but used when SparkTrainingStats
are being collectedpublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph)
TrainingWorker
TrainingWorker.getFinalResult(ComputationGraph)
but used when SparkTrainingStats
are being collectedpublic WorkerConfiguration getDataConfiguration()
TrainingWorker
WorkerConfiguration
that contains information such as minibatch sizes, etcpublic long getInstanceId()
Copyright © 2020. All rights reserved.