import gymnasium as gym import numpy as np from collections import deque from PIL import Image from multiprocessing import Process, Pipe # atari_wrappers.py class NoopResetEnv(gym.Wrapper): def __init__(self, env, noop_max=30): """Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. """ gym.Wrapper.__init__(self, env) self.noop_max = noop_max self.override_num_noops = None assert env.unwrapped.get_action_meanings()[0] == 'NOOP' def reset(self): """ Do no-op action for a number of steps in [1, noop_max].""" self.env.reset() if self.override_num_noops is not None: noops = self.override_num_noops else: noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 assert noops > 0 obs = None for _ in range(noops): obs, _, done, _ = self.env.step(0) if done: obs = self.env.reset() return obs class FireResetEnv(gym.Wrapper): def __init__(self, env): """Take action on reset for environments that are fixed until firing.""" gym.Wrapper.__init__(self, env) assert env.unwrapped.get_action_meanings()[1] == 'FIRE' assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self): self.env.reset() obs, _, done, _ = self.env.step(1) if done: self.env.reset() obs, _, done, _ = self.env.step(2) if done: self.env.reset() return obs class ImageSaver(gym.Wrapper): def __init__(self, env, img_path, rank): gym.Wrapper.__init__(self, env) self._cnt = 0 self._img_path = img_path self._rank = rank def step(self, action): step_result = self.env.step(action) obs, _, _, _ = step_result img = Image.fromarray(obs, 'RGB') img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt)) self._cnt += 1 return step_result class EpisodicLifeEnv(gym.Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game over. Done by DeepMind for the DQN and co. since it helps value estimation. """ gym.Wrapper.__init__(self, env) self.lives = 0 self.was_real_done = True def step(self, action): obs, reward, done, info = self.env.step(action) self.was_real_done = done # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() if lives < self.lives and lives > 0: # for Qbert somtimes we stay in lives == 0 condtion for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. done = True self.lives = lives return obs, reward, done, info def reset(self): """Reset only when lives are exhausted. This way all states are still reachable even though lives are episodic, and the learner need not know about any of this behind-the-scenes. """ if self.was_real_done: obs = self.env.reset() else: # no-op step to advance from terminal/lost life state obs, _, _, _ = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() return obs class MaxAndSkipEnv(gym.Wrapper): def __init__(self, env, skip=4): """Return only every `skip`-th frame""" gym.Wrapper.__init__(self, env) # most recent raw observations (for max pooling across time steps) self._obs_buffer = deque(maxlen=2) self._skip = skip def step(self, action): """Repeat action, sum reward, and max over last observations.""" total_reward = 0.0 done = None for _ in range(self._skip): obs, reward, done, info = self.env.step(action) self._obs_buffer.append(obs) total_reward += reward if done: break max_frame = np.max(np.stack(self._obs_buffer), axis=0) return max_frame, total_reward, done, info def reset(self): """Clear past frame buffer and init. to first obs. from inner env.""" self._obs_buffer.clear() obs = self.env.reset() self._obs_buffer.append(obs) return obs class ClipRewardEnv(gym.RewardWrapper): def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) class WarpFrame(gym.ObservationWrapper): def __init__(self, env): """Warp frames to 84x84 as done in the Nature paper and later work.""" gym.ObservationWrapper.__init__(self, env) self.res = 84 self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8') def observation(self, obs): frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32')) frame = np.array(Image.fromarray(frame).resize((self.res, self.res), resample=Image.BILINEAR), dtype=np.uint8) return frame.reshape((self.res, self.res, 1)) class FrameStack(gym.Wrapper): def __init__(self, env, k): """Buffer observations and stack across channels (last axis).""" gym.Wrapper.__init__(self, env) self.k = k self.frames = deque([], maxlen=k) shp = env.observation_space.shape assert shp[2] == 1 # can only stack 1-channel frames self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8') def reset(self): """Clear buffer and re-fill by duplicating the first observation.""" ob = self.env.reset() for _ in range(self.k): self.frames.append(ob) return self.observation() def step(self, action): ob, reward, done, info = self.env.step(action) self.frames.append(ob) return self.observation(), reward, done, info def observation(self): assert len(self.frames) == self.k return np.concatenate(self.frames, axis=2) def wrap_deepmind(env, episode_life=True, clip_rewards=True): """Configure environment for DeepMind-style Atari. Note: this does not include frame stacking!""" assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip if episode_life: env = EpisodicLifeEnv(env) env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) env = WarpFrame(env) if clip_rewards: env = ClipRewardEnv(env) return env # envs.py def make_env(env_id, img_dir, atari_wrapper, train, seed, rank): def _thunk(): env = gym.make(env_id) env.seed(seed + rank) if img_dir is not None: env = ImageSaver(env, img_dir, rank) if atari_wrapper: env = wrap_deepmind(env, episode_life=train, clip_rewards=train) env = WrapPyTorch(env) return env return _thunk def make_env_single_proc(env_id, atari_wrapper, train): env = gym.make(env_id) if atari_wrapper: env = wrap_deepmind(env, episode_life=train, clip_rewards=train) env = WrapPyTorch(env) return env class WrapPyTorch(gym.ObservationWrapper): def __init__(self, env=None): super(WrapPyTorch, self).__init__(env) self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32') def observation(self, observation): return observation.transpose(2, 0, 1) # vecenv.py class VecEnv(object): """ Vectorized environment base class """ def step(self, vac): """ Apply sequence of actions to sequence of environments actions -> (observations, rewards, news) where 'news' is a boolean vector indicating whether each element is new. """ raise NotImplementedError def reset(self): """ Reset all environments """ raise NotImplementedError def close(self): pass # subproc_vec_env.py def worker(remote, env_fn_wrapper): env = env_fn_wrapper.x() while True: cmd, data = remote.recv() if cmd == 'step': ob, reward, done, info = env.step(data) if done: ob = env.reset() remote.send((ob, reward, done, info)) elif cmd == 'reset': ob = env.reset() remote.send(ob) elif cmd == 'close': remote.close() break elif cmd == 'get_spaces': remote.send((env.action_space, env.observation_space)) elif cmd == 'none': remote.send(None) else: raise NotImplementedError class CloudpickleWrapper(object): """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) """ def __init__(self, x): self.x = x def __getstate__(self): import cloudpickle return cloudpickle.dumps(self.x) def __setstate__(self, ob): import pickle self.x = pickle.loads(ob) class SubprocVecEnv(VecEnv): def __init__(self, env_fns): """ envs: list of gym environments to run in subprocesses """ nenvs = len(env_fns) self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn))) for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] for p in self.ps: p.start() self.remotes[0].send(('get_spaces', None)) self.action_space, self.observation_space = self.remotes[0].recv() def step(self, actions): for remote, action in zip(self.remotes, actions): remote.send(('step', action)) results = [remote.recv() for remote in self.remotes] obs, rews, dones, infos = zip(*results) # The original implenentation of the vectorized environments # returns observations as a stacked ndarray. Here, the code # is modified to return a python list of objects. # return np.stack(obs), np.stack(rews), np.stack(dones), infos return list(obs), np.stack(rews), np.stack(dones), infos def reset(self, mask=None): """ `mask` is `None` or `List[float]`. """ if mask is None: for remote in self.remotes: remote.send(('reset', None)) else: for mask_i, remote in zip(mask, self.remotes): if mask_i == 1.0: remote.send(('reset', None)) else: remote.send(('none', None)) # do nothing # The original implenentation of the vectorized environments # returns observations as a stacked ndarray. Here, the code # is modified to return a python list of objects. # return np.stack([remote.recv() for remote in self.remotes]) obs = [remote.recv() for remote in self.remotes] return obs def close(self): for remote in self.remotes: remote.send(('close', None)) for p in self.ps: p.join() @property def num_envs(self): return len(self.remotes) # Create the environment. def make(env_name, atari_wrapper, train, num_processes): img_dir=None import cv2 # workaround for importing cv2 in subprocesses on mac envs = SubprocVecEnv([ make_env(env_name, img_dir, atari_wrapper, train, 42, i) for i in range(num_processes) ]) return envs if __name__ == "__main__": env = make_env_single_proc("PongNoFrameskip-v4", True, True) env.reset() step = env.step(0) Image.fromarray(step[0].reshape([84, 84])).save("./test.png")