How to implement a new environmentΒΆ

Similar to algorithms, environments in Baconian project also should implement the methods and attributes defined in Env class baconian/core/core.py, inheriting gym Env class.

class Env(gym.Env, Basic):
    """
    Abstract class for environment
    """
    key_list = ()
    STATUS_LIST = ('JUST_RESET', 'JUST_INITED', 'TRAIN', 'TEST', 'NOT_INIT')
    INIT_STATUS = 'NOT_INIT'

    @typechecked
    def __init__(self, name: str = 'env'):
        super(Env, self).__init__(status=StatusWithSubInfo(obj=self), name=name)
        self.action_space = None
        self.observation_space = None
        self.step_count = None
        self.recorder = Recorder()
        self._last_reset_point = 0
        self.total_step_count_fn = lambda: self._status.group_specific_info_key(info_key='step', group_way='sum')

    @register_counter_info_to_status_decorator(increment=1, info_key='step', under_status=('TRAIN', 'TEST'),
                                               ignore_wrong_status=True)
    def step(self, action):
        pass

    @register_counter_info_to_status_decorator(increment=1, info_key='reset', under_status='JUST_RESET')
    def reset(self):
        self._status.set_status('JUST_RESET')
        self._last_reset_point = self.total_step_count_fn()

    @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='JUST_INITED')
    def init(self):
        self._status.set_status('JUST_INITED')

    def get_state(self):
        raise NotImplementedError

    def seed(self, seed=None):
        return self.unwrapped.seed(seed=seed)

We use STATUS to record and control the status of an environment, register_counter_info_to_status_decorator is a decorator that counts the times of initialization and reset of an environment.

def register_counter_info_to_status_decorator(increment, info_key, under_status: (str, tuple) = None,
                                              ignore_wrong_status=False):
    def wrap(fn):
        if under_status:
            assert isinstance(under_status, (str, tuple))
            if isinstance(under_status, str):
                final_st = tuple([under_status])
            else:
                final_st = under_status

        else:
            final_st = (None,)

        @wraps(fn)
        def wrap_with_self(self, *args, **kwargs):
            # todo record() called in fn will lost the just appended info_key at the very first
            obj = self
            if not hasattr(obj, '_status') or not isinstance(getattr(obj, '_status'), StatusWithInfo):
                raise ValueError(
                    ' the object {} does not not have attribute StatusWithInfo instance or hold wrong type of Status'.format(
                        obj))

            assert isinstance(getattr(obj, '_status'), StatusWithInfo)
            obj_status = getattr(obj, '_status')
            for st in final_st:
                obj_status.append_new_info(info_key=info_key, init_value=0, under_status=st)
            res = fn(self, *args, **kwargs)
            for st in final_st:
                if st and st != obj.get_status()['status'] and not ignore_wrong_status:
                    raise ValueError('register counter info under status: {} but got status {}'.format(st,
                                                                                                       obj.get_status()[
                                                                                                           'status']))
            obj_status.update_info(info_key=info_key, increment=increment,
                                   under_status=obj.get_status()['status'])
            return res

        return wrap_with_self

    return wrap

The class EnvSpec stores and regulates the environment specifications, e.g. data type of observation space and action space in an environment.

class EnvSpec(object):
    @init_func_arg_record_decorator()
    @typechecked
    def __init__(self, obs_space: Space, action_space: Space):
        self._obs_space = obs_space
        self._action_space = action_space
        self.obs_shape = tuple(np.array(self.obs_space.sample()).shape)
        if len(self.obs_shape) == 0:
            self.obs_shape = (1,)
        self.action_shape = tuple(np.array(self.action_space.sample()).shape)
        if len(self.action_shape) == 0:
            self.action_shape = ()

    @property
    def obs_space(self):
        return self._obs_space

    @property
    def action_space(self):
        return self._action_space

    @property
    def flat_obs_dim(self) -> int:
        return int(flat_dim(self.obs_space))

    @property
    def flat_action_dim(self) -> int:
        return int(flat_dim(self.action_space))

    @staticmethod
    def flat(space: Space, obs_or_action: (np.ndarray, list)):
        return flatten(space, obs_or_action)

    def flat_action(self, action: (np.ndarray, list)):
        return flatten(self.action_space, action)

    def flat_obs(self, obs: (np.ndarray, list)):
        return flatten(self.obs_space, obs)