Algorithm Module¶
baconian.algo.algo.Algo¶
-
class
baconian.algo.algo.
Algo
(env_spec: baconian.core.core.EnvSpec, name: str = 'algo', warm_up_trajectories_number=0)¶ Abstract class for algorithms
-
INIT_STATUS
= 'CREATED'¶
-
STATUS_LIST
= ['CREATED', 'INITED', 'TRAIN', 'TEST']¶
-
__init__
(env_spec: baconian.core.core.EnvSpec, name: str = 'algo', warm_up_trajectories_number=0)¶ Constructor
Parameters: - env_spec (EnvSpec) – environment specifications
- name (str) – name of the algorithm
- warm_up_trajectories_number (int) – how many trajectories used to warm up the training
-
append_to_memory
(*args, **kwargs)¶ For off-policy algorithm, use this API to append the data into replay buffer. samples will be read as the first argument passed into this API, like algo.append_to_memory(samples=x, …)
-
init
()¶ Initialization method, such as network random initialization in Tensorflow
Returns:
-
is_testing
¶ A boolean indicate the if the algorithm is in training status
Returns: True if in testing Return type: bool
-
is_training
¶ A boolean indicate the if the algorithm is in training status
Returns: True if in training Return type: bool
-
predict
(*arg, **kwargs)¶ Predict function, given the obs as input, return the action, obs will be read as the first argument passed into this API, like algo.predict(obs=x, …)
Returns: predicted action Return type: np.ndarray
-
test
(*arg, **kwargs) → dict¶ Testing API, most of the evaluation can be done by agent instead of algorithms, so this API can be skipped
Returns: test results, e.g., rewards Return type: dict
-
train
(*arg, **kwargs) → dict¶ Training API, specific arguments should be defined by each algorithms itself.
Returns: training results, e.g., loss Return type: dict
-
warm_up
(trajectory_data: baconian.common.sampler.sample_data.TrajectoryData)¶ Use some data to warm up the algorithm, e.g., compute the mean/std-dev of the state to perform normalization. Data used in warm up process will not be added into the memory :param trajectory_data: TrajectoryData object :type trajectory_data: TrajectoryData
Returns: None
-