Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #459 - Check inner type of the scalar list #464

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions sqlalchemy_utils/types/scalar_list.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import six
import sqlalchemy as sa
from sqlalchemy import types
Expand All @@ -11,8 +13,7 @@ class ScalarListType(types.TypeDecorator):
"""
ScalarListType type provides convenient way for saving multiple scalar
values in one column. ScalarListType works like list on python side and
saves the result as comma-separated list in the database (custom separators
can also be used).
saves the result as comma-separated list in the database.

Example ::

Expand Down Expand Up @@ -50,13 +51,31 @@ class Player(Base):
session.commit()


:param inner_type:
The type of the values. Default is ``str``.
:param separator:
Separator of the values. Default is ``,``.
:param coerce_func:
Custom function to coerce values when read from database.
By default ``inner_type`` is used instead.
"""

impl = sa.UnicodeText()

def __init__(self, coerce_func=six.text_type, separator=u','):
def __init__(self, inner_type=six.text_type, separator=u',',
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keeping the parameter order as it was and adding the new parameter as the last parameter? That way we wouldn't need to throw any warnings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about that. I believe, having inner_type defined is the only way I can ensure validation of the values before they're written to database. So I made inner_type primary and coerce_func optional. It would be weird to have an optional parameter first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kvesteri Any decision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kvesteri I'd like to move this forward. Can I keep the change, or should I add the inner_type as the last parameter?

coerce_func=None):
self.separator = six.text_type(separator)
self.coerce_func = coerce_func
if not isinstance(inner_type, type) and coerce_func is None:
warn_msg = (
"ScalarListType has new required argument 'inner_type'. "
"Provide the type of the values and if required, "
"pass coerce func as a keyword argument.")
warnings.warn(warn_msg, DeprecationWarning)
self.inner_type = None
self.coerce_func = inner_type
else:
self.inner_type = inner_type
self.coerce_func = coerce_func

def process_bind_param(self, value, dialect):
# Convert list of values to unicode separator-separated list
Expand All @@ -69,6 +88,10 @@ def process_bind_param(self, value, dialect):
"these strings, use a different separator string.)"
% self.separator
)
if self.inner_type is not None:
if any(not isinstance(i, self.inner_type) for i in value):
msg = "Not all items in value {} match the type {}"
raise ValueError(msg.format(value, self.inner_type))
return self.separator.join(
map(six.text_type, value)
)
Expand All @@ -78,6 +101,7 @@ def process_result_value(self, value, dialect):
if value == u'':
return []
# coerce each value
coerce_func = self.coerce_func or self.inner_type
return list(map(
self.coerce_func, value.split(self.separator)
coerce_func, value.split(self.separator)
))
54 changes: 54 additions & 0 deletions tests/types/test_scalar_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def test_save_integer_list(self, session, User):
user = session.query(User).first()
assert user.some_list == [1, 2, 3, 4]

def test_save_integer_list_invalid(self, session, User):
user = User(
some_list=[1, 2, 'invalid', 4]
)

session.add(user)
with pytest.raises(sa.exc.StatementError):
session.commit()


class TestScalarUnicodeList(object):

Expand Down Expand Up @@ -92,3 +101,48 @@ def test_save_and_retrieve_empty_list(self, session, User):

user = session.query(User).first()
assert user.some_list == []


def custom_int(value):
return int(value)


@pytest.mark.filterwarnings(
"ignore:ScalarListType has new required argument 'inner_type'")
class TestScalarListCoerceFunc(object):
"""Test deprecated behaviour with single argument which is not a type."""

@pytest.fixture
def User(self, Base):
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
some_list = sa.Column(ScalarListType(custom_int))

def __repr__(self):
return 'User(%r)' % self.id

return User

@pytest.fixture
def init_models(self, User):
pass

def test_save_integer_list(self, session, User):
user = User(some_list=[1, 2, 3, 4])

session.add(user)
session.commit()

user = session.query(User).first()
assert user.some_list == [1, 2, 3, 4]

def test_save_integer_list_invalid(self, session, User):
user = User(some_list=[1, 2, 'invalid', 4])

session.add(user)
session.commit()

# It stores invalid value to database and fails on coerce after read.
with pytest.raises(ValueError, match='invalid literal for int'):
session.query(User).first()