Skip to content

Commit

Permalink
[OneBot] Fix v11 custom segment init and resolve errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aicorein committed Dec 11, 2024
1 parent 4bddb6a commit f006ee0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
20 changes: 14 additions & 6 deletions src/melobot/protocols/onebot/v11/adapter/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,14 @@ def add_type(

seg_cls = type(
type_classname,
(Segment,),
(_CustomSegInterface,),
{
"Model": create_model(
type_dataname,
type=(seg_type_hint, ...),
data=(seg_data_hint, ...),
)
),
"SegTypeVal": type_name,
},
)
setattr(
Expand Down Expand Up @@ -342,7 +343,10 @@ def data(self) -> _SegDataT:
@classmethod
def resolve(cls, seg_type: Any, seg_data: Any) -> Segment:
cls_name = f"{seg_type.lower().capitalize()}Segment"
cls_map = {subcls.__name__: subcls for subcls in cls.__subclasses__()}
cls_map = {
subcls.__name__: subcls
for subcls in cls.__subclasses__() + _CustomSegInterface.__subclasses__()
}
if cls_name in cls_map:
return cls_map[cls_name].resolve(seg_type, seg_data)
return cls(seg_type, **seg_data)
Expand Down Expand Up @@ -370,9 +374,13 @@ def to_json(self, force_str: bool = False) -> str:


class _CustomSegInterface(Segment[_SegTypeT, _SegDataT]):
def __init__( # pylint: disable=super-init-not-called,unused-argument
self, **data: Any
) -> None: ...
SegTypeVal: str

def __init__(self, seg_type: _SegTypeT | None = None, **seg_data: _SegDataT) -> None:
if seg_type is None:
super().__init__(cast(_SegTypeT, self.__class__.SegTypeVal), **seg_data)
else:
super().__init__(seg_type, **seg_data)


class _TextData(TypedDict):
Expand Down
9 changes: 8 additions & 1 deletion tests/onebot/v11/test_adapter_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ async def test_add_type():
s = SType(key="123")
assert s.type == "MyS"
assert s.data == {"key": "123"}
assert SType in seg.Segment.__subclasses__()
assert s.raw == {"type": "MyS", "data": {"key": "123"}}

s2 = seg.Segment.resolve("MyS", {"key": "123"})
assert s.type == "MyS"
assert s.data == {"key": "123"}
assert s.raw == {"type": "MyS", "data": {"key": "123"}}
assert isinstance(s2, SType)

with pt.raises(ValidationError):
SType(key=123)

Expand Down

0 comments on commit f006ee0

Please sign in to comment.