Skip to content

Commit

Permalink
Merge pull request #138 from Project-DC/state-experiments
Browse files Browse the repository at this point in the history
State experiments and 2x faster training
  • Loading branch information
frankhart2018 authored Oct 11, 2020
2 parents 84cee0a + 7fa5784 commit 5d10d45
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 33 deletions.
53 changes: 44 additions & 9 deletions prima-vita-reinforce.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,66 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pygame 2.0.0.dev10 (SDL 2.0.12, python 3.8.3)\n",
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
]
}
],
"source": [
"from pygeneses.envs.prima_vita import PrimaVita"
"from pygeneses.envs.prima_vita import PrimaVita\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"rl_model = PrimaVita()"
"rl_model = PrimaVita(log_dir_info=\"test_new_state_10k\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RIP 0-3, alive count = 10\n",
]
}
],
"source": [
"rl_model.run()"
"t1 = time.time()\n",
"rl_model.run(stop_at = 500)\n",
"print(time.time()-t1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -44,7 +80,6 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
11 changes: 5 additions & 6 deletions pygeneses/envs/prima_vita/player_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, i, log_dir, tob, energy, x=None, y=None, mode="bot"):
self.fighting_with = -1
self.energy = energy
self.embeddings = np.array([0])
self.states = []
self.states = np.array([0])
self.mode = mode

# Add the initial x, y coordinates as first entry in logs
Expand Down Expand Up @@ -198,7 +198,6 @@ def update_history(
fight_with (int)
: Id of player with which current agent fought (optional)
"""

# If action number is less than or equal to 9 (i.e. movement in 8 directions, stay or ingestion) then
if action <= 9:
self.action_history.append(
Expand All @@ -210,7 +209,7 @@ def update_history(
self.energy,
self.playerX,
self.playerY,
self.states[-1],
self.states,
],
dtype=object,
)
Expand All @@ -228,7 +227,7 @@ def update_history(
np.array(offspring_ids),
self.playerX,
self.playerY,
self.states[-1],
self.states,
],
dtype=object,
)
Expand All @@ -247,7 +246,7 @@ def update_history(
mate_id,
self.playerX,
self.playerY,
self.states[-1],
self.states,
],
dtype=object,
)
Expand All @@ -264,7 +263,7 @@ def update_history(
fight_with,
self.playerX,
self.playerY,
self.states[-1],
self.states,
],
dtype=object,
)
Expand Down
90 changes: 72 additions & 18 deletions pygeneses/envs/prima_vita/prima_vita.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import time
import importlib
from itertools import cycle

# Import other classes
from .player_class import Player
Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(
self.current_population = 0
self.screen = None
self.number_of_particles = random.randint(70, 80)

self.leading_zeros = 0
# Can take values from user
self.initial_population = (
params_dic["initial_population"]
Expand Down Expand Up @@ -313,26 +314,44 @@ def get_current_state(self, idx=None):
env_player_index,
) = self.players_in_env(self.players[i], get_idx=True)

# Stack together the food and player vectors
temp_state = [env_food_vector, env_player_vector]
# Find count of particles and players
num_food_particles = len(env_food_vector) // 2
num_players = len(env_player_vector) // 3

# Compute the max out of two, this will be the max value for loop counter
max_count = max(num_food_particles, num_players)

# Iterate and stack food and player in alternate positions
temp_state = []
j = 0

while(j < max_count):
if j < num_food_particles:
temp_state.append(env_food_vector[(j*2):(i*2+2)])
else:
temp_state.append([0, 0])

if j < num_players:
temp_state.append(env_player_vector[(j*3):(i*3+3)])
else:
temp_state.append([0, 0, 0])
j += 1

# Update food_near and players_near for current player
self.players[i].food_near = env_particle_index
self.players[i].players_near = env_player_index

# Flatten the list
temp_state = sum(temp_state, [])

# Save this as state in current agent's object
self.players[i].states.append(
np.array(
self.players[i].states = np.array(
[
np.array(env_food_vector, dtype=object),
np.array(env_player_vector, dtype=object),
],
dtype=object,
)
)

# Update food_near and players_near for current player
self.players[i].food_near = env_particle_index
self.players[i].players_near = env_player_index

# Convert the food and player vectors stacked together into a single vector
temp_state = sum(temp_state, [])

# Pad to state_size - 1
temp_state = self.pad_state(
Expand All @@ -342,14 +361,41 @@ def get_current_state(self, idx=None):
# Append energy to state
temp_state = np.append(temp_state, [self.players[i].energy])


# Append to initial_state
initial_state.append(temp_state)
# Otherwise copy old state
else:
# Convert state to 1D array
temp_state = np.hstack(self.players[i].states[-1])
env_food_vector = self.players[i].states[0]
env_player_vector = self.players[i].states[1]

# Find count of particles and players
num_food_particles = len(env_food_vector) // 2
num_players = len(env_player_vector) // 3

# Compute the max out of two, this will be the max value for loop counter
max_count = max(num_food_particles, num_players)

# Pad state to state_size - 1
# Iterate and stack food and player in alternate positions
temp_state = []
j = 0

while(j < max_count):
if j < num_food_particles:
temp_state.append(list(env_food_vector[(j*2):(i*2+2)]))
else:
temp_state.append([0, 0])

if j < num_players:
temp_state.append(list(env_player_vector[(j*3):(i*3+3)]))
else:
temp_state.append([0, 0, 0])
j += 1

# Flatten the list
temp_state = sum(temp_state, [])

# Pad to state_size - 1
temp_state = self.pad_state(
np.array(temp_state), self.state_size - 1
)
Expand Down Expand Up @@ -393,10 +439,18 @@ def run(self, stop_at=None):
break

# Loop through all the players
for i in range(len(self.players)):
i = 0
j = self.leading_zeros
while True:
j += 1
if type(self.players[j]) == int:
self.leading_zeros += 1
else:
break
for i in range(self.leading_zeros, len(self.players)):
if type(self.players[i]) != int:
# Take an action for current index
self.take_action(i, states[i].astype(np.uint8))
self.take_action(i, states[i])
idx = i if type(self.players[i]) != int else None

# Get updated state
Expand Down
3 changes: 3 additions & 0 deletions pygeneses/models/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def update_single_agent(self, idx):
-(self.saved_log_probs[idx][j] * self.rewards[idx][j])
)

self.saved_log_probs[idx] = []
self.rewards[idx] = []

# Sum all the products
self.policy_loss[idx] = torch.cat(self.policy_loss[idx]).sum()

Expand Down
105 changes: 105 additions & 0 deletions test-alternate-stacking-algo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"env_food_vector = [1, 2]\n",
"env_player_vector = [5, 6, 7, 8, 9, 10, 13, 14, 15]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"num_food_particles = len(env_food_vector) // 2\n",
"num_players = len(env_player_vector) // 3"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
],
"source": [
"max_count = max(num_food_particles, num_players)\n",
"print(max_count)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1, 2], [5, 6, 7], [0, 0], [8, 9, 10], [0, 0], [13, 14, 15]]\n",
"[1, 2, 5, 6, 7, 0, 0, 8, 9, 10, 0, 0, 13, 14, 15]\n"
]
}
],
"source": [
"temp_state = []\n",
"i = 0\n",
"\n",
"while(i < max_count):\n",
" if i < num_food_particles:\n",
" temp_state.append(env_food_vector[(i*2):(i*2+2)])\n",
" else:\n",
" temp_state.append([0, 0])\n",
"\n",
" if i < num_players:\n",
" temp_state.append(env_player_vector[(i*3):(i*3+3)])\n",
" else:\n",
" temp_state.append([0, 0, 0])\n",
" i += 1\n",
" \n",
"print(temp_state)\n",
"temp_state = sum(temp_state, [])\n",
"print(temp_state)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 5d10d45

Please sign in to comment.