From 1434e37067eff191f5c952da06e9b5f379466873 Mon Sep 17 00:00:00 2001 From: dtracers Date: Sat, 15 Sep 2018 14:07:19 -0600 Subject: [PATCH 1/6] working on making a shot trainer --- .../trainer}/base_trainer.py | 6 + framework/trainer/generated_data_trainer.py | 17 ++ trainer/binary_converter.py | 242 ------------------ trainer/download_trainer.py | 62 ----- trainer/downloader.py | 107 -------- trainer/parsed_download_trainer.py | 48 ---- trainer/shots/__init__.py | 0 trainer/shots/shot_trainer.py | 4 + 8 files changed, 27 insertions(+), 459 deletions(-) rename {trainer => framework/trainer}/base_trainer.py (62%) create mode 100644 framework/trainer/generated_data_trainer.py delete mode 100644 trainer/binary_converter.py delete mode 100644 trainer/download_trainer.py delete mode 100644 trainer/downloader.py delete mode 100644 trainer/parsed_download_trainer.py create mode 100644 trainer/shots/__init__.py create mode 100644 trainer/shots/shot_trainer.py diff --git a/trainer/base_trainer.py b/framework/trainer/base_trainer.py similarity index 62% rename from trainer/base_trainer.py rename to framework/trainer/base_trainer.py index 7544876..31f7795 100644 --- a/trainer/base_trainer.py +++ b/framework/trainer/base_trainer.py @@ -6,3 +6,9 @@ class BaseTrainer: def __init__(self, model_holder: BaseModelHolder): self.model_holder = model_holder model_holder.initialize_model(load=True) + + def initialize_training(self): + raise NotImplementedError() + + def train(self): + raise NotImplementedError() diff --git a/framework/trainer/generated_data_trainer.py b/framework/trainer/generated_data_trainer.py new file mode 100644 index 0000000..0fb14b2 --- /dev/null +++ b/framework/trainer/generated_data_trainer.py @@ -0,0 +1,17 @@ +from framework.data_generator.base_generator import BaseDataGenerator +from framework.model_holder.base_model_holder import BaseModelHolder +from framework.trainer.base_trainer import BaseTrainer + + +class GeneratedDataTrainer(BaseTrainer): + + def __init__(self, model_holder: BaseModelHolder, data_generator: BaseDataGenerator): + super().__init__(model_holder) + self.data_generator = data_generator + + def initialize_training(self, **kwargs): + self.data_generator.initialize(**kwargs) + + def train(self): + for data in self.data_generator.get_data(): + self.model_holder.train_step(data, data) diff --git a/trainer/binary_converter.py b/trainer/binary_converter.py deleted file mode 100644 index 5d5549d..0000000 --- a/trainer/binary_converter.py +++ /dev/null @@ -1,242 +0,0 @@ -import io -import os -import struct - -import numpy as np -import time -import logging - -import gzip - -EMPTY_FILE = 'empty' -NO_FILE_VERSION = -1 -NON_FLIPPED_FILE_VERSION = 0 -FLIPPED_FILE_VERSION = 1 -HASHED_NAME_FILE_VERSION = 2 -IS_EVAL_FILE_VERSION = 3 -BATCH_ARRAY_FILE_VERSION = 4 -TIME_ADDITION_FILE_VERSION = 5 - - -def get_latest_file_version(): - return TIME_ADDITION_FILE_VERSION - - -def get_state_dim(file_version): - if file_version == 4: - return 206 - elif file_version is get_latest_file_version(): - return 219 - # return input_formatter.get_state_dim() - - -def write_array_to_file(game_file, array): - """ - :param game_file: This is the file that the array will be written to. - :param array: A numpy array of any size. - """ - bytes = convert_numpy_array(array) - size_of_bytes = len(bytes.getvalue()) - game_file.write(struct.pack('i', size_of_bytes)) - game_file.write(bytes.getvalue()) - - -def convert_numpy_array(numpy_array): - """ - Converts a numpy array into compressed bytes - :param numpy_array: An array that is going to be converted into bytes - :return: A BytesIO object that contains compressed bytes - """ - compressed_array = io.BytesIO() # np.savez_compressed() requires a file-like object to write to - np.save(compressed_array, numpy_array, allow_pickle=False, fix_imports=False) - return compressed_array - - -def write_version_info(file, version_number): - file.write(struct.pack('i', version_number)) - - -def write_bot_hash(game_file, hashed_name): - game_file.write(struct.pack('Q', hashed_name)) - - -def write_is_eval(game_file, is_eval): - game_file.write(struct.pack('?', is_eval)) - - -def get_file_version(file, file_name=None): - """ - Gets file info from the file - :param file: - :return: a tuple containing - file_version: This is the version of a file represented as a number. - hashed_name: This is the hash of the model that was used to create this file. If it is a least version 2 - is_eval: This is used to decide if the file was created in eval mode - """ - if not isinstance(file, io.BytesIO): - file_name = os.path.basename(file.name).split('-')[0] - else: - file_name = 'ram' - - result = [] - - try: - chunk = file.read(4) - file_version = struct.unpack('i', chunk)[0] - if file_version > get_latest_file_version(): - file.seek(0, 0) - file_version = NO_FILE_VERSION - - result.append(file_version) - - if file_version < HASHED_NAME_FILE_VERSION: - result.append(file_name) - else: - chunk = file.read(8) - hashed_name = struct.unpack('Q', chunk)[0] - result.append(hashed_name) - if file_version < IS_EVAL_FILE_VERSION: - result.append(False) - else: - chunk = file.read(1) - is_eval = struct.unpack('?', chunk)[0] - result.append(is_eval) - except Exception as e: - result = [EMPTY_FILE, file_name, False] - print('file version was messed up', e) - finally: - return tuple(result) - - -def get_file_size(f): - """ - :param f: The file - :return: The size of the file in bytes. - """ - # f is a file-like object. - try: - old_file_position = f.tell() - f.seek(0, os.SEEK_END) - size = f.tell() - f.seek(old_file_position, os.SEEK_SET) - return size - except: - return 0 - - -def read_data(file, process_pair_function, batching=False): - """ - Reads a file. Quits if anything breaks. - :param file: A simple python file object that will be read - :param process_pair_function: A function that takes in an input array and an output array. - There is also an optional number saying how many times this has been called for a single file. - It always starts at 0 - :param batching: If more than one item in an array is read at the same time then we will batch - them instead of doing them one at a time - :return: None - """ - - file_version, hashed_name, is_eval = get_file_version(file) - if file_version == EMPTY_FILE: - return - - # print('replay version:', file_version) - # print('hashed name:', hashed_name) - - pair_number = 0 - totalbytes = 0 - total_time = 0 - counter = 0 - while True: - try: - start = time.time() - chunk = file.read(4) - if chunk == '': - totalbytes += 4 - break - input_array, num_bytes = get_array(file, chunk) - totalbytes += num_bytes + 4 - chunk = file.read(4) - if chunk == '': - totalbytes += 4 - break - output_array, num_bytes = get_array(file, chunk) - total_time += time.time() - start - batch_size = int(len(input_array) / get_state_dim(file_version)) - input_array = np.reshape(input_array, (batch_size, int(get_state_dim(file_version)))) - output_array = np.reshape(output_array, (batch_size, 8)) - if not batching: - for i in range(len(input_array)): - input_ = input_array[i] - if file_version is 4: - input_ = v4tov5(input_) - process_pair_function(input_, output_array[i], pair_number, hashed_name) - pair_number += 1 - else: - if file_version is 4: - input_array = v4tov5(input_array) - process_pair_function(input_array, output_array, pair_number, hashed_name, batch_size) - pair_number += batch_size - totalbytes += num_bytes + 4 - counter += 1 - except EOFError: - # print('reached end of file') - break - except Exception as e: - logging.exception('error occurred but not because of reading but something else') - # print('total batches [', counter, '] total pairs [', pair_number, ']') - # print('time reading', total_time) - file_size = get_file_size(file) - if file_size - totalbytes <= 4 + 4 + 8 + 1: - pass - # print('read: 100% of file') - else: - print('read: ' + str(totalbytes) + '/' + str(file_size) + ' bytes') - - -def v4tov5(input_array): - # Passed time (after game_info) 1 - input_array = np.insert(input_array, 1, 0.0, axis=1) - for i in range(6): - i = 22 if i is 0 else 43 + 20 * i - input_array = np.insert(input_array, i, 0, axis=1) - input_array = np.insert(input_array, i + 1, np.greater( - np.hypot(np.hypot(input_array[:, i - 6], input_array[:, i - 5]), input_array[:, i - 4]), 2200), axis=1) - return input_array - - -def get_array(file, chunk): - """ - Gets a compressed numpy array from a file. - - Throws an EOFError if it has problems loading the data. - - :param file: The file that is being read - :param chunk: A chunk representing a single number, this will be the number of bytes the array takes up. - :return: A numpy array - """ - try: - starting_byte = struct.unpack('i', chunk)[0] - except struct.error: - # print('struct error') - raise EOFError - numpy_bytes = file.read(starting_byte) - fake_file = io.BytesIO(numpy_bytes) - try: - result = np.load(fake_file, fix_imports=False) - except OSError: - print('numpy parse error') - raise EOFError - return result, starting_byte - - -def print_values(input_array, output_array, somevalue, anothervalue): - return - - -if __name__ == '__main__': - with gzip.open("path_to_file", 'rb') as f: - try: - read_data(f, print_values, batching=True) - except Exception as e: - print('error training on file ', e) diff --git a/trainer/download_trainer.py b/trainer/download_trainer.py deleted file mode 100644 index 29538aa..0000000 --- a/trainer/download_trainer.py +++ /dev/null @@ -1,62 +0,0 @@ -import gzip -import io - -from examples.autoencoder.autoencoder_model import AutoencoderModel -from examples.autoencoder.autoencoder_model_holder import AutoencoderModelHolder -from examples.autoencoder.autoencoder_output_formatter import AutoencoderOutputFormatter -from examples.autoencoder.variational_autoencoder_model import VariationalAutoencoderModel -from examples.legacy.legacy_input_formatter import LegacyInputFormatter -from examples.legacy.legacy_normalizer_input_formatter import LegacyNormalizerInputFormatter -from examples.legacy.legacy_output_formatter import LegacyOutputFormatter -from examples.multi_output_model import MultiOutputKerasModel -from framework.input_formatter.host_input_formatter import HostInputFormatter -from framework.model_holder.base_model_holder import BaseModelHolder -from examples.lstm.example_lstm_model import ExampleLSTMModel -from examples.example_model_holder import ExampleModelHolder -from examples.lstm.lstm_input_formatter import LSTMInputFormatter -from examples.lstm.lstm_output_formatter import LSTMOutputFormatter -from framework.output_formatter.host_output_formatter import HostOutputFormatter -from trainer.base_trainer import BaseTrainer -from trainer.downloader import Downloader -import trainer.binary_converter as bc - - -class DownloadTrainer(BaseTrainer): - def __init__(self, model_holder: BaseModelHolder): - super().__init__(model_holder) - self.downloader = Downloader() - - def train_on_file(self): - input_file = self.downloader.get_random_replay() - file_name = input_file[1] - input_file = input_file[0] - if isinstance(input_file, io.BytesIO): - input_file.seek(0) - with gzip.GzipFile(fileobj=input_file, mode='rb') as f: - bc.read_data(f, self.model_holder.process_pair, batching=True) - - def train_on_files(self): - input_file_list = self.downloader.get_replays(2000) - counter = 0 - for input_file in input_file_list: - file_name = input_file[1] - input_file = input_file[0] - if isinstance(input_file, io.BytesIO): - input_file.seek(0) - with gzip.GzipFile(fileobj=input_file, mode='rb') as f: - bc.read_data(f, self.model_holder.process_pair, batching=True) - counter += 1 - if counter % 10 == 0: - print('FILE', counter) - - def finish(self): - self.model_holder.finish_training() - - -if __name__ == '__main__': - input_formatter = LegacyNormalizerInputFormatter(LegacyInputFormatter()) - output_formatter = HostOutputFormatter(AutoencoderOutputFormatter(input_formatter)) - d = DownloadTrainer(AutoencoderModelHolder(AutoencoderModel(compressed_dim=50), - input_formatter, output_formatter)) - d.train_on_files() - d.finish() diff --git a/trainer/downloader.py b/trainer/downloader.py deleted file mode 100644 index abee400..0000000 --- a/trainer/downloader.py +++ /dev/null @@ -1,107 +0,0 @@ -import io -import json -import os - -import pandas -import pickle -import random -import zipfile - -import fs -import requests -import sys - -from requests.exceptions import ChunkedEncodingError - -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'framework', 'replayanalysis')) # dirty way to fix the path for the submodule pickling - - -class Downloader: - BASE_URL = 'http://138.197.6.71:5000' # for saltie replays/training - BASE_REPLAY_URL = "http://saltie.tk" # for replay parsing/training - API_KEY = '123456' - - def __init__(self, max_size_mb=100, path='mem://saltie'): - self.max_size_mb = max_size_mb - self.filesystem = fs.open_fs(path) - - @staticmethod - def unzip(in_memory_file: io.BytesIO): - in_memory_zip_file = zipfile.ZipFile(in_memory_file) - return [io.BytesIO(in_memory_zip_file.read(name)) for name in in_memory_zip_file.namelist()] - - @staticmethod - def create_in_memory_file(response: requests.Response) -> io.BytesIO: - in_memory_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - print('chunk') - if chunk: - in_memory_file.write(chunk) - return in_memory_file - - def get_random_replay(self): - js = requests.get(self.BASE_URL + '/replays/list?model_hash=rashbot0').json() - filename = random.choice(js) - return self.get_replay(filename), filename - - def get_replays(self, number=1, batch=50): - batch = min(number, batch) - js = requests.get(self.BASE_URL + '/replays/list?model_hash=rashbot0').json() - filenames = [] - file_list = [] - - total_filenames = random.sample(js, number) - for i in range(int(number / batch)): - sequence_filenames = total_filenames[i * batch: (i + 1) * batch] - file_list += self.get_replay(sequence_filenames) - filenames += sequence_filenames - print('downloaded', (batch * (i + 1.0)) / number * 100, '% of files') - return zip(file_list, filenames) - - def get_replay(self, filename_or_filenames: list or str): - if isinstance(filename_or_filenames, list): - try: - r = requests.post(self.BASE_URL + '/replays/download', - data={'files': json.dumps(filename_or_filenames)}) - except ChunkedEncodingError: - return [] - imf = self.create_in_memory_file(r) - return self.unzip(imf) - else: - r = requests.get(self.BASE_URL + '/replays/{}'.format(filename_or_filenames)) - imf = self.create_in_memory_file(r) - return imf - - def download_replays(self): - rpl, fn = self.get_random_replay() - success = self.filesystem.create(fn) - if success: - # file has been successfully created - self.filesystem.setfile(fn, rpl) - - def download_pandas_game(self, from_disk=False, hash=None) -> pandas.DataFrame: - if not from_disk: - if hash is None: - js = requests.get(self.BASE_REPLAY_URL + '/api/v1/parsed/list?key={}'.format(self.API_KEY)).json() - dl = random.choice(js) - else: - dl = hash + '.replay.pkl' - dl_url = self.BASE_REPLAY_URL + '/api/v1/parsed/{}?key={}'.format(dl, self.API_KEY) - r = requests.get(dl_url, stream=True) - r.raw.decode_content = True # Content-Encoding - r.raise_for_status() - try: - game = pickle.load(io.BytesIO(r.content)) - except (EOFError, ImportError): - return self.download_pandas_game(from_disk=False) - else: - game = pickle.load(open('test.pkl', 'rb')) - return game - - -if __name__ == '__main__': - dl = Downloader() - # dl.download_replays() - # print(dl.filesystem.listdir('/')) - game = dl.download_pandas_game(True) - print() diff --git a/trainer/parsed_download_trainer.py b/trainer/parsed_download_trainer.py deleted file mode 100644 index 763468a..0000000 --- a/trainer/parsed_download_trainer.py +++ /dev/null @@ -1,48 +0,0 @@ -from examples.autoencoder.autoencoder_model import AutoencoderModel -from examples.autoencoder.autoencoder_model_holder import AutoencoderModelHolder -from examples.autoencoder.autoencoder_output_formatter import AutoencoderOutputFormatter -from examples.legacy.legacy_input_formatter import LegacyInputFormatter -from examples.legacy.legacy_normalizer_input_formatter import LegacyNormalizerInputFormatter -from framework.model_holder.base_model_holder import BaseModelHolder -from framework.output_formatter.host_output_formatter import HostOutputFormatter -from trainer.base_trainer import BaseTrainer -from trainer.downloader import Downloader -import matplotlib.pyplot as plt - - -class ParsedDownloadTrainer(BaseTrainer): - def __init__(self, model_holder: BaseModelHolder): - super().__init__(model_holder) - self.downloader = Downloader() - - def process_file(self, input_file): - pass - - def train_on_file(self, name=None): - if name is None: - input_file = self.downloader.download_pandas_game(from_disk=False) - else: - input_file = self.downloader.download_pandas_game(hash=name) - self.process_file(input_file) - - def train_on_files(self, count=200): - counter = 0 - for i in range(count): - input_file = self.downloader.download_pandas_game(from_disk=False) - self.process_file(input_file) - counter += 1 - if counter % 10 == 0: - print('FILE', counter) - - def finish(self): - self.model_holder.finish_training() - - -if __name__ == '__main__': - input_formatter = LegacyNormalizerInputFormatter(LegacyInputFormatter()) - output_formatter = HostOutputFormatter(AutoencoderOutputFormatter(input_formatter)) - d = ParsedDownloadTrainer(AutoencoderModelHolder(AutoencoderModel(compressed_dim=50), - input_formatter, output_formatter)) - # d.train_on_files() - d.train_on_file() - d.finish() diff --git a/trainer/shots/__init__.py b/trainer/shots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trainer/shots/shot_trainer.py b/trainer/shots/shot_trainer.py new file mode 100644 index 0000000..e96b758 --- /dev/null +++ b/trainer/shots/shot_trainer.py @@ -0,0 +1,4 @@ +from framework.trainer.generated_data_trainer import GeneratedDataTrainer + +if __name__ == '__main__': + GeneratedDataTrainer(BaseMod) From a7d1297c550360314ec875ee063242c9a6236072 Mon Sep 17 00:00:00 2001 From: dtracers Date: Sat, 15 Sep 2018 17:14:10 -0600 Subject: [PATCH 2/6] working on training hits --- examples/base_keras_model.py | 1 - .../replays}/shots/__init__.py | 0 .../replays/shots/shot_input_formatter.py | 16 ++++++++++ examples/replays/shots/shot_model.py | 17 ++++++++++ .../replays/shots/shot_output_formatter.py | 11 +++++++ examples/replays/shots/shot_trainer.py | 13 ++++++++ .../data_generator/local_cache_creator.py | 31 +++++++++++++++++++ .../data_generator/replay/replay_generator.py | 15 ++++++--- framework/replay/replay_format.py | 5 ++- trainer/shots/shot_trainer.py | 4 --- 10 files changed, 103 insertions(+), 10 deletions(-) rename {trainer => examples/replays}/shots/__init__.py (100%) create mode 100644 examples/replays/shots/shot_input_formatter.py create mode 100644 examples/replays/shots/shot_model.py create mode 100644 examples/replays/shots/shot_output_formatter.py create mode 100644 examples/replays/shots/shot_trainer.py create mode 100644 framework/data_generator/local_cache_creator.py delete mode 100644 trainer/shots/shot_trainer.py diff --git a/examples/base_keras_model.py b/examples/base_keras_model.py index b1d5e94..f3cdab8 100644 --- a/examples/base_keras_model.py +++ b/examples/base_keras_model.py @@ -69,7 +69,6 @@ def write_log(self, callback, names, logs, batch_no, eval=False): def finalize_model(self, logname=str(int(random() * 1000))): - loss, loss_weights = self.create_loss() self.model.compile(tf.keras.optimizers.Nadam(lr=0.001), loss=loss, loss_weights=loss_weights, metrics=[tf.keras.metrics.mean_absolute_error, tf.keras.metrics.binary_accuracy]) diff --git a/trainer/shots/__init__.py b/examples/replays/shots/__init__.py similarity index 100% rename from trainer/shots/__init__.py rename to examples/replays/shots/__init__.py diff --git a/examples/replays/shots/shot_input_formatter.py b/examples/replays/shots/shot_input_formatter.py new file mode 100644 index 0000000..5d2a0a9 --- /dev/null +++ b/examples/replays/shots/shot_input_formatter.py @@ -0,0 +1,16 @@ +from framework.input_formatter.base_input_formatter import BaseInputFormatter +from framework.replay.replay_format import GeneratedHit + + +class ShotInputFormatter(BaseInputFormatter): + def get_input_state_dimension(self): + pass + + def create_input_array(self, input_data: GeneratedHit, batch_size=1): + result = [] + hit = input_data.get_hit() + frame = hit.frame_number + hit_frame = input_data.get_replay().get_pandas().loc[frame] + result = input_data.get_replay().get_pandas() + index = input_data.get_replay().get_pandas().loc[frame].name + return result diff --git a/examples/replays/shots/shot_model.py b/examples/replays/shots/shot_model.py new file mode 100644 index 0000000..2ff44dd --- /dev/null +++ b/examples/replays/shots/shot_model.py @@ -0,0 +1,17 @@ +from tensorflow.python.keras import Model + +from examples.base_keras_model import BaseKerasModel +from framework.output_formatter.base_output_formatter import BaseOutputFormatter +import tensorflow as tf + + +class ShotModel(BaseKerasModel): + + def create_output_layer(self, output_formatter: BaseOutputFormatter, hidden_layer=None): + # sigmoid/tanh all you want on self.model + if hidden_layer is None: + hidden_layer = self.hidden_layer + self.outputs = tf.keras.layers.Dense(output_formatter.get_model_output_dimension()[0], + activation='sigmoid')(hidden_layer) + self.model = Model(inputs=self.inputs, outputs=self.outputs) + return self.outputs diff --git a/examples/replays/shots/shot_output_formatter.py b/examples/replays/shots/shot_output_formatter.py new file mode 100644 index 0000000..cbd2877 --- /dev/null +++ b/examples/replays/shots/shot_output_formatter.py @@ -0,0 +1,11 @@ +from framework.output_formatter.base_output_formatter import BaseOutputFormatter +from framework.replay.replay_format import GeneratedHit + + +class ShotOutputFormatter(BaseOutputFormatter): + + def get_model_output_dimension(self): + return [1] + + def create_array_for_training(self, input_hit: GeneratedHit, batch_size=1): + return [int(input_hit.get_hit().goal)] diff --git a/examples/replays/shots/shot_trainer.py b/examples/replays/shots/shot_trainer.py new file mode 100644 index 0000000..7f8bb4e --- /dev/null +++ b/examples/replays/shots/shot_trainer.py @@ -0,0 +1,13 @@ +from examples.replays.shots.shot_input_formatter import ShotInputFormatter +from examples.replays.shots.shot_output_formatter import ShotOutputFormatter +from framework.data_generator.local_cache_creator import LocalCacheCreator +from framework.data_generator.replay.hit_generator import HitGenerator + +if __name__ == '__main__': + hit_generator = HitGenerator() + hit_generator.initialize(hit_filter={'shot': True}) + cache = LocalCacheCreator(ShotInputFormatter(), ShotOutputFormatter(), hit_generator) + + cache.create_cache() + + cache.save_cache('cache.ch') diff --git a/framework/data_generator/local_cache_creator.py b/framework/data_generator/local_cache_creator.py new file mode 100644 index 0000000..9b4054d --- /dev/null +++ b/framework/data_generator/local_cache_creator.py @@ -0,0 +1,31 @@ +import numpy as np +import pandas as pd +from carball.analysis.utils.pandas_manager import PandasManager + +from framework.data_generator.base_generator import BaseDataGenerator +from framework.input_formatter.base_input_formatter import BaseInputFormatter +from framework.output_formatter.base_output_formatter import BaseOutputFormatter + + +class LocalCacheCreator: + + def __init__(self, input_formatter: BaseInputFormatter, output_formatter: BaseOutputFormatter, data_generator: BaseDataGenerator): + self.data_generator = data_generator + self.output_formatter = output_formatter + self.input_formatter = input_formatter + self.cache = None + + def create_cache(self): + input_array = np.array([]) + output_array = np.array([]) + + for data in self.data_generator.get_data(): + np.append(input_array, np.array(self.input_formatter.create_input_array(data))) + np.append(output_array, np.array(self.output_formatter.create_array_for_training(data))) + + self.cache = pd.DataFrame(data={"input": input_array, "output": output_array}) + + def save_cache(self, file_path): + result = PandasManager.safe_write_pandas_to_memory(self.cache) + with open(file_path, 'w') as f: + f.write(result) diff --git a/framework/data_generator/replay/replay_generator.py b/framework/data_generator/replay/replay_generator.py index 916fcbd..35056c3 100644 --- a/framework/data_generator/replay/replay_generator.py +++ b/framework/data_generator/replay/replay_generator.py @@ -1,3 +1,4 @@ +import io import random import requests @@ -6,6 +7,11 @@ from framework.replay.replay_format import GeneratedReplay +def create_in_memory_file(response: bytes) -> io.BytesIO: + in_memory_file = io.BytesIO(response) + return in_memory_file + + class ReplayListGenerator(BaseDataGenerator): BASE_URL = "https://calculated.gg" @@ -20,11 +26,11 @@ def __init__(self, api_key=1, min_mmr=0, max_mmr=4000, num_players_on_team=-1, m self.shuffle = shuffle self.next_page = True - self.existing_url = '/api/v1/replays?page=1&key=' + str(api_key) + self.existing_url = None self.replays = [] - def initialize(self, **kwargs): - pass + def initialize(self, initial_page=1): + self.existing_url = '/api/v1/replays?page=' + str(initial_page) + '&key=' + str(self.api_key) def create_url(self, existing_url): return self.BASE_URL + existing_url +'&minmmr=' + str(self.min_mmr) + '&max_mmr=' + str(self.max_mmr) @@ -86,7 +92,8 @@ def download_replay(self, replay_hash): # get pts pts = requests.get(self.BASE_URL + self.DOWNLOAD_URL + replay_hash + '.replay.pts' + self.key_url) pandas = requests.get(self.BASE_URL + self.DOWNLOAD_URL + replay_hash + '.replay.gzip' + self.key_url) - replay = GeneratedReplay(protobuf=pts.content, pandas=pandas.content) + + replay = GeneratedReplay(protobuf=pts.content, pandas=create_in_memory_file(pandas.content)) return replay def __get_next_replay(self): diff --git a/framework/replay/replay_format.py b/framework/replay/replay_format.py index 31c4d43..5a73390 100644 --- a/framework/replay/replay_format.py +++ b/framework/replay/replay_format.py @@ -1,3 +1,4 @@ +import gzip import zlib from carball.analysis.utils.pandas_manager import PandasManager @@ -37,7 +38,9 @@ def get_pandas(self) -> pd.DataFrame: """ if self.decoded_pandas is not None: return self.decoded_pandas - self.decoded_pandas = PandasManager.safe_read_pandas_to_memory(zlib.decompress(self.pandas)) + self.pandas.seek(0) + with gzip.GzipFile(fileobj=self.pandas, mode='rb') as f: + self.decoded_pandas = PandasManager.safe_read_pandas_to_memory(f) self.pandas = None return self.decoded_pandas diff --git a/trainer/shots/shot_trainer.py b/trainer/shots/shot_trainer.py deleted file mode 100644 index e96b758..0000000 --- a/trainer/shots/shot_trainer.py +++ /dev/null @@ -1,4 +0,0 @@ -from framework.trainer.generated_data_trainer import GeneratedDataTrainer - -if __name__ == '__main__': - GeneratedDataTrainer(BaseMod) From c551b71f7cfee140caa8e6919228544e61a51289 Mon Sep 17 00:00:00 2001 From: dtracers Date: Sun, 16 Sep 2018 02:10:41 -0600 Subject: [PATCH 3/6] working on grabbing the correct frame --- examples/replays/shots/shot_input_formatter.py | 9 +++++++-- framework/replay/replay_format.py | 9 +++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/replays/shots/shot_input_formatter.py b/examples/replays/shots/shot_input_formatter.py index 5d2a0a9..274a02a 100644 --- a/examples/replays/shots/shot_input_formatter.py +++ b/examples/replays/shots/shot_input_formatter.py @@ -1,3 +1,4 @@ +import pandas as pd from framework.input_formatter.base_input_formatter import BaseInputFormatter from framework.replay.replay_format import GeneratedHit @@ -10,7 +11,11 @@ def create_input_array(self, input_data: GeneratedHit, batch_size=1): result = [] hit = input_data.get_hit() frame = hit.frame_number - hit_frame = input_data.get_replay().get_pandas().loc[frame] + df = input_data.get_replay().get_pandas() + df.set_index(('index',), inplace=True) + index = df.index + hit_frame = df.loc[frame] + result = input_data.get_replay().get_pandas() - index = input_data.get_replay().get_pandas().loc[frame].name + index = input_data.get_replay().get_pandas().loc[frame].index return result diff --git a/framework/replay/replay_format.py b/framework/replay/replay_format.py index 5a73390..ddce34a 100644 --- a/framework/replay/replay_format.py +++ b/framework/replay/replay_format.py @@ -41,6 +41,15 @@ def get_pandas(self) -> pd.DataFrame: self.pandas.seek(0) with gzip.GzipFile(fileobj=self.pandas, mode='rb') as f: self.decoded_pandas = PandasManager.safe_read_pandas_to_memory(f) + cols = [] + for tuple_str in self.decoded_pandas.columns.values: + cleaned_string = tuple_str.replace("'", "") + if '(' in cleaned_string: + split = cleaned_string.replace('(', '').replace(')', '').split(',') + cols.append(tuple(split)) + else: + cols.append((cleaned_string,)) + self.decoded_pandas.columns = pd.MultiIndex.from_tuples(cols) self.pandas = None return self.decoded_pandas From 0db06c7f5b1b9ca54650dfaeee368a3e4880c405 Mon Sep 17 00:00:00 2001 From: "HARRY\\Harry" Date: Sun, 16 Sep 2018 19:08:17 +0100 Subject: [PATCH 4/6] Decoding fix --- framework/replay/replay_format.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/framework/replay/replay_format.py b/framework/replay/replay_format.py index ddce34a..367e934 100644 --- a/framework/replay/replay_format.py +++ b/framework/replay/replay_format.py @@ -41,15 +41,12 @@ def get_pandas(self) -> pd.DataFrame: self.pandas.seek(0) with gzip.GzipFile(fileobj=self.pandas, mode='rb') as f: self.decoded_pandas = PandasManager.safe_read_pandas_to_memory(f) - cols = [] + + self.decoded_pandas.set_index('index', drop=True, inplace=True) + columns = [] for tuple_str in self.decoded_pandas.columns.values: - cleaned_string = tuple_str.replace("'", "") - if '(' in cleaned_string: - split = cleaned_string.replace('(', '').replace(')', '').split(',') - cols.append(tuple(split)) - else: - cols.append((cleaned_string,)) - self.decoded_pandas.columns = pd.MultiIndex.from_tuples(cols) + columns.append(eval(tuple_str)) + self.decoded_pandas.columns = pd.MultiIndex.from_tuples(columns) self.pandas = None return self.decoded_pandas From d84801c4722754417ef2b032e7a042283a4d0bb5 Mon Sep 17 00:00:00 2001 From: dtracers Date: Mon, 17 Sep 2018 00:06:10 -0600 Subject: [PATCH 5/6] started work on adding features --- .../replays/shots/shot_input_formatter.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/examples/replays/shots/shot_input_formatter.py b/examples/replays/shots/shot_input_formatter.py index 274a02a..093924d 100644 --- a/examples/replays/shots/shot_input_formatter.py +++ b/examples/replays/shots/shot_input_formatter.py @@ -1,4 +1,5 @@ import pandas as pd +import numpy as np from framework.input_formatter.base_input_formatter import BaseInputFormatter from framework.replay.replay_format import GeneratedHit @@ -12,10 +13,31 @@ def create_input_array(self, input_data: GeneratedHit, batch_size=1): hit = input_data.get_hit() frame = hit.frame_number df = input_data.get_replay().get_pandas() - df.set_index(('index',), inplace=True) index = df.index hit_frame = df.loc[frame] - result = input_data.get_replay().get_pandas() - index = input_data.get_replay().get_pandas().loc[frame].index + new_frame = hit_frame['ball'] + + proto = input_data.get_replay().get_proto() + hit_player = hit.player_id + blue_team = [] + orange_team = [] + for player in proto.players: + if player.is_orange: + orange_team.append(player) + else: + blue_team.append(player) + + return result + + def get_speed(self, frame): + return np.sqrt(frame['vel_x']**2 + frame['vel_y']**2 + frame['vel_z']**2) + + def get_distance_from_goal(self, frame, player): + + + def get_player_data(self, frame): + return [frame['pos_x'], frame['pos_y'], frame['pos_z'], + frame['rot_x'], frame['rot_y'], frame['rot_z'], + frame['vel_x'], frame['vel_y'], frame['vel_z']] From 5127b4f2a4898c747c16ef2ffb77193719d71963 Mon Sep 17 00:00:00 2001 From: dtracers Date: Wed, 10 Oct 2018 19:31:25 -0600 Subject: [PATCH 6/6] added shot training --- examples/replays/shots/shot_input_formatter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/replays/shots/shot_input_formatter.py b/examples/replays/shots/shot_input_formatter.py index 093924d..7edd3fb 100644 --- a/examples/replays/shots/shot_input_formatter.py +++ b/examples/replays/shots/shot_input_formatter.py @@ -28,14 +28,14 @@ def create_input_array(self, input_data: GeneratedHit, batch_size=1): else: blue_team.append(player) - return result def get_speed(self, frame): return np.sqrt(frame['vel_x']**2 + frame['vel_y']**2 + frame['vel_z']**2) - def get_distance_from_goal(self, frame, player): - + def get_distance_from_goal(self, frame, player, team): + if team == 0: + return np.sqrt(frame['pos_x']**2 + (frame['pos_y'] - (6000 * (1 - team)))**2 + frame['vel_z']**2) def get_player_data(self, frame): return [frame['pos_x'], frame['pos_y'], frame['pos_z'],