Algorithm Module

baconian.algo.algo.Algo

class baconian.algo.algo.Algo(env_spec: baconian.core.core.EnvSpec, name: str = 'algo', warm_up_trajectories_number=0)

Abstract class for algorithms

INIT_STATUS = 'CREATED'
STATUS_LIST = ['CREATED', 'INITED', 'TRAIN', 'TEST']
__init__(env_spec: baconian.core.core.EnvSpec, name: str = 'algo', warm_up_trajectories_number=0)

Constructor

Parameters:
  • env_spec (EnvSpec) – environment specifications
  • name (str) – name of the algorithm
  • warm_up_trajectories_number (int) – how many trajectories used to warm up the training
append_to_memory(*args, **kwargs)

For off-policy algorithm, use this API to append the data into replay buffer. samples will be read as the first argument passed into this API, like algo.append_to_memory(samples=x, …)

init()

Initialization method, such as network random initialization in Tensorflow

Returns:
is_testing

A boolean indicate the if the algorithm is in training status

Returns:True if in testing
Return type:bool
is_training

A boolean indicate the if the algorithm is in training status

Returns:True if in training
Return type:bool
predict(*arg, **kwargs)

Predict function, given the obs as input, return the action, obs will be read as the first argument passed into this API, like algo.predict(obs=x, …)

Returns:predicted action
Return type:np.ndarray
test(*arg, **kwargs) → dict

Testing API, most of the evaluation can be done by agent instead of algorithms, so this API can be skipped

Returns:test results, e.g., rewards
Return type:dict
train(*arg, **kwargs) → dict

Training API, specific arguments should be defined by each algorithms itself.

Returns:training results, e.g., loss
Return type:dict
warm_up(trajectory_data: baconian.common.sampler.sample_data.TrajectoryData)

Use some data to warm up the algorithm, e.g., compute the mean/std-dev of the state to perform normalization. Data used in warm up process will not be added into the memory :param trajectory_data: TrajectoryData object :type trajectory_data: TrajectoryData

Returns:None