Skip to content

Commit

Permalink
Treat some GraphQL Java types as pseudo sealed types
Browse files Browse the repository at this point in the history
  • Loading branch information
gnawf committed Dec 2, 2024
1 parent bc72914 commit b850c10
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package graphql.nadel.definition

import graphql.nadel.engine.util.whenType
import graphql.schema.GraphQLDirective
import graphql.schema.GraphQLEnumType
import graphql.schema.GraphQLFieldsContainer
Expand Down Expand Up @@ -150,21 +151,19 @@ fun GraphQLDirective.coordinates(): NadelSchemaMemberCoordinates.Directive {
}

fun GraphQLFieldsContainer.coordinates(): NadelSchemaMemberCoordinates.FieldContainer {
return when (this) {
is GraphQLObjectType -> coordinates()
is GraphQLInterfaceType -> coordinates()
else -> throw IllegalArgumentException(javaClass.name)
}
return whenType(
interfaceType = GraphQLInterfaceType::coordinates,
objectType = GraphQLObjectType::coordinates,
)
}

fun GraphQLNamedType.coordinates(): NadelSchemaMemberCoordinates.Type {
return when (this) {
is GraphQLUnionType -> coordinates()
is GraphQLInterfaceType -> coordinates()
is GraphQLEnumType -> coordinates()
is GraphQLInputObjectType -> coordinates()
is GraphQLObjectType -> coordinates()
is GraphQLScalarType -> coordinates()
else -> throw IllegalArgumentException(javaClass.name)
}
return whenType(
enumType = GraphQLEnumType::coordinates,
inputObjectType = GraphQLInputObjectType::coordinates,
interfaceType = GraphQLInterfaceType::coordinates,
objectType = GraphQLObjectType::coordinates,
scalarType = GraphQLScalarType::coordinates,
unionType = GraphQLUnionType::coordinates,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ internal class NadelSchemaTraverser {
) {
val queue: MutableList<NadelSchemaTraverserElement> = roots
.mapNotNullTo(mutableListOf()) { typeName ->
val type = schema.typeMap[typeName] ?: schema.getDirective(typeName)
val type = schema.typeMap[typeName]
// Types can be deleted by transformer, so they may not exist in end schema
if (type == null) {
null
val directive = schema.getDirective(typeName)
if (directive == null) {
null
} else {
NadelSchemaTraverserElement.from(directive)
}
} else {
NadelSchemaTraverserElement.from(type)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql.nadel.engine.blueprint

import graphql.nadel.engine.util.unwrapAll
import graphql.nadel.engine.util.whenType
import graphql.nadel.engine.util.whenUnmodifiedType
import graphql.schema.GraphQLAppliedDirective
import graphql.schema.GraphQLAppliedDirectiveArgument
import graphql.schema.GraphQLArgument
Expand Down Expand Up @@ -42,27 +43,25 @@ internal sealed interface NadelSchemaTraverserElement {
sealed interface OutputType : NadelSchemaTraverserElement, Type {
companion object {
fun from(type: GraphQLOutputType): OutputType {
return when (val node = type.unwrapAll()) {
is GraphQLEnumType -> EnumType(node)
is GraphQLInterfaceType -> InterfaceType(node)
is GraphQLObjectType -> ObjectType(node)
is GraphQLScalarType -> ScalarType(node)
is GraphQLUnionType -> UnionType(node)
else -> throw UnsupportedOperationException(type.javaClass.name)
}
return type.whenUnmodifiedType(
enumType = ::EnumType,
interfaceType = ::InterfaceType,
objectType = ::ObjectType,
scalarType = ::ScalarType,
unionType = ::UnionType,
)
}
}
}

sealed interface InputType : NadelSchemaTraverserElement, Type {
companion object {
fun from(type: GraphQLInputType): InputType {
return when (val node = type.unwrapAll()) {
is GraphQLEnumType -> EnumType(node)
is GraphQLInputObjectType -> InputObjectType(node)
is GraphQLScalarType -> ScalarType(node)
else -> throw UnsupportedOperationException(type.javaClass.name)
}
return type.whenUnmodifiedType(
enumType = ::EnumType,
inputObjectType = ::InputObjectType,
scalarType = ::ScalarType,
)
}
}
}
Expand Down Expand Up @@ -215,33 +214,19 @@ internal sealed interface NadelSchemaTraverserElement {
}

companion object {
fun from(type: GraphQLNamedSchemaElement): NadelSchemaTraverserElement {
return when (type) {
is GraphQLEnumType -> {
EnumType(type)
}
is GraphQLInputObjectType -> {
InputObjectType(type)
}
is GraphQLInterfaceType -> {
InterfaceType(type)
}
is GraphQLObjectType -> {
ObjectType(type)
}
is GraphQLScalarType -> {
ScalarType(type)
}
is GraphQLUnionType -> {
UnionType(type)
}
is GraphQLDirective -> {
Directive(type)
}
else -> {
throw UnsupportedOperationException(type.javaClass.name)
}
}
fun from(type: GraphQLNamedType): NadelSchemaTraverserElement {
return type.whenType(
enumType = ::EnumType,
inputObjectType = ::InputObjectType,
interfaceType = ::InterfaceType,
objectType = ::ObjectType,
scalarType = ::ScalarType,
unionType = ::UnionType,
)
}

fun from(type: GraphQLDirective): NadelSchemaTraverserElement {
return Directive(type)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Util functions to treat some GraphQL Java classes _like_ sealed types.
*
* There are tests in `NadelPseudoSealedTypeKtTest` to ensure these assumptions are correct.
*/
package graphql.nadel.engine.util

import graphql.schema.GraphQLEnumType
import graphql.schema.GraphQLFieldsContainer
import graphql.schema.GraphQLInputObjectType
import graphql.schema.GraphQLInputType
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLNamedType
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLOutputType
import graphql.schema.GraphQLScalarType
import graphql.schema.GraphQLUnionType

internal inline fun <T> GraphQLFieldsContainer.whenType(
interfaceType: (GraphQLInterfaceType) -> T,
objectType: (GraphQLObjectType) -> T,
): T {
return when (this) {
is GraphQLInterfaceType -> interfaceType(this)
is GraphQLObjectType -> objectType(this)
else -> throw IllegalStateException("Should never happen")
}
}

internal inline fun <T> GraphQLNamedType.whenType(
enumType: (GraphQLEnumType) -> T,
inputObjectType: (GraphQLInputObjectType) -> T,
interfaceType: (GraphQLInterfaceType) -> T,
objectType: (GraphQLObjectType) -> T,
scalarType: (GraphQLScalarType) -> T,
unionType: (GraphQLUnionType) -> T,
): T {
return when (this) {
is GraphQLEnumType -> enumType(this)
is GraphQLInputObjectType -> inputObjectType(this)
is GraphQLInterfaceType -> interfaceType(this)
is GraphQLObjectType -> objectType(this)
is GraphQLScalarType -> scalarType(this)
is GraphQLUnionType -> unionType(this)
else -> throw IllegalStateException("Should never happen")
}
}

internal inline fun <T> GraphQLOutputType.whenUnmodifiedType(
enumType: (GraphQLEnumType) -> T,
interfaceType: (GraphQLInterfaceType) -> T,
objectType: (GraphQLObjectType) -> T,
scalarType: (GraphQLScalarType) -> T,
unionType: (GraphQLUnionType) -> T,
): T {
return when (val unmodifiedType = this.unwrapAll()) {
is GraphQLEnumType -> enumType(unmodifiedType)
is GraphQLInterfaceType -> interfaceType(unmodifiedType)
is GraphQLObjectType -> objectType(unmodifiedType)
is GraphQLScalarType -> scalarType(unmodifiedType)
is GraphQLUnionType -> unionType(unmodifiedType)
else -> throw IllegalStateException("Should never happen")
}
}

internal inline fun <T> GraphQLInputType.whenUnmodifiedType(
enumType: (GraphQLEnumType) -> T,
inputObjectType: (GraphQLInputObjectType) -> T,
scalarType: (GraphQLScalarType) -> T,
): T {
return when (val unmodifiedType = this.unwrapAll()) {
is GraphQLEnumType -> enumType(unmodifiedType)
is GraphQLInputObjectType -> inputObjectType(unmodifiedType)
is GraphQLScalarType -> scalarType(unmodifiedType)
else -> throw IllegalStateException("Should never happen")
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package graphql.nadel.archunit

import com.tngtech.archunit.core.domain.JavaClass
import com.tngtech.archunit.lang.ArchCondition
import com.tngtech.archunit.lang.ConditionEvents
import com.tngtech.archunit.lang.SimpleConditionEvent
import kotlin.reflect.KClass

class ArchUnitEqualsClassesExactly<T : Any>(
private val requiredClasses: List<Class<out T>>,
) : ArchCondition<JavaClass>("equal exactly the given classes") {
constructor(vararg requiredClasses: Class<out T>) : this(requiredClasses.toList())

constructor(vararg requiredClasses: KClass<out T>) : this(requiredClasses.map { it.java })

private val actualClasses: MutableList<JavaClass> = mutableListOf()

override fun check(item: JavaClass, events: ConditionEvents) {
actualClasses.add(item)
}

override fun finish(events: ConditionEvents) {
val requiredClassNames = requiredClasses.mapTo(LinkedHashSet()) { it.name }
val actualClassNames = actualClasses.mapTo(LinkedHashSet()) { it.fullName }

val isValid = requiredClassNames == actualClassNames
if (isValid) {
events.add(
SimpleConditionEvent(
requiredClassNames,
true,
"Classes match $requiredClassNames",
),
)
} else {
val extraClassNames = actualClassNames - requiredClassNames
val missingClassNames = requiredClassNames - actualClassNames

if (extraClassNames.isNotEmpty()) {
events.add(
SimpleConditionEvent(
requiredClassNames,
false,
"Found extra classes $extraClassNames",
),
)
}
if (missingClassNames.isNotEmpty()) {
events.add(
SimpleConditionEvent(
requiredClassNames,
false,
"Missing classes $missingClassNames"
),
)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package graphql.nadel
package graphql.nadel.archunit

import com.tngtech.archunit.base.DescribedPredicate
import com.tngtech.archunit.core.domain.JavaAccess
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package graphql.nadel
package graphql.nadel.archunit

import com.tngtech.archunit.base.DescribedPredicate
import com.tngtech.archunit.core.domain.JavaClass
Expand Down
Loading

0 comments on commit b850c10

Please sign in to comment.