Modifier and Type | Field and Description |
---|---|
protected EarlyStoppingConfiguration<ComputationGraph> |
ComputationGraphSpace.earlyStoppingConfiguration |
protected EarlyStoppingConfiguration<ComputationGraph> |
ComputationGraphSpace.Builder.earlyStoppingConfiguration |
Modifier and Type | Method and Description |
---|---|
ComputationGraphSpace.Builder |
ComputationGraphSpace.Builder.earlyStoppingConfiguration(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration)
Early stopping configuration (optional).
|
Modifier and Type | Method and Description |
---|---|
abstract double |
BaseNetScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator) |
double |
EvaluationScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator) |
double |
RegressionScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator) |
double |
ROCScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator) |
double |
TestSetAccuracyScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator)
Deprecated.
|
double |
TestSetF1ScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator)
Deprecated.
|
double |
TestSetLossScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator) |
double |
TestSetRegressionScoreFunction.score(ComputationGraph graph,
DataSetIterator iterator)
Deprecated.
|
abstract double |
BaseNetScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator) |
double |
EvaluationScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator) |
double |
RegressionScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator) |
double |
ROCScoreFunction.score(ComputationGraph net,
MultiDataSetIterator iterator) |
double |
TestSetAccuracyScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator)
Deprecated.
|
double |
TestSetF1ScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator)
Deprecated.
|
double |
TestSetLossScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator) |
double |
TestSetRegressionScoreFunction.score(ComputationGraph graph,
MultiDataSetIterator iterator)
Deprecated.
|
Modifier and Type | Method and Description |
---|---|
static Evaluation |
ScoreUtil.getEvaluation(ComputationGraph model,
DataSetIterator testData)
Get the evaluation
for the given model and test dataset
|
static Evaluation |
ScoreUtil.getEvaluation(ComputationGraph model,
MultiDataSetIterator testData)
Get the evaluation
for the given model and test dataset
|
static double |
ScoreUtil.score(ComputationGraph model,
DataSetIterator testData,
boolean average)
Score based on the loss function
|
static double |
ScoreUtil.score(ComputationGraph model,
DataSetIterator testSet,
RegressionValue regressionValue)
Run a
RegressionEvaluation
over a DataSetIterator |
static double |
ScoreUtil.score(ComputationGraph model,
MultiDataSetIterator testData,
boolean average)
Score based on the loss function
|
static double |
ScoreUtil.score(ComputationGraph model,
MultiDataSetIterator testSet,
RegressionValue regressionValue) |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
LocalFileGraphSaver.getBestModel() |
ComputationGraph |
LocalFileGraphSaver.getLatestModel() |
Modifier and Type | Method and Description |
---|---|
void |
LocalFileGraphSaver.saveBestModel(ComputationGraph net,
double score) |
void |
LocalFileGraphSaver.saveLatestModel(ComputationGraph net,
double score) |
Modifier and Type | Method and Description |
---|---|
double |
DataSetLossCalculatorCG.calculateScore(ComputationGraph network)
Deprecated.
|
Constructor and Description |
---|
EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
DataSetIterator train) |
EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
DataSetIterator train,
EarlyStoppingListener<ComputationGraph> listener)
Constructor for training using a
DataSetIterator |
EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
MultiDataSetIterator train,
EarlyStoppingListener<ComputationGraph> listener)
Constructor for training using a
MultiDataSetIterator |
Modifier and Type | Method and Description |
---|---|
GraphVertex |
ElementWiseVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
FrozenVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
abstract GraphVertex |
GraphVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype)
Create a
GraphVertex instance, for the given computation graph,
given the configuration instance. |
GraphVertex |
L2NormalizeVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
L2Vertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
LayerVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
MergeVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
PoolHelperVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
PreprocessorVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
ReshapeVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
ScaleVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
ShiftVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
StackVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
SubsetVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
GraphVertex |
UnstackVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
Modifier and Type | Method and Description |
---|---|
GraphVertex |
DuplicateToTimeSeriesVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
LastTimeStepVertex |
LastTimeStepVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
ReverseTimeSeriesVertex |
ReverseTimeSeriesVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
Modifier and Type | Method and Description |
---|---|
GraphVertex |
SameDiffVertex.instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype) |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
ComputationGraph.clone() |
ComputationGraph |
ComputationGraph.convertDataType(@NonNull DataType dataType)
Return a copy of the network with the parameters and activations set to use the specified (floating point) data type.
|
static ComputationGraph |
ComputationGraph.load(File f,
boolean loadUpdater)
Restore a ComputationGraph to a file, saved using
save(File) or ModelSerializer |
Modifier and Type | Field and Description |
---|---|
protected ComputationGraph |
BaseGraphVertex.graph |
Constructor and Description |
---|
BaseGraphVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
DataType dataType) |
Constructor and Description |
---|
ElementWiseVertex(ComputationGraph graph,
String name,
int vertexIndex,
ElementWiseVertex.Op op,
DataType dataType) |
ElementWiseVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
ElementWiseVertex.Op op,
DataType dataType) |
InputVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] outputVertices,
DataType dataType) |
L2NormalizeVertex(ComputationGraph graph,
String name,
int vertexIndex,
int[] dimension,
double eps,
DataType dataType) |
L2NormalizeVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
int[] dimension,
double eps,
DataType dataType) |
L2Vertex(ComputationGraph graph,
String name,
int vertexIndex,
double eps,
DataType dataType) |
L2Vertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
double eps,
DataType dataType) |
LayerVertex(ComputationGraph graph,
String name,
int vertexIndex,
Layer layer,
InputPreProcessor layerPreProcessor,
boolean outputVertex,
DataType dataType)
Create a network input vertex:
|
LayerVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
Layer layer,
InputPreProcessor layerPreProcessor,
boolean outputVertex,
DataType dataType) |
MergeVertex(ComputationGraph graph,
String name,
int vertexIndex,
DataType dataType,
int mergeAxis) |
MergeVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
DataType dataType,
int mergeAxis) |
PoolHelperVertex(ComputationGraph graph,
String name,
int vertexIndex,
DataType dataType) |
PoolHelperVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
DataType dataType) |
PreprocessorVertex(ComputationGraph graph,
String name,
int vertexIndex,
InputPreProcessor preProcessor,
DataType dataType) |
PreprocessorVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
InputPreProcessor preProcessor,
DataType dataType) |
ReshapeVertex(ComputationGraph graph,
String name,
int vertexIndex,
char order,
int[] newShape,
int[] maskShape,
DataType dataType) |
ReshapeVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
char order,
int[] newShape,
int[] maskShape,
DataType dataType) |
ScaleVertex(ComputationGraph graph,
String name,
int vertexIndex,
double scaleFactor,
DataType dataType) |
ScaleVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
double scaleFactor,
DataType dataType) |
ShiftVertex(ComputationGraph graph,
String name,
int vertexIndex,
double shiftFactor,
DataType dataType) |
ShiftVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
double shiftFactor,
DataType dataType) |
StackVertex(ComputationGraph graph,
String name,
int vertexIndex,
DataType dataType) |
StackVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
DataType dataType) |
SubsetVertex(ComputationGraph graph,
String name,
int vertexIndex,
int from,
int to,
DataType dataType) |
SubsetVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
int from,
int to,
DataType dataType) |
UnstackVertex(ComputationGraph graph,
String name,
int vertexIndex,
int from,
int stackSize,
DataType dataType) |
UnstackVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
int from,
int stackSize,
DataType dataType) |
Constructor and Description |
---|
DuplicateToTimeSeriesVertex(ComputationGraph graph,
String name,
int vertexIndex,
String inputVertexName,
DataType dataType) |
DuplicateToTimeSeriesVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
String inputName,
DataType dataType) |
LastTimeStepVertex(ComputationGraph graph,
String name,
int vertexIndex,
String inputName,
DataType dataType) |
LastTimeStepVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
String inputName,
DataType dataType) |
ReverseTimeSeriesVertex(ComputationGraph graph,
String name,
int vertexIndex,
String inputName,
DataType dataType) |
Constructor and Description |
---|
SameDiffGraphVertex(SameDiffVertex config,
ComputationGraph graph,
String name,
int vertexIndex,
INDArray paramsView,
boolean initParams,
DataType dataType) |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
KerasModel.getComputationGraph()
Build a ComputationGraph from this Keras Model configuration and import weights.
|
ComputationGraph |
KerasModel.getComputationGraph(boolean importWeights)
Build a ComputationGraph from this Keras Model configuration and (optionally) import weights.
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(InputStream modelHdf5Stream)
Load Keras (Functional API) Model saved using model.save_model(...).
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(InputStream modelHdf5Stream,
boolean enforceTrainingConfig)
Load Keras (Functional API) Model saved using model.save_model(...).
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(String modelHdf5Filename)
Load Keras (Functional API) Model saved using model.save_model(...).
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(String modelHdf5Filename,
boolean enforceTrainingConfig)
Load Keras (Functional API) Model saved using model.save_model(...).
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(String modelHdf5Filename,
int[] inputShape,
boolean enforceTrainingConfig)
Load Keras (Functional API) Model saved using model.save_model(...).
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(String modelJsonFilename,
String weightsHdf5Filename)
Load Keras (Functional API) Model for which the configuration and weights were
saved separately using calls to model.to_json() and model.save_weights(...).
|
static ComputationGraph |
KerasModelImport.importKerasModelAndWeights(String modelJsonFilename,
String weightsHdf5Filename,
boolean enforceTrainingConfig)
Load Keras (Functional API) Model for which the configuration and weights were
saved separately using calls to model.to_json() and model.save_weights(...).
|
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
MultiLayerNetwork.toComputationGraph()
Convert this MultiLayerNetwork to a ComputationGraph
|
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
TransferLearning.GraphBuilder.build()
Returns a computation graph build to specifications.
|
ComputationGraph |
TransferLearningHelper.unfrozenGraph()
Returns the unfrozen subset of the original computation graph as a computation graph
Note that with each call to featurizedFit the parameters to the original computation graph are also updated
|
Constructor and Description |
---|
GraphBuilder(ComputationGraph origGraph)
Computation Graph to tweak for transfer learning
|
TransferLearningHelper(ComputationGraph orig)
Expects a computation graph where some vertices are frozen
|
TransferLearningHelper(ComputationGraph orig,
String... frozenOutputAt)
Will modify the given comp graph (in place!)
|
Constructor and Description |
---|
ComputationGraphUpdater(ComputationGraph graph) |
ComputationGraphUpdater(ComputationGraph graph,
INDArray updaterState) |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
CheckpointListener.loadCheckpointCG(Checkpoint checkpoint)
Load a ComputationGraph for the given checkpoint
|
static ComputationGraph |
CheckpointListener.loadCheckpointCG(File rootDir,
Checkpoint checkpoint)
Load a ComputationGraph for the given checkpoint from the specified root direcotry
|
static ComputationGraph |
CheckpointListener.loadCheckpointCG(File rootDir,
int checkpointNum)
Load a ComputationGraph for the given checkpoint that resides in the specified root directory
|
ComputationGraph |
CheckpointListener.loadCheckpointCG(int checkpointNum)
Load a ComputationGraph for the given checkpoint
|
static ComputationGraph |
CheckpointListener.loadLastCheckpointCG(File rootDir)
Load the last (most recent) checkpoint from the specified root directory
|
Modifier and Type | Field and Description |
---|---|
protected ComputationGraph |
JsonModelServer.cgModel |
Constructor and Description |
---|
Builder(@NonNull ComputationGraph cgModel) |
JsonModelServer(@NonNull ComputationGraph cgModel,
InferenceAdapter<I,O> inferenceAdapter,
JsonSerializer<O> serializer,
JsonDeserializer<I> deserializer,
BinarySerializer<O> binarySerializer,
BinaryDeserializer<I> binaryDeserializer,
int port,
@NonNull InferenceMode inferenceMode,
int numWorkers) |
Modifier and Type | Field and Description |
---|---|
protected ComputationGraph |
ActorCriticCompGraph.cg |
Constructor and Description |
---|
ActorCriticCompGraph(ComputationGraph cg) |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
TrainingWorker.getInitialModelGraph()
Get the initial model when training a ComputationGraph/SparkComputationGraph
|
Modifier and Type | Method and Description |
---|---|
R |
TrainingWorker.getFinalResult(ComputationGraph graph)
Get the final result to be returned to the driver
|
Pair<R,SparkTrainingStats> |
TrainingWorker.getFinalResultWithStats(ComputationGraph graph)
As per
TrainingWorker.getFinalResult(ComputationGraph) but used when SparkTrainingStats are being collected |
R |
TrainingWorker.processMinibatch(DataSet dataSet,
ComputationGraph graph,
boolean isLast)
Process (fit) a minibatch for a ComputationGraph
|
R |
TrainingWorker.processMinibatch(MultiDataSet dataSet,
ComputationGraph graph,
boolean isLast)
Process (fit) a minibatch for a ComputationGraph using a MultiDataSet
|
Pair<R,SparkTrainingStats> |
TrainingWorker.processMinibatchWithStats(DataSet dataSet,
ComputationGraph graph,
boolean isLast)
As per
TrainingWorker.processMinibatch(DataSet, ComputationGraph, boolean) but used when SparkTrainingStats are being collected |
Pair<R,SparkTrainingStats> |
TrainingWorker.processMinibatchWithStats(MultiDataSet dataSet,
ComputationGraph graph,
boolean isLast)
As per
TrainingWorker.processMinibatch(MultiDataSet, ComputationGraph, boolean) but used when SparkTrainingStats are being collected |
Modifier and Type | Method and Description |
---|---|
EarlyStoppingResult<ComputationGraph> |
SparkEarlyStoppingGraphTrainer.pretrain() |
Modifier and Type | Method and Description |
---|---|
double |
SparkLossCalculatorComputationGraph.calculateScore(ComputationGraph network) |
Constructor and Description |
---|
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
EarlyStoppingListener<ComputationGraph> listener) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
int examplesPerFit,
int totalExamples) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
int examplesPerFit,
int totalExamples) |
Constructor and Description |
---|
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
EarlyStoppingListener<ComputationGraph> listener) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
EarlyStoppingListener<ComputationGraph> listener) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
int examplesPerFit,
int totalExamples) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<MultiDataSet> train,
int examplesPerFit,
int totalExamples) |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
SparkComputationGraph.fit(org.apache.spark.api.java.JavaRDD<DataSet> rdd)
Fit the ComputationGraph with the given data set
|
ComputationGraph |
SparkComputationGraph.fit(org.apache.spark.rdd.RDD<DataSet> rdd)
Fit the ComputationGraph with the given data set
|
ComputationGraph |
SparkComputationGraph.fit(String path)
Fit the SparkComputationGraph network using a directory of serialized DataSet objects
The assumption here is that the directory contains a number of
DataSet objects, each serialized using
DataSet.save(OutputStream) |
ComputationGraph |
SparkComputationGraph.fit(String path,
int minPartitions)
Deprecated.
|
ComputationGraph |
SparkComputationGraph.fitMultiDataSet(org.apache.spark.api.java.JavaRDD<MultiDataSet> rdd)
Fit the ComputationGraph with the given data set
|
ComputationGraph |
SparkComputationGraph.fitMultiDataSet(org.apache.spark.rdd.RDD<MultiDataSet> rdd)
Fit the ComputationGraph with the given data set
|
ComputationGraph |
SparkComputationGraph.fitMultiDataSet(String path)
Fit the SparkComputationGraph network using a directory of serialized MultiDataSet objects
The assumption here is that the directory contains a number of serialized
MultiDataSet objects |
ComputationGraph |
SparkComputationGraph.fitMultiDataSet(String path,
int minPartitions)
Deprecated.
|
ComputationGraph |
SparkComputationGraph.fitPaths(org.apache.spark.api.java.JavaRDD<String> paths)
Fit the network using a list of paths for serialized DataSet objects.
|
ComputationGraph |
SparkComputationGraph.fitPaths(org.apache.spark.api.java.JavaRDD<String> paths,
DataSetLoader loader) |
ComputationGraph |
SparkComputationGraph.fitPaths(org.apache.spark.api.java.JavaRDD<String> paths,
MultiDataSetLoader loader) |
ComputationGraph |
SparkComputationGraph.fitPathsMultiDataSet(org.apache.spark.api.java.JavaRDD<String> paths)
Fit the network using a list of paths for serialized MultiDataSet objects.
|
ComputationGraph |
SparkComputationGraph.getNetwork() |
Modifier and Type | Method and Description |
---|---|
void |
SparkComputationGraph.setNetwork(ComputationGraph network) |
Constructor and Description |
---|
SparkComputationGraph(org.apache.spark.api.java.JavaSparkContext javaSparkContext,
ComputationGraph network,
TrainingMaster trainingMaster) |
SparkComputationGraph(org.apache.spark.SparkContext sparkContext,
ComputationGraph network,
TrainingMaster trainingMaster)
Instantiate a ComputationGraph instance with the given context, network and training master.
|
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
ParameterAveragingTrainingWorker.getInitialModelGraph() |
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
SharedTrainingWorker.getInitialModelGraph() |
Modifier and Type | Method and Description |
---|---|
SharedTrainingResult |
SharedTrainingWorker.getFinalResult(ComputationGraph network) |
Pair<SharedTrainingResult,SparkTrainingStats> |
SharedTrainingWorker.getFinalResultWithStats(ComputationGraph graph) |
SharedTrainingResult |
SharedTrainingWorker.processMinibatch(DataSet dataSet,
ComputationGraph graph,
boolean isLast) |
SharedTrainingResult |
SharedTrainingWorker.processMinibatch(MultiDataSet dataSet,
ComputationGraph graph,
boolean isLast) |
Pair<SharedTrainingResult,SparkTrainingStats> |
SharedTrainingWorker.processMinibatchWithStats(DataSet dataSet,
ComputationGraph graph,
boolean isLast) |
Pair<SharedTrainingResult,SparkTrainingStats> |
SharedTrainingWorker.processMinibatchWithStats(MultiDataSet dataSet,
ComputationGraph graph,
boolean isLast) |
Modifier and Type | Method and Description |
---|---|
static ComputationGraph |
ModelSerializer.restoreComputationGraph(@NonNull File file)
Load a computation graph from a file
|
static ComputationGraph |
ModelSerializer.restoreComputationGraph(@NonNull File file,
boolean loadUpdater)
Load a computation graph from a file
|
static ComputationGraph |
ModelSerializer.restoreComputationGraph(@NonNull InputStream is)
Load a computation graph from a InputStream
|
static ComputationGraph |
ModelSerializer.restoreComputationGraph(@NonNull InputStream is,
boolean loadUpdater)
Load a computation graph from a InputStream
|
static ComputationGraph |
ModelSerializer.restoreComputationGraph(@NonNull String path)
Load a computation graph from a file
|
static ComputationGraph |
ModelSerializer.restoreComputationGraph(@NonNull String path,
boolean loadUpdater)
Load a computation graph from a file
|
static ComputationGraph |
NetworkUtils.toComputationGraph(MultiLayerNetwork net)
Convert a MultiLayerNetwork to a ComputationGraph
|
Modifier and Type | Method and Description |
---|---|
static Pair<ComputationGraph,Normalizer> |
ModelSerializer.restoreComputationGraphAndNormalizer(@NonNull File file,
boolean loadUpdater)
Restore a ComputationGraph and Normalizer (if present - null if not) from a File
|
static Pair<ComputationGraph,Normalizer> |
ModelSerializer.restoreComputationGraphAndNormalizer(@NonNull InputStream is,
boolean loadUpdater)
Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream.
|
Modifier and Type | Method and Description |
---|---|
static Double |
NetworkUtils.getLearningRate(ComputationGraph net,
String layerName)
Get the current learning rate, for the specified layer, from the network.
|
static void |
NetworkUtils.setLearningRate(ComputationGraph net,
double newLr)
Set the learning rate for all layers in the network to the specified value.
|
static void |
NetworkUtils.setLearningRate(ComputationGraph net,
ISchedule newLrSchedule)
Set the learning rate schedule for all layers in the network to the specified schedule.
|
static void |
NetworkUtils.setLearningRate(ComputationGraph net,
String layerName,
double newLr)
Set the learning rate for a single layer in the network to the specified value.
|
static void |
NetworkUtils.setLearningRate(ComputationGraph net,
String layerName,
ISchedule lrSchedule)
Set the learning rate schedule for a single layer in the network to the specified value.
|
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
Darknet19.init() |
ComputationGraph |
FaceNetNN4Small2.init() |
ComputationGraph |
InceptionResNetV1.init() |
ComputationGraph |
NASNet.init() |
ComputationGraph |
ResNet50.init() |
ComputationGraph |
SqueezeNet.init() |
ComputationGraph |
TinyYOLO.init() |
ComputationGraph |
UNet.init() |
ComputationGraph |
VGG16.init() |
ComputationGraph |
VGG19.init() |
ComputationGraph |
Xception.init() |
ComputationGraph |
YOLO2.init() |
ComputationGraph |
SqueezeNet.initPretrained(PretrainedType pretrainedType) |
Copyright © 2020. All rights reserved.