Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[record] defend against problematic collisions #22860

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python_modules/dagster/dagster/_record/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from abc import ABC
from functools import partial
from typing import (
Expand Down Expand Up @@ -46,7 +47,24 @@ def _namedtuple_model_transform(
* creates a run time checked __new__ (optional).
"""
field_set = getattr(cls, "__annotations__", {})
defaults = {name: getattr(cls, name) for name in field_set.keys() if hasattr(cls, name)}

defaults = {}
for name in field_set.keys():
if hasattr(cls, name):
attr_val = getattr(cls, name)
check.invariant(
not isinstance(attr_val, property),
f"Conflicting @property for field {name} on record {cls.__name__}."
"If you are trying to declare an abstract property "
"you will have to use a class attribute instead.",
)
check.invariant(
not inspect.isfunction(attr_val),
f"Conflicting function for field {name} on record {cls.__name__}. "
"If you are trying to set a function as a default value "
"you will have to override __new__.",
)
defaults[name] = attr_val

base = NamedTuple(f"_{cls.__name__}", field_set.items())
nt_new = base.__new__
Expand Down Expand Up @@ -83,9 +101,8 @@ def _namedtuple_model_transform(
check.failed(f"Expected __new__ on {cls}, add it or switch from the _with_new decorator.")

# clear default values
for name in field_set.keys():
if hasattr(cls, name):
delattr(cls, name)
for name in defaults.keys():
delattr(cls, name)

new_type = type(
cls.__name__,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

import pytest
Expand Down Expand Up @@ -436,3 +437,42 @@ def test_pickle():

a2 = Agent(name="mr. clean")
assert a2 == pickle.loads(pickle.dumps(a2))


def test_default_collision() -> None:
class BadBase(ABC):
@property
@abstractmethod
def abstract_prop(self): ...

def some_method(self): ...

with pytest.raises(check.CheckError, match="Conflicting @property"):

@record
class _(BadBase):
abstract_prop: Any

with pytest.raises(check.CheckError, match="Conflicting function"):

@record
class _(BadBase):
some_method: Any

class Base(ABC):
thing: Any

@record
class Impl(Base):
thing: Any

assert Impl(thing=3).thing == 3

with pytest.raises(check.CheckError, match="will have to override __new__"):

def _some_func():
return 4

@record
class _(Base):
thing: Any = _some_func