public class TinyYOLO extends ZooModel
ImageNet+VOC weights for this model are available and have been converted from https://pjreddie.com/darknet/yolo/ using https://github.com/allanzelener/YAD2K and the following code.
String filename = "tiny-yolo-voc.h5";
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(filename, false);
INDArray priors = Nd4j.create(priorBoxes);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(1.0)
.updater(new Adam.Builder().learningRate(1e-3).build())
.l2(0.00001)
.activation(Activation.IDENTITY)
.trainingWorkspaceMode(workspaceMode)
.inferenceWorkspaceMode(workspaceMode)
.build();
ComputationGraph model = new TransferLearning.GraphBuilder(graph)
.fineTuneConfiguration(fineTuneConf)
.addLayer("outputs",
new Yolo2OutputLayer.Builder()
.boundingBoxPriors(priors)
.build(),
"conv2d_9")
.setOutputs("outputs")
.build();
System.out.println(model.summary(InputType.convolutional(416, 416, 3)));
ModelSerializer.writeModel(model, "tiny-yolo-voc_dl4j_inference.v1.zip", false);
The channels of the 416x416 input images need to be in RGB order (not BGR), with values normalized within [0, 1].Modifier and Type | Method and Description |
---|---|
ComputationGraphConfiguration |
conf() |
ComputationGraph |
init() |
ModelMetaData |
metaData() |
java.lang.Class<? extends Model> |
modelType() |
long |
pretrainedChecksum(PretrainedType pretrainedType) |
java.lang.String |
pretrainedUrl(PretrainedType pretrainedType) |
void |
setInputShape(int[][] inputShape) |
initPretrained, initPretrained, modelName, pretrainedAvailable
public java.lang.String pretrainedUrl(PretrainedType pretrainedType)
public long pretrainedChecksum(PretrainedType pretrainedType)
public java.lang.Class<? extends Model> modelType()
public ComputationGraphConfiguration conf()
public ComputationGraph init()
public ModelMetaData metaData()
public void setInputShape(int[][] inputShape)