diff --git a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs index ffc6538..6b51273 100644 --- a/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs +++ b/src/Abc.Zebus.MessageDsl.Tests/MessageDsl/ParsedContractsTests.cs @@ -671,6 +671,16 @@ public void should_parse_nested_classes() message.ContainingClasses.ShouldEqual(new TypeName[] { "Foo", "Bar" }); } + [Test] + public void should_detect_inheritance_loops() + { + ParseInvalid(@"Foo() : Foo;"); + ParseInvalid(@"Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar() : Foo;"); + ParseInvalid(@"Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar() : Baz; [ProtoInclude(1, typeof(Bar))] Baz() : Foo;"); + + ParseValid(@"Foo() : Bar; [ProtoInclude(1, typeof(Foo))] Bar();"); + } + private static ParsedContracts ParseValid(string definitionText) { var contracts = Parse(definitionText); diff --git a/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs b/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs index d325e15..b72a1e9 100644 --- a/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs +++ b/src/Abc.Zebus.MessageDsl/Analysis/AstValidator.cs @@ -75,6 +75,8 @@ private void ValidateMessage(MessageDefinition message) foreach (var baseType in message.BaseTypes) ValidateType(baseType, message.ParseContext); + + ValidateInheritance(message); } private void ValidateTags(MessageDefinition message) @@ -146,6 +148,35 @@ private void ValidateType(TypeName type, ParserRuleContext? context) _contracts.AddError(context, "Invalid type: {0}", type.NetType); } + private void ValidateInheritance(MessageDefinition message) + { + if (message.BaseTypes.Count == 0) + return; + + var seenTypes = new HashSet + { + message.Name + }; + + var currentMessage = message; + + while (true) + { + if (currentMessage.BaseTypes.Count == 0) + break; + + currentMessage = _contracts.Messages.FirstOrDefault(m => m.Name == currentMessage.BaseTypes[0].NetType); + if (currentMessage is null) + break; + + if (!seenTypes.Add(currentMessage.Name)) + { + _contracts.AddError(message.ParseContext, "There is a loop in the inheritance chain"); + break; + } + } + } + private void DetectDuplicateTypes() { var seenTypes = new HashSet();