Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

account for overrides when generating new token #421

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Changelog

`unreleased`_
-------------
nothing yet
* Added support for auto refreshing tokens if the OAuth model is overriden as in the multi-provider example

`7.0.0`_ (2023-05-10)
---------------------
Expand Down
18 changes: 18 additions & 0 deletions flask_dance/consumer/storage/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def set(self, blueprint, token, user=None, user_id=None):
has_user = hasattr(self.model, "user")
if has_user and u:
existing_query = existing_query.filter_by(user=u)

# grab the existing model before we delete so that we can copy overriden columns
existing = existing_query.first()
# queue up delete query -- won't be run until commit()
existing_query.delete()
# create a new model for this token
Expand All @@ -222,6 +225,21 @@ def set(self, blueprint, token, user=None, user_id=None):
kwargs["user_id"] = uid
if has_user and u:
kwargs["user"] = u

if existing:
EXCLUDE_COLS = ["created_at"]
EXCLUDE_COLS.extend(kwargs.keys())
# if the oauth model is overridden, make sure to copy the columns
column_names = [
col.name
for col in self.model.__table__.columns
if not col.nullable
and not col.primary_key
and col.name not in EXCLUDE_COLS
]
for name in column_names:
kwargs[name] = getattr(existing, name)

self.session.add(self.model(**kwargs))
# commit to delete and add simultaneously
self.session.commit()
Expand Down
19 changes: 11 additions & 8 deletions tests/consumer/storage/test_sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def done():
"/oauth_done",
)

assert len(queries) == 2
assert len(queries) == 3

# check the database
authorizations = OAuth.query.all()
Expand Down Expand Up @@ -211,7 +211,7 @@ def done():
"/oauth_done",
)

assert len(queries) == 3
assert len(queries) == 4

# check the database
alice = User.query.first()
Expand Down Expand Up @@ -351,7 +351,7 @@ def load_user(userid):
"/oauth_done",
)

assert len(queries) == 5
assert len(queries) == 6

# lets do it again, with Bob as the logged in user -- he gets a different token
if "_login_user" in flask.g:
Expand Down Expand Up @@ -379,7 +379,7 @@ def load_user(userid):
"/oauth_done",
)

assert len(queries) == 5
assert len(queries) == 6

# check the database
authorizations = OAuth.query.all()
Expand Down Expand Up @@ -519,7 +519,7 @@ def logged_in(sender, token):
"/oauth_done",
)

assert len(queries) == 5
assert len(queries) == 6

# check the database
users = User.query.all()
Expand Down Expand Up @@ -692,7 +692,8 @@ def done():

def test_sqla_overwrite_token(app, db, blueprint, request):
class OAuth(OAuthConsumerMixin, db.Model):
pass
provider_user_id = db.Column(db.String, nullable=False)
provider_user_login = db.Column(db.String, nullable=False)

blueprint.storage = SQLAlchemyStorage(OAuth, db.session)

Expand All @@ -708,6 +709,8 @@ def done():
existing = OAuth(
provider="test-service",
token={"access_token": "something", "token_type": "bearer", "scope": ["blah"]},
provider_user_id="some-hash",
provider_user_login="user.name",
)
db.session.add(existing)
db.session.commit()
Expand All @@ -730,7 +733,7 @@ def done():
"/oauth_done",
)

assert len(queries) == 2
assert len(queries) == 3

# check that the database record was overwritten
authorizations = OAuth.query.all()
Expand Down Expand Up @@ -778,7 +781,7 @@ def done():
"/oauth_done",
)

assert len(queries) == 2
assert len(queries) == 3

expected_token = {"access_token": "foobar", "token_type": "bearer", "scope": [""]}

Expand Down
Loading