Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gemenerik committed Mar 14, 2024
1 parent 23165d2 commit d124052
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
14 changes: 8 additions & 6 deletions cflib/localization/param_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import yaml

from cflib.crazyflie.param import PersistentParamState


class ParamFileManager():
"""Reads and writes parameter configurations from file"""
TYPE_ID = 'type'
Expand All @@ -40,7 +40,8 @@ def write(file_name, params={}):
for id, param in params.items():
assert isinstance(param, PersistentParamState)
if isinstance(param, PersistentParamState):
file_params[id] = {'is_stored': param.is_stored, 'default_value': param.default_value, 'stored_value': param.stored_value}
file_params[id] = {'is_stored': param.is_stored,
'default_value': param.default_value, 'stored_value': param.stored_value}

data = {
ParamFileManager.TYPE_ID: ParamFileManager.TYPE,
Expand All @@ -49,7 +50,7 @@ def write(file_name, params={}):
}

yaml.dump(data, file)

@staticmethod
def read(file_name):
file = open(file_name, 'r')
Expand All @@ -71,13 +72,14 @@ def read(file_name):

if data[ParamFileManager.VERSION_ID] != ParamFileManager.VERSION:
raise Exception('Unsupported file version')

def get_data(input_data):
persistent_params = {}
for id, param in input_data.items():
persistent_params[id] = PersistentParamState(param['is_stored'], param['default_value'], param['stored_value'])
persistent_params[id] = PersistentParamState(
param['is_stored'], param['default_value'], param['stored_value'])
return persistent_params

if ParamFileManager.PARAMS_ID in data:
return get_data(data[ParamFileManager.PARAMS_ID])
else:
Expand Down
12 changes: 6 additions & 6 deletions test/localization/test_param_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import unittest
from unittest.mock import ANY
from unittest.mock import mock_open
Expand All @@ -29,6 +28,7 @@

from cflib.localization import ParamFileManager


class TestParamFileManager(unittest.TestCase):
def setUp(self):
self.data = {
Expand All @@ -48,7 +48,7 @@ def test_that_read_open_correct_file(self, mock_yaml_load):

# Assert
mock_file.assert_called_with(file_name, 'r')

@patch('yaml.safe_load')
def test_that_missing_file_type_raises(self, mock_yaml_load):
# Fixture
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_that_missing_version_raises(self, mock_yaml_load):
with self.assertRaises(Exception):
with patch('builtins.open', new_callable=mock_open()):
ParamFileManager.read('some/name.yaml')

@patch('yaml.safe_load')
def test_that_wrong_version_raises(self, mock_yaml_load):
# Fixture
Expand All @@ -106,10 +106,10 @@ def test_that_no_data_returns_empty_default_data(self, mock_yaml_load):
# Test
with patch('builtins.open', new_callable=mock_open()):
actual_params = ParamFileManager.read('some/name.yaml')

# Assert
self.assertEqual(0, len(actual_params))

@patch('yaml.dump')
def test_file_end_to_end_write_read(self, mock_yaml_dump):
# Fixture
Expand All @@ -126,7 +126,7 @@ def test_file_end_to_end_write_read(self, mock_yaml_dump):

# Assert
mock_yaml_dump.assert_called_with(expected, ANY)

@patch('yaml.dump')
def test_file_write_to_correct_file(self, mock_yaml_dump):
# Fixture
Expand Down

0 comments on commit d124052

Please sign in to comment.