Skip to content

Commit

Permalink
Add TableNameResolver to tempest2
Browse files Browse the repository at this point in the history
  • Loading branch information
szabado-faire committed Apr 25, 2024
1 parent 899f19f commit 4e446a6
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package app.cash.tempest2.testing

import app.cash.tempest2.LogicalDb
import app.cash.tempest2.TableNameResolver
import com.google.common.util.concurrent.Service
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedAsyncClient
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedClient
Expand Down Expand Up @@ -55,12 +56,20 @@ interface TestDynamoDbClient : Service {
fun <DB : LogicalDb> logicalDb(
type: KClass<DB>,
extensions: List<DynamoDbEnhancedClientExtension>
): DB {
return logicalDb(type, extensions, tableNameResolver = null)
}

fun <DB : LogicalDb> logicalDb(
type: KClass<DB>,
extensions: List<DynamoDbEnhancedClientExtension>,
tableNameResolver: TableNameResolver? = null
): DB {
val enhancedClient = DynamoDbEnhancedClient.builder()
.dynamoDbClient(dynamoDb)
.extensions(extensions)
.build()
return LogicalDb.create(type, enhancedClient)
return LogicalDb.create(type, enhancedClient, tableNameResolver)
}

fun <DB : LogicalDb> logicalDb(type: Class<DB>): DB {
Expand Down Expand Up @@ -89,12 +98,20 @@ interface TestDynamoDbClient : Service {
fun <DB : AsyncLogicalDb> asyncLogicalDb(
type: KClass<DB>,
extensions: List<DynamoDbEnhancedClientExtension>
): DB {
return asyncLogicalDb(type, extensions, tableNameResolver = null)
}

fun <DB : AsyncLogicalDb> asyncLogicalDb(
type: KClass<DB>,
extensions: List<DynamoDbEnhancedClientExtension>,
tableNameResolver: TableNameResolver?
): DB {
val enhancedClient = DynamoDbEnhancedAsyncClient.builder()
.dynamoDbClient(asyncDynamoDb)
.extensions(extensions)
.build()
return app.cash.tempest2.AsyncLogicalDb.create(type, enhancedClient)
return AsyncLogicalDb.create(type, enhancedClient)
}

fun <DB : AsyncLogicalDb> asyncLogicalDb(type: Class<DB>): DB {
Expand All @@ -113,18 +130,30 @@ interface TestDynamoDbClient : Service {
}
}

inline fun <reified DB : LogicalDb> TestDynamoDbClient.logicalDb(vararg extensions: DynamoDbEnhancedClientExtension): DB {
return logicalDb(extensions.toList())
inline fun <reified DB : LogicalDb> TestDynamoDbClient.logicalDb(
vararg extensions: DynamoDbEnhancedClientExtension,
tableNameResolver: TableNameResolver? = null
): DB {
return logicalDb(extensions.toList(), tableNameResolver)
}

inline fun <reified DB : LogicalDb> TestDynamoDbClient.logicalDb(extensions: List<DynamoDbEnhancedClientExtension>): DB {
return logicalDb(DB::class, extensions)
inline fun <reified DB : LogicalDb> TestDynamoDbClient.logicalDb(
extensions: List<DynamoDbEnhancedClientExtension>,
tableNameResolver: TableNameResolver? = null
): DB {
return logicalDb(DB::class, extensions, tableNameResolver)
}

inline fun <reified DB : AsyncLogicalDb> TestDynamoDbClient.asyncLogicalDb(vararg extensions: DynamoDbEnhancedClientExtension): DB {
return asyncLogicalDb(extensions.toList())
inline fun <reified DB : AsyncLogicalDb> TestDynamoDbClient.asyncLogicalDb(
vararg extensions: DynamoDbEnhancedClientExtension,
tableNameResolver: TableNameResolver? = null
): DB {
return asyncLogicalDb(extensions.toList(), tableNameResolver)
}

inline fun <reified DB : AsyncLogicalDb> TestDynamoDbClient.asyncLogicalDb(extensions: List<DynamoDbEnhancedClientExtension>): DB {
return asyncLogicalDb(DB::class, extensions)
inline fun <reified DB : AsyncLogicalDb> TestDynamoDbClient.asyncLogicalDb(
extensions: List<DynamoDbEnhancedClientExtension>,
tableNameResolver: TableNameResolver? = null
): DB {
return asyncLogicalDb(DB::class, extensions, tableNameResolver)
}
21 changes: 15 additions & 6 deletions tempest2/src/main/kotlin/app/cash/tempest2/AsyncLogicalDb.kt
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,18 @@ interface AsyncLogicalDb : AsyncLogicalTable.Factory {

companion object {
inline operator fun <reified DB : AsyncLogicalDb> invoke(
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient,
tableNameResolver: TableNameResolver? = null
): DB {
return create(DB::class, dynamoDbEnhancedClient)
return create(DB::class, dynamoDbEnhancedClient, tableNameResolver)
}

fun <DB : AsyncLogicalDb> create(
dbType: KClass<DB>,
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient,
tableNameResolver: TableNameResolver? = null,
): DB {
return AsyncLogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType)
return AsyncLogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType, tableNameResolver)
}

// Overloaded functions for Java callers (Kotlin interface companion objects do not support
Expand All @@ -152,8 +154,15 @@ interface AsyncLogicalDb : AsyncLogicalTable.Factory {
@JvmStatic
fun <DB : AsyncLogicalDb> create(
dbType: Class<DB>,
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient
) = create(dbType.kotlin, dynamoDbEnhancedClient)
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient,
) = create(dbType, dynamoDbEnhancedClient, tableNameResolver = null)

@JvmStatic
fun <DB : AsyncLogicalDb> create(
dbType: Class<DB>,
dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient,
tableNameResolver: TableNameResolver?
) = create(dbType.kotlin, dynamoDbEnhancedClient, tableNameResolver)
}

// Overloaded functions for Java callers (Kotlin interfaces do not support `@JvmOverloads`).
Expand Down
19 changes: 14 additions & 5 deletions tempest2/src/main/kotlin/app/cash/tempest2/LogicalDb.kt
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,18 @@ interface LogicalDb : LogicalTable.Factory {

companion object {
inline operator fun <reified DB : LogicalDb> invoke(
dynamoDbEnhancedClient: DynamoDbEnhancedClient
dynamoDbEnhancedClient: DynamoDbEnhancedClient,
tableNameResolver: TableNameResolver? = null
): DB {
return create(DB::class, dynamoDbEnhancedClient)
return create(DB::class, dynamoDbEnhancedClient, tableNameResolver)
}

fun <DB : LogicalDb> create(
dbType: KClass<DB>,
dynamoDbEnhancedClient: DynamoDbEnhancedClient
dynamoDbEnhancedClient: DynamoDbEnhancedClient,
tableNameResolver: TableNameResolver? = null
): DB {
return LogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType)
return LogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType, tableNameResolver)
}

// Overloaded functions for Java callers (Kotlin interface companion objects do not support
Expand All @@ -180,7 +182,14 @@ interface LogicalDb : LogicalTable.Factory {
fun <DB : LogicalDb> create(
dbType: Class<DB>,
dynamoDbEnhancedClient: DynamoDbEnhancedClient
) = create(dbType.kotlin, dynamoDbEnhancedClient)
) = create(dbType, dynamoDbEnhancedClient, tableNameResolver = null)

@JvmStatic
fun <DB : LogicalDb> create(
dbType: Class<DB>,
dynamoDbEnhancedClient: DynamoDbEnhancedClient,
tableNameResolver: TableNameResolver?
) = create(dbType.kotlin, dynamoDbEnhancedClient, tableNameResolver)
}
}

Expand Down
10 changes: 10 additions & 0 deletions tempest2/src/main/kotlin/app/cash/tempest2/TableNameResolver.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package app.cash.tempest2

/**
* Resolves the table name for a given [LogicalTable] class.
*
* This allows table names to be overridden at runtime.
*/
interface TableNameResolver {
fun resolveTableName(clazz: Class<*>, tableNameFromAnnotation: String?): String
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import app.cash.tempest2.AsyncQueryable
import app.cash.tempest2.AsyncScannable
import app.cash.tempest2.AsyncSecondaryIndex
import app.cash.tempest2.AsyncView
import app.cash.tempest2.TableNameResolver
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbAsyncTable
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedAsyncClient
import software.amazon.awssdk.enhanced.dynamodb.TableSchema
Expand All @@ -51,7 +52,7 @@ internal class AsyncLogicalDbFactory(
V2RawItemTypeFactory()
)

fun <DB : AsyncLogicalDb> logicalDb(dbType: KClass<DB>): DB {
fun <DB : AsyncLogicalDb> logicalDb(dbType: KClass<DB>, tableNameResolver: TableNameResolver?): DB {
val logicalDb = DynamoDbLogicalDb(
DynamoDbLogicalDb.MappedTableResourceFactory.simple(dynamoDbEnhancedClient::table),
schema,
Expand All @@ -62,7 +63,7 @@ internal class AsyncLogicalDbFactory(
continue
}
val tableType = member.returnType.jvmErasure as KClass<AsyncLogicalTable<Any>>
val tableName = getTableName(member, dbType)
val tableName = getTableName(member, dbType, tableType, tableNameResolver)
val table = logicalTable(tableName, tableType)
methodHandlers[member.javaMethod] = GetterMethodHandler(table)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import app.cash.tempest2.LogicalTable
import app.cash.tempest2.Queryable
import app.cash.tempest2.Scannable
import app.cash.tempest2.SecondaryIndex
import app.cash.tempest2.TableNameResolver
import app.cash.tempest2.View
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedClient
import software.amazon.awssdk.enhanced.dynamodb.DynamoDbTable
Expand All @@ -51,7 +52,7 @@ internal class LogicalDbFactory(
V2RawItemTypeFactory()
)

fun <DB : LogicalDb> logicalDb(dbType: KClass<DB>): DB {
fun <DB : LogicalDb> logicalDb(dbType: KClass<DB>, tableNameResolver: TableNameResolver?): DB {
val logicalDb = DynamoDbLogicalDb(
DynamoDbLogicalDb.MappedTableResourceFactory.simple(dynamoDbEnhancedClient::table),
schema,
Expand All @@ -62,7 +63,7 @@ internal class LogicalDbFactory(
continue
}
val tableType = member.returnType.jvmErasure as KClass<LogicalTable<Any>>
val tableName = getTableName(member, dbType)
val tableName = getTableName(member, dbType, tableType, tableNameResolver)
val table = logicalTable(tableName, tableType)
methodHandlers[member.javaMethod] = GetterMethodHandler(table)
}
Expand Down
12 changes: 10 additions & 2 deletions tempest2/src/main/kotlin/app/cash/tempest2/internal/V2.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import app.cash.tempest.internal.StringAttributeValue
import app.cash.tempest2.Attribute
import app.cash.tempest2.ForIndex
import app.cash.tempest2.TableName
import app.cash.tempest2.TableNameResolver
import software.amazon.awssdk.enhanced.dynamodb.TableMetadata
import software.amazon.awssdk.enhanced.dynamodb.TableSchema
import software.amazon.awssdk.services.dynamodb.model.AttributeValue
Expand Down Expand Up @@ -87,8 +88,15 @@ internal class V2RawItemTypeFactory : RawItemType.Factory {
}
}

internal fun getTableName(member: ClassMember, dbType: KClass<*>): String {
val tableName = member.annotations.filterIsInstance<TableName>().singleOrNull()?.value
internal fun getTableName(
member: ClassMember,
dbType: KClass<*>,
tableType: KClass<*>,
tableNameResolver: TableNameResolver?
): String {
val tableNameFromAnnotation = member.annotations.filterIsInstance<TableName>().singleOrNull()?.value
val tableNameFromResolver = tableNameResolver?.resolveTableName(tableType.java, tableNameFromAnnotation)
val tableName = tableNameFromResolver ?: tableNameFromAnnotation
requireNotNull(tableName) {
"Please annotate ${member.javaMethod} in $dbType with `@TableName`"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2021 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package app.cash.tempest2

import app.cash.tempest2.musiclibrary.AlbumInfo
import app.cash.tempest2.musiclibrary.AsyncMusicDb
import app.cash.tempest2.musiclibrary.AsyncMusicTable
import app.cash.tempest2.musiclibrary.testDb
import app.cash.tempest2.testing.asyncLogicalDb
import java.time.LocalDate
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.RegisterExtension

class DynamoDbAsyncTableNameResolverTest {

@RegisterExtension
@JvmField
val db = testDb("custom_table_name")

private object TestTableNameResolver : TableNameResolver {
override fun resolveTableName(clazz: Class<*>, tableNameFromAnnotation: String?): String {
check(clazz == AsyncMusicTable::class.java)

return "custom_table_name"
}
}

private val musicTable = db.asyncLogicalDb<AsyncMusicDb>(extensions = listOf(), TestTableNameResolver).music

@Test
fun loadAfterSave() = runBlockingTest {
val albumInfo = AlbumInfo(
"ALBUM_1",
"after hours - EP",
"53 Thieves",
LocalDate.of(2020, 2, 21),
"Contemporary R&B"
)
musicTable.albumInfo.save(albumInfo)

// Query the movies created.
val loadedAlbumInfo = musicTable.albumInfo.load(albumInfo.key)!!

assertThat(loadedAlbumInfo.album_token).isEqualTo(albumInfo.album_token)
assertThat(loadedAlbumInfo.artist_name).isEqualTo(albumInfo.artist_name)
assertThat(loadedAlbumInfo.release_date).isEqualTo(albumInfo.release_date)
assertThat(loadedAlbumInfo.genre_name).isEqualTo(albumInfo.genre_name)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2021 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package app.cash.tempest2

import app.cash.tempest2.musiclibrary.AlbumInfo
import app.cash.tempest2.musiclibrary.MusicDb
import app.cash.tempest2.musiclibrary.MusicTable
import app.cash.tempest2.musiclibrary.testDb
import app.cash.tempest2.testing.logicalDb
import java.time.LocalDate
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.RegisterExtension

class DynamoDbTableNameResolverTest {

@RegisterExtension
@JvmField
val db = testDb("custom_table_name")

private object TestTableNameResolver : TableNameResolver {
override fun resolveTableName(clazz: Class<*>, tableNameFromAnnotation: String?): String {
check(clazz == MusicTable::class.java)

return "custom_table_name"
}
}

private val musicTable by lazy { db.logicalDb<MusicDb>(extensions = listOf(), TestTableNameResolver).music }

@Test
fun loadAfterSave() {
val albumInfo = AlbumInfo(
"ALBUM_1",
"after hours - EP",
"53 Thieves",
LocalDate.of(2020, 2, 21),
"Contemporary R&B"
)
musicTable.albumInfo.save(albumInfo)

// Query the movies created.
val loadedAlbumInfo = musicTable.albumInfo.load(albumInfo.key)!!

assertThat(loadedAlbumInfo.album_token).isEqualTo(albumInfo.album_token)
assertThat(loadedAlbumInfo.artist_name).isEqualTo(albumInfo.artist_name)
assertThat(loadedAlbumInfo.release_date).isEqualTo(albumInfo.release_date)
assertThat(loadedAlbumInfo.genre_name).isEqualTo(albumInfo.genre_name)
}
}
Loading

0 comments on commit 4e446a6

Please sign in to comment.