Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 11 additions & 58 deletions dqfd/fragile_learning/memory.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,29 @@
from typing import Iterable

from fragile.core import Swarm
import numpy

from dqfd.kerasrl.memory import (
PartitionedRingBuffer,
SumSegmentTree,
MinSegmentTree,
PartitionedMemory as KrlPartitionedMemory,
)


class SwarmReplayMemory:
def __init__(self, max_size: int, names: Iterable[str], mode: str = "best"):
self.max_len = max_size
self.names = names
self.mode = mode
for name in names:
setattr(self, name, None)

def __len__(self):
if getattr(self, self.names[0]) is None:
return 0
return len(getattr(self, self.names[0]))

def memorize(self, swarm: Swarm):
# extract data from the swarm
if self.mode == "best":
data = next(swarm.tree.iterate_branch(swarm.best_id, batch_size=-1, names=self.names))
else:
data = next(swarm.tree.iterate_nodes_at_random(batch_size=-1, names=self.names))
# Concatenate the data to the current memory
for name, val in zip(self.names, data):
if len(val.shape) == 1: # Scalar vectors are transformed to columns
val = val.reshape(-1, 1)
processed = (
val if getattr(self, name) is None else numpy.vstack([val, getattr(self, name)])
)
if len(processed) > self.max_len:
processed = processed[: self.max_len]
setattr(self, name, processed)
print("Memory now contains %s samples" % len(self))


class DQFDMemory(SwarmReplayMemory):
def __init__(self, max_size: int):
names = ["observs", "actions", "rewards", "oobs"]
super(DQFDMemory, self).__init__(max_size=max_size, mode="best", names=names)
from dqfd.kerasrl.memory import PartitionedMemory as KrlPartitionedMemory

def iterate_data(self):
if len(self) == 0:
raise ValueError("Memory is empty. Call memorize before iterating data.")
for i in range(len(self)):
vals = [numpy.squeeze(getattr(self, name)[i]) for name in self.names]
yield vals
from dqfd.fragile_learning.runner import FragileRunner


class PartitionedMemory(KrlPartitionedMemory):
def __init__(
self,
limit,
swarm_memory: DQFDMemory,
runner: FragileRunner,
alpha=0.4,
start_beta=1.0,
end_beta=1.0,
steps_annealed=1,
**kwargs
):
pre_load_data = [
(obs, action, reward, end) for obs, action, reward, end in swarm_memory.iterate_data()
]
print("LEN PRELOAD DATA", len(pre_load_data))
def iterate_values(runner):
if len(runner) == 0:
raise ValueError("Memory is empty. Call memorize before iterating data.")
for i in range(len(runner)):
vals = [numpy.squeeze(getattr(runner, name)[i]) for name in runner.names]
yield vals

pre_load_data = list(iterate_values(runner))
super(PartitionedMemory, self).__init__(
pre_load_data=pre_load_data,
limit=limit,
Expand Down
65 changes: 48 additions & 17 deletions dqfd/fragile_learning/runner.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,87 @@
from fragile.core import DiscreteEnv, DiscreteUniform
from fragile.core.tree import HistoryTree
from fragile.core.swarm import Swarm
from fragile.distributed import ParallelEnv
from fragile.core import DiscreteUniform
from fragile.distributed import ReplayCreator
from plangym import AtariEnvironment

from dqfd.fragile_learning.env import AtariEnvironment
from dqfd.fragile_learning.memory import DQFDMemory


class FragileRunner:
def __init__(
self,
game_name: str,
n_swarms: int = 2,
n_workers_per_swarm: int = 2,
n_walkers: int = 32,
max_epochs: int = 200,
reward_scale: float = 2.0,
distance_scale: float = 1.0,
n_workers: int = 8,
memory_size: int = 200,
score_limit: float = 600,
):

self.env = ParallelEnv(lambda: AtariEnvironment(game_name), n_workers=n_workers)
self.env = AtariEnvironment(game_name)
self.n_actions = self.env.n_actions
self.game_name = game_name
self.env_callable = lambda: self.env
self.model_callable = lambda env: DiscreteUniform(env=self.env)
self.prune_tree = True
self.memory_size = memory_size
self.score_limit = score_limit
# A bigger number will increase the quality of the trajectories sampled.
self.n_walkers = n_walkers
self.max_epochs = max_epochs # Increase to sample longer games.
self.reward_scale = reward_scale # Rewards are more important than diversity.
self.distance_scale = distance_scale
self.minimize = False # We want to get the maximum score possible.
store_data = ["observs", "actions", "rewards", "oobs"]
self.swarm = Swarm(
self.names = ["observs", "actions", "rewards", "oobs"]
self.swarm = ReplayCreator(
n_swarms=n_swarms,
n_workers_per_swarm=n_workers_per_swarm,
num_examples=self.memory_size,
max_examples=int(self.memory_size * 1.5),
model=self.model_callable,
env=self.env_callable,
tree=lambda: HistoryTree(names=store_data, prune=True),
names=self.names,
n_walkers=self.n_walkers,
max_epochs=self.max_epochs,
prune_tree=self.prune_tree,
reward_scale=self.reward_scale,
distance_scale=self.distance_scale,
minimize=self.minimize,
score_limit=score_limit,
score_limit=self.score_limit,
)
self.memory = DQFDMemory(max_size=memory_size)
for name in self.names:
setattr(self, name, None)

def __len__(self):
if getattr(self, self.names[0]) is None:
return 0
return len(getattr(self, self.names[0]))

def iterate_memory(self):
return self.swarm.iterate_values()

def run(self):
while len(self.memory) < self.memory.max_len - 1:
print("Creating fractal replay memory...")
_ = self.swarm.run()
print("Max. fractal cum_rewards:", self.swarm.best_reward)
self.memory.memorize(swarm=self.swarm)
self.swarm.run()
for name in self.names:
setattr(self, name, getattr(self.swarm, name))


"""
swarm = ReplayCreator(
names=names,
num_examples=num_examples,
max_examples=300,
n_swarms=n_swarms,
n_workers_per_swarm=n_workers_per_swarm,
model=model_callable,
env=env_callable,
n_walkers=n_walkers,
max_epochs=max_epochs,
reward_scale=reward_scale,
distance_scale=distance_scale,
minimize=minimize,
force_logging=True,
show_pbar=True,
report_interval=10,
)"""
37 changes: 21 additions & 16 deletions dqfd_atari.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import sys

import ray
import tensorflow.compat.v1 as tf
from tensorflow.python.keras.optimizers import Adam

Expand All @@ -19,21 +21,21 @@
def main():
# We downsize the atari frame to 84 x 84 and feed the model 4 frames at a time for
# a sense of direction and speed.
INPUT_SHAPE = (84, 84)
WINDOW_LENGTH = 4
input_shape = (84, 84)
window_length = 4
# Runner parameters
EXPLORE_MEMORY_STEPS = 5
fractal_memory_size = 1000
fractal_memory_size = 200
n_walkers = 32
n_workers = 8
max_epochs_per_game = 2000
n_workers = 2
n_swarms = 2
max_epochs_per_game = 110
score_limit_per_game = 1000
# Training parameters
n_training_steps = 1000
pretraining_steps = 500
target_model_update = 1000
n_max_episode_steps = 200000
rl_training_memory_max_size = 10000
n_training_steps = 50
pretraining_steps = 50
target_model_update = 10
n_max_episode_steps = 20000
rl_training_memory_max_size = 1000
# testing
n_episodes_test = 10

Expand All @@ -43,6 +45,8 @@ def main():
parser.add_argument("--weights", type=str, default=None)
args = parser.parse_args()

ray.init(object_store_memory=78643200 * 100) # 370Mb

# Get the environment and extract the number of actions.
env = create_plangym_env(args.env_name)
n_actions = env.action_space.n
Expand All @@ -56,26 +60,27 @@ def main():
args.env_name,
memory_size=fractal_memory_size,
n_walkers=n_walkers,
n_workers=n_workers,
n_workers_per_swarm=n_workers,
n_swarms=n_swarms,
max_epochs=max_epochs_per_game,
score_limit=score_limit_per_game,
)
explorer.run()
processed_memory = processor.process_demo_data(explorer.memory)
processed_explorer = processor.process_demo_data(explorer)

memory = PartitionedMemory(
limit=rl_training_memory_max_size,
swarm_memory=processed_memory,
runner=processed_explorer,
alpha=0.4,
start_beta=0.6,
end_beta=0.6,
window_length=WINDOW_LENGTH,
window_length=window_length,
)

policy = EpsGreedyQPolicy(0.01)

model = DQFDNeuralNet(
window_length=WINDOW_LENGTH, n_actions=explorer.n_actions, input_shape=INPUT_SHAPE
window_length=window_length, n_actions=explorer.n_actions, input_shape=input_shape
)
dqfd = DQfDAgent(
model=model,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
url="https://github.com/Zeta36/FractalExplorationImitationLearning",
download_url="https://github.com/Zeta36/FractalExplorationImitationLearning",
install_requires=[
"plangym>=0.0.6",
"fragile>=0.0.40",
"plangym>=0.0.7",
"fragile>=0.0.41",
"numpy>=1.16.2",
"gym>=0.10.9",
"pillow-simd>=7.0.0.post3",
Expand Down