Flatland 3
Observing the Flatland 3 Environment
A brief examination of the global and the tree observation types.
In this notebook, I look at the observation outputs for small maps to understand their contents.
In [2]:
!pip install flatland-rl | grep error
Observing the Flatland 3 Environment¶
In this notebook, I look at the observation outputs for small maps to understand their contents.
In [3]:
# basic:
import numpy as np
from flatland.envs.rail_env import RailEnv
# for rendering:
import PIL
from flatland.utils.rendertools import RenderTool
from IPython.display import clear_output
# for plotting
import matplotlib.pyplot as plt
In [4]:
def render_env(env,wait=True):
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env()
image = env_renderer.get_image()
pil_image = PIL.Image.fromarray(image)
clear_output(wait=True)
display(pil_image)
Global observation¶
In [5]:
env = RailEnv(width=40, height=40)
obs = env.reset()
Take few random steps:
In [6]:
for t in range(10):
obs, rew, done, info = env.step({
0: np.random.randint(0, 5),
1: np.random.randint(0, 5)
})
In [7]:
render_env(env, True)
In [8]:
agent_handle = env.get_agent_handles()[0]
print('Observations for agent {}:'.format(agent_handle))
agent_obs = obs[agent_handle]
Transition map:¶
(Ignoring the directions)
In [9]:
rail_map = [[np.sum(cell) for cell in row] for row in agent_obs[0]]
plt.matshow(rail_map)
plt.show()
Agent states:¶
In [10]:
agent_states = np.transpose(agent_obs[1], (2, 0, 1))
print('- Agent position\n')
plt.matshow(agent_states[0])
plt.show()
print('- Other agent positions\n')
plt.matshow(agent_states[1])
plt.show()
print('- Malfunctions\n')
plt.matshow(agent_states[2])
plt.show()
Agent targets:¶
In [11]:
agent_targets = np.transpose(agent_obs[2], (2, 0, 1))
print('- Current agent:')
plt.matshow(agent_targets[0])
plt.show()
print('- All agents:')
plt.matshow(agent_targets[1])
plt.show()
Tree observation¶
In [12]:
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.observations import Node
# for graphs:
import networkx as nx
In [13]:
env = RailEnv(width=30,
height=30,
number_of_agents=3,
rail_generator=sparse_rail_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=4)
)
obs = env.reset()
In [14]:
obs, rew, done, info = env.step({
0: np.random.randint(0, 5),
1: np.random.randint(0, 5),
2: np.random.randint(0, 5)
})
In [15]:
render_env(env, True)
Overlaying shapes are only added with the tree observation. These are the shortest paths of each train. Each color-symbol combination belongs to a different train.
Agent 0's obs tree¶
In [16]:
obs[0].childs
Out[16]:
In [17]:
obs[0].childs['F'].childs
Out[17]:
In [18]:
def node_data(node, direction):
""" copy the data except the childs """
return {
"direction" : direction,
"dist_own_target_encountered" : node.dist_own_target_encountered,
"dist_other_target_encountered" : node.dist_other_target_encountered,
"dist_other_agent_encountered" : node.dist_other_agent_encountered,
"dist_potential_conflict" : node.dist_potential_conflict,
"dist_unusable_switch" : node.dist_unusable_switch,
"dist_to_next_branch" : node.dist_to_next_branch,
"dist_min_to_target" : node.dist_min_to_target,
"num_agents_same_direction" : node.num_agents_same_direction,
"num_agents_opposite_direction" : node.num_agents_opposite_direction,
"num_agents_malfunctioning" : node.num_agents_malfunctioning,
"speed_min_fractional" : node.speed_min_fractional,
"num_agents_ready_to_depart" : node.num_agents_ready_to_depart
}
In [19]:
G = nx.DiGraph()
G.add_node(1, object=node_data(obs[0], None))
In [ ]:
type(obs[0].childs['F'])
Out[ ]:
Build a Networkx graph
In [20]:
last_id = 1
fringe = [(obs[0], last_id)]
while len(fringe) > 0:
parent, parent_id = fringe.pop()
for key, node in parent.childs.items():
if type(node) is Node:
last_id += 1
G.add_node(last_id, object=node_data(node, key))
G.add_edge(parent_id, last_id)
fringe.append((node, last_id))
In [21]:
G.nodes
Out[21]:
In [22]:
plt.figure()
nx.draw_planar(G, with_labels=True)
In [23]:
G.nodes[4]
Out[23]:
Shortest way¶
In [24]:
def shortest_way(G, node, path):
""" chooses the closest nodes to the target, recursively """
d = {neibor:G.nodes[neibor]['object']['dist_min_to_target'] for neibor in G.neighbors(node)}
print(node, end='')
print(' -> ', end='')
print(d)
# exhausted the depth?
if len(d) == 0:
return path
# choose the shortes path neighbor and continue recursively
best_neibor = min(d, key=d.get)
return shortest_way(G, best_neibor, path + [best_neibor])
In [25]:
shortest_way(G, 1, [])
Out[25]:
In [ ]:
Content
Comments
You must login before you can post a comment.