public class StandardDQN extends BaseDQNAlgorithm
qNetworkNextObservation, targetQNetworkNextObservation
gamma, qNetworkSource
Constructor and Description |
---|
StandardDQN(TargetQNetworkSource qTargetNetworkSource,
double gamma) |
StandardDQN(TargetQNetworkSource qTargetNetworkSource,
double gamma,
double errorClamp) |
Modifier and Type | Method and Description |
---|---|
protected double |
computeTarget(int batchIdx,
double reward,
boolean isTerminal)
In litterature, this corresponds to:
Q(s_t, a_t) = R_{t+1} + \gamma * max_{a}Q_{tar}(s_{t+1}, a) |
protected void |
initComputation(INDArray observations,
INDArray nextObservations)
Called just before the calculation starts
|
computeTDTargets
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma)
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp)
protected void initComputation(INDArray observations, INDArray nextObservations)
BaseTDTargetAlgorithm
initComputation
in class BaseDQNAlgorithm
observations
- A INDArray of all observations stacked on dimension 0nextObservations
- A INDArray of all next observations stacked on dimension 0protected double computeTarget(int batchIdx, double reward, boolean isTerminal)
computeTarget
in class BaseTDTargetAlgorithm
batchIdx
- The index in the batch of the current transitionreward
- The reward of the current transitionisTerminal
- True if it's the last transition of the "game"Copyright © 2020. All rights reserved.