public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResult> implements TrainingWorker<SharedTrainingResult>
Constructor and Description |
---|
SharedTrainingWorker(long instanceId,
org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcastModel,
org.apache.spark.broadcast.Broadcast<SharedTrainingConfiguration> broadcastConfiguration,
List<TrainingListener> listeners,
StatsStorageRouter router,
Boolean workerTogglePeriodicGC,
Integer workerPeriodicGCFrequency) |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
getInstanceId
public SharedTrainingWorker(long instanceId, org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcastModel, org.apache.spark.broadcast.Broadcast<SharedTrainingConfiguration> broadcastConfiguration, List<TrainingListener> listeners, StatsStorageRouter router, Boolean workerTogglePeriodicGC, Integer workerPeriodicGCFrequency)
public void removeHook(TrainingHook trainingHook)
TrainingWorker
removeHook
in interface TrainingWorker<SharedTrainingResult>
trainingHook
- the training hook to removepublic void addHook(TrainingHook trainingHook)
TrainingWorker
addHook
in interface TrainingWorker<SharedTrainingResult>
trainingHook
- the training hook to addpublic MultiLayerNetwork getInitialModel()
TrainingWorker
getInitialModel
in interface TrainingWorker<SharedTrainingResult>
public ComputationGraph getInitialModelGraph()
TrainingWorker
getInitialModelGraph
in interface TrainingWorker<SharedTrainingResult>
public SharedTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorker
processMinibatch
in interface TrainingWorker<SharedTrainingResult>
dataSet
- Data set to train onnetwork
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic SharedTrainingResult processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
processMinibatch
in interface TrainingWorker<SharedTrainingResult>
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic SharedTrainingResult processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
processMinibatch
in interface TrainingWorker<SharedTrainingResult>
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic Pair<SharedTrainingResult,SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(DataSet, MultiLayerNetwork, boolean)
but used when SparkTrainingStats
are being collecteprocessMinibatchWithStats
in interface TrainingWorker<SharedTrainingResult>
public Pair<SharedTrainingResult,SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(DataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedprocessMinibatchWithStats
in interface TrainingWorker<SharedTrainingResult>
public Pair<SharedTrainingResult,SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(MultiDataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedprocessMinibatchWithStats
in interface TrainingWorker<SharedTrainingResult>
public SharedTrainingResult getFinalResult(MultiLayerNetwork network)
TrainingWorker
getFinalResult
in interface TrainingWorker<SharedTrainingResult>
network
- Current state of the networkpublic SharedTrainingResult getFinalResult(ComputationGraph network)
TrainingWorker
getFinalResult
in interface TrainingWorker<SharedTrainingResult>
network
- Current state of the networkpublic SharedTrainingResult getFinalResultNoData()
TrainingWorker
getFinalResultNoData
in interface TrainingWorker<SharedTrainingResult>
public Pair<SharedTrainingResult,SparkTrainingStats> getFinalResultNoDataWithStats()
TrainingWorker
TrainingWorker.getFinalResultNoData()
but used when SparkTrainingStats
are being collectedgetFinalResultNoDataWithStats
in interface TrainingWorker<SharedTrainingResult>
public Pair<SharedTrainingResult,SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network)
TrainingWorker
TrainingWorker.getFinalResult(MultiLayerNetwork)
but used when SparkTrainingStats
are being collectedgetFinalResultWithStats
in interface TrainingWorker<SharedTrainingResult>
public Pair<SharedTrainingResult,SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph)
TrainingWorker
TrainingWorker.getFinalResult(ComputationGraph)
but used when SparkTrainingStats
are being collectedgetFinalResultWithStats
in interface TrainingWorker<SharedTrainingResult>
public WorkerConfiguration getDataConfiguration()
TrainingWorker
WorkerConfiguration
that contains information such as minibatch sizes, etcgetDataConfiguration
in interface TrainingWorker<SharedTrainingResult>
Copyright © 2020. All rights reserved.