diff --git a/tempest2-testing/src/main/kotlin/app/cash/tempest2/testing/TestDynamoDbClient.kt b/tempest2-testing/src/main/kotlin/app/cash/tempest2/testing/TestDynamoDbClient.kt index 3a7e16492..ec2b9e371 100644 --- a/tempest2-testing/src/main/kotlin/app/cash/tempest2/testing/TestDynamoDbClient.kt +++ b/tempest2-testing/src/main/kotlin/app/cash/tempest2/testing/TestDynamoDbClient.kt @@ -17,6 +17,7 @@ package app.cash.tempest2.testing import app.cash.tempest2.LogicalDb +import app.cash.tempest2.TableNameResolver import com.google.common.util.concurrent.Service import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedAsyncClient import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedClient @@ -55,12 +56,20 @@ interface TestDynamoDbClient : Service { fun logicalDb( type: KClass, extensions: List + ): DB { + return logicalDb(type, extensions, tableNameResolver = null) + } + + fun logicalDb( + type: KClass, + extensions: List, + tableNameResolver: TableNameResolver? = null ): DB { val enhancedClient = DynamoDbEnhancedClient.builder() .dynamoDbClient(dynamoDb) .extensions(extensions) .build() - return LogicalDb.create(type, enhancedClient) + return LogicalDb.create(type, enhancedClient, tableNameResolver) } fun logicalDb(type: Class): DB { @@ -89,12 +98,20 @@ interface TestDynamoDbClient : Service { fun asyncLogicalDb( type: KClass, extensions: List + ): DB { + return asyncLogicalDb(type, extensions, tableNameResolver = null) + } + + fun asyncLogicalDb( + type: KClass, + extensions: List, + tableNameResolver: TableNameResolver? ): DB { val enhancedClient = DynamoDbEnhancedAsyncClient.builder() .dynamoDbClient(asyncDynamoDb) .extensions(extensions) .build() - return app.cash.tempest2.AsyncLogicalDb.create(type, enhancedClient) + return AsyncLogicalDb.create(type, enhancedClient) } fun asyncLogicalDb(type: Class): DB { @@ -113,18 +130,30 @@ interface TestDynamoDbClient : Service { } } -inline fun TestDynamoDbClient.logicalDb(vararg extensions: DynamoDbEnhancedClientExtension): DB { - return logicalDb(extensions.toList()) +inline fun TestDynamoDbClient.logicalDb( + vararg extensions: DynamoDbEnhancedClientExtension, + tableNameResolver: TableNameResolver? = null +): DB { + return logicalDb(extensions.toList(), tableNameResolver) } -inline fun TestDynamoDbClient.logicalDb(extensions: List): DB { - return logicalDb(DB::class, extensions) +inline fun TestDynamoDbClient.logicalDb( + extensions: List, + tableNameResolver: TableNameResolver? = null +): DB { + return logicalDb(DB::class, extensions, tableNameResolver) } -inline fun TestDynamoDbClient.asyncLogicalDb(vararg extensions: DynamoDbEnhancedClientExtension): DB { - return asyncLogicalDb(extensions.toList()) +inline fun TestDynamoDbClient.asyncLogicalDb( + vararg extensions: DynamoDbEnhancedClientExtension, + tableNameResolver: TableNameResolver? = null +): DB { + return asyncLogicalDb(extensions.toList(), tableNameResolver) } -inline fun TestDynamoDbClient.asyncLogicalDb(extensions: List): DB { - return asyncLogicalDb(DB::class, extensions) +inline fun TestDynamoDbClient.asyncLogicalDb( + extensions: List, + tableNameResolver: TableNameResolver? = null +): DB { + return asyncLogicalDb(DB::class, extensions, tableNameResolver) } diff --git a/tempest2/src/main/kotlin/app/cash/tempest2/AsyncLogicalDb.kt b/tempest2/src/main/kotlin/app/cash/tempest2/AsyncLogicalDb.kt index e5a8c7f2d..76cacb5a2 100644 --- a/tempest2/src/main/kotlin/app/cash/tempest2/AsyncLogicalDb.kt +++ b/tempest2/src/main/kotlin/app/cash/tempest2/AsyncLogicalDb.kt @@ -133,16 +133,18 @@ interface AsyncLogicalDb : AsyncLogicalTable.Factory { companion object { inline operator fun invoke( - dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient + dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient, + tableNameResolver: TableNameResolver? = null ): DB { - return create(DB::class, dynamoDbEnhancedClient) + return create(DB::class, dynamoDbEnhancedClient, tableNameResolver) } fun create( dbType: KClass, - dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient + dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient, + tableNameResolver: TableNameResolver? = null, ): DB { - return AsyncLogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType) + return AsyncLogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType, tableNameResolver) } // Overloaded functions for Java callers (Kotlin interface companion objects do not support @@ -152,8 +154,15 @@ interface AsyncLogicalDb : AsyncLogicalTable.Factory { @JvmStatic fun create( dbType: Class, - dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient - ) = create(dbType.kotlin, dynamoDbEnhancedClient) + dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient, + ) = create(dbType, dynamoDbEnhancedClient, tableNameResolver = null) + + @JvmStatic + fun create( + dbType: Class, + dynamoDbEnhancedClient: DynamoDbEnhancedAsyncClient, + tableNameResolver: TableNameResolver? + ) = create(dbType.kotlin, dynamoDbEnhancedClient, tableNameResolver) } // Overloaded functions for Java callers (Kotlin interfaces do not support `@JvmOverloads`). diff --git a/tempest2/src/main/kotlin/app/cash/tempest2/LogicalDb.kt b/tempest2/src/main/kotlin/app/cash/tempest2/LogicalDb.kt index 39df00cc1..52fcbf040 100644 --- a/tempest2/src/main/kotlin/app/cash/tempest2/LogicalDb.kt +++ b/tempest2/src/main/kotlin/app/cash/tempest2/LogicalDb.kt @@ -160,16 +160,18 @@ interface LogicalDb : LogicalTable.Factory { companion object { inline operator fun invoke( - dynamoDbEnhancedClient: DynamoDbEnhancedClient + dynamoDbEnhancedClient: DynamoDbEnhancedClient, + tableNameResolver: TableNameResolver? = null ): DB { - return create(DB::class, dynamoDbEnhancedClient) + return create(DB::class, dynamoDbEnhancedClient, tableNameResolver) } fun create( dbType: KClass, - dynamoDbEnhancedClient: DynamoDbEnhancedClient + dynamoDbEnhancedClient: DynamoDbEnhancedClient, + tableNameResolver: TableNameResolver? = null ): DB { - return LogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType) + return LogicalDbFactory(dynamoDbEnhancedClient).logicalDb(dbType, tableNameResolver) } // Overloaded functions for Java callers (Kotlin interface companion objects do not support @@ -180,7 +182,14 @@ interface LogicalDb : LogicalTable.Factory { fun create( dbType: Class, dynamoDbEnhancedClient: DynamoDbEnhancedClient - ) = create(dbType.kotlin, dynamoDbEnhancedClient) + ) = create(dbType, dynamoDbEnhancedClient, tableNameResolver = null) + + @JvmStatic + fun create( + dbType: Class, + dynamoDbEnhancedClient: DynamoDbEnhancedClient, + tableNameResolver: TableNameResolver? + ) = create(dbType.kotlin, dynamoDbEnhancedClient, tableNameResolver) } } diff --git a/tempest2/src/main/kotlin/app/cash/tempest2/TableNameResolver.kt b/tempest2/src/main/kotlin/app/cash/tempest2/TableNameResolver.kt new file mode 100644 index 000000000..e17d2c378 --- /dev/null +++ b/tempest2/src/main/kotlin/app/cash/tempest2/TableNameResolver.kt @@ -0,0 +1,10 @@ +package app.cash.tempest2 + +/** + * Resolves the table name for a given [LogicalTable] class. + * + * This allows table names to be overridden at runtime. + */ +interface TableNameResolver { + fun resolveTableName(clazz: Class<*>, tableNameFromAnnotation: String?): String +} diff --git a/tempest2/src/main/kotlin/app/cash/tempest2/internal/AsyncLogicalDbFactory.kt b/tempest2/src/main/kotlin/app/cash/tempest2/internal/AsyncLogicalDbFactory.kt index bd9dddb2d..384eafb23 100644 --- a/tempest2/src/main/kotlin/app/cash/tempest2/internal/AsyncLogicalDbFactory.kt +++ b/tempest2/src/main/kotlin/app/cash/tempest2/internal/AsyncLogicalDbFactory.kt @@ -32,6 +32,7 @@ import app.cash.tempest2.AsyncQueryable import app.cash.tempest2.AsyncScannable import app.cash.tempest2.AsyncSecondaryIndex import app.cash.tempest2.AsyncView +import app.cash.tempest2.TableNameResolver import software.amazon.awssdk.enhanced.dynamodb.DynamoDbAsyncTable import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedAsyncClient import software.amazon.awssdk.enhanced.dynamodb.TableSchema @@ -51,7 +52,7 @@ internal class AsyncLogicalDbFactory( V2RawItemTypeFactory() ) - fun logicalDb(dbType: KClass): DB { + fun logicalDb(dbType: KClass, tableNameResolver: TableNameResolver?): DB { val logicalDb = DynamoDbLogicalDb( DynamoDbLogicalDb.MappedTableResourceFactory.simple(dynamoDbEnhancedClient::table), schema, @@ -62,7 +63,7 @@ internal class AsyncLogicalDbFactory( continue } val tableType = member.returnType.jvmErasure as KClass> - val tableName = getTableName(member, dbType) + val tableName = getTableName(member, dbType, tableType, tableNameResolver) val table = logicalTable(tableName, tableType) methodHandlers[member.javaMethod] = GetterMethodHandler(table) } diff --git a/tempest2/src/main/kotlin/app/cash/tempest2/internal/LogicalDbFactory.kt b/tempest2/src/main/kotlin/app/cash/tempest2/internal/LogicalDbFactory.kt index 2cf8614c1..720011d23 100644 --- a/tempest2/src/main/kotlin/app/cash/tempest2/internal/LogicalDbFactory.kt +++ b/tempest2/src/main/kotlin/app/cash/tempest2/internal/LogicalDbFactory.kt @@ -31,6 +31,7 @@ import app.cash.tempest2.LogicalTable import app.cash.tempest2.Queryable import app.cash.tempest2.Scannable import app.cash.tempest2.SecondaryIndex +import app.cash.tempest2.TableNameResolver import app.cash.tempest2.View import software.amazon.awssdk.enhanced.dynamodb.DynamoDbEnhancedClient import software.amazon.awssdk.enhanced.dynamodb.DynamoDbTable @@ -51,7 +52,7 @@ internal class LogicalDbFactory( V2RawItemTypeFactory() ) - fun logicalDb(dbType: KClass): DB { + fun logicalDb(dbType: KClass, tableNameResolver: TableNameResolver?): DB { val logicalDb = DynamoDbLogicalDb( DynamoDbLogicalDb.MappedTableResourceFactory.simple(dynamoDbEnhancedClient::table), schema, @@ -62,7 +63,7 @@ internal class LogicalDbFactory( continue } val tableType = member.returnType.jvmErasure as KClass> - val tableName = getTableName(member, dbType) + val tableName = getTableName(member, dbType, tableType, tableNameResolver) val table = logicalTable(tableName, tableType) methodHandlers[member.javaMethod] = GetterMethodHandler(table) } diff --git a/tempest2/src/main/kotlin/app/cash/tempest2/internal/V2.kt b/tempest2/src/main/kotlin/app/cash/tempest2/internal/V2.kt index b8930422f..9ffd0651c 100644 --- a/tempest2/src/main/kotlin/app/cash/tempest2/internal/V2.kt +++ b/tempest2/src/main/kotlin/app/cash/tempest2/internal/V2.kt @@ -27,6 +27,7 @@ import app.cash.tempest.internal.StringAttributeValue import app.cash.tempest2.Attribute import app.cash.tempest2.ForIndex import app.cash.tempest2.TableName +import app.cash.tempest2.TableNameResolver import software.amazon.awssdk.enhanced.dynamodb.TableMetadata import software.amazon.awssdk.enhanced.dynamodb.TableSchema import software.amazon.awssdk.services.dynamodb.model.AttributeValue @@ -87,8 +88,15 @@ internal class V2RawItemTypeFactory : RawItemType.Factory { } } -internal fun getTableName(member: ClassMember, dbType: KClass<*>): String { - val tableName = member.annotations.filterIsInstance().singleOrNull()?.value +internal fun getTableName( + member: ClassMember, + dbType: KClass<*>, + tableType: KClass<*>, + tableNameResolver: TableNameResolver? +): String { + val tableNameFromAnnotation = member.annotations.filterIsInstance().singleOrNull()?.value + val tableNameFromResolver = tableNameResolver?.resolveTableName(tableType.java, tableNameFromAnnotation) + val tableName = tableNameFromResolver ?: tableNameFromAnnotation requireNotNull(tableName) { "Please annotate ${member.javaMethod} in $dbType with `@TableName`" } diff --git a/tempest2/src/test/kotlin/app/cash/tempest2/DynamoDbAsyncTableNameResolverTest.kt b/tempest2/src/test/kotlin/app/cash/tempest2/DynamoDbAsyncTableNameResolverTest.kt new file mode 100644 index 000000000..32463f579 --- /dev/null +++ b/tempest2/src/test/kotlin/app/cash/tempest2/DynamoDbAsyncTableNameResolverTest.kt @@ -0,0 +1,64 @@ +/* + * Copyright 2021 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package app.cash.tempest2 + +import app.cash.tempest2.musiclibrary.AlbumInfo +import app.cash.tempest2.musiclibrary.AsyncMusicDb +import app.cash.tempest2.musiclibrary.AsyncMusicTable +import app.cash.tempest2.musiclibrary.testDb +import app.cash.tempest2.testing.asyncLogicalDb +import java.time.LocalDate +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class DynamoDbAsyncTableNameResolverTest { + + @RegisterExtension + @JvmField + val db = testDb("custom_table_name") + + private object TestTableNameResolver : TableNameResolver { + override fun resolveTableName(clazz: Class<*>, tableNameFromAnnotation: String?): String { + check(clazz == AsyncMusicTable::class.java) + + return "custom_table_name" + } + } + + private val musicTable = db.asyncLogicalDb(extensions = listOf(), TestTableNameResolver).music + + @Test + fun loadAfterSave() = runBlockingTest { + val albumInfo = AlbumInfo( + "ALBUM_1", + "after hours - EP", + "53 Thieves", + LocalDate.of(2020, 2, 21), + "Contemporary R&B" + ) + musicTable.albumInfo.save(albumInfo) + + // Query the movies created. + val loadedAlbumInfo = musicTable.albumInfo.load(albumInfo.key)!! + + assertThat(loadedAlbumInfo.album_token).isEqualTo(albumInfo.album_token) + assertThat(loadedAlbumInfo.artist_name).isEqualTo(albumInfo.artist_name) + assertThat(loadedAlbumInfo.release_date).isEqualTo(albumInfo.release_date) + assertThat(loadedAlbumInfo.genre_name).isEqualTo(albumInfo.genre_name) + } +} diff --git a/tempest2/src/test/kotlin/app/cash/tempest2/DynamoDbTableNameResolverTest.kt b/tempest2/src/test/kotlin/app/cash/tempest2/DynamoDbTableNameResolverTest.kt new file mode 100644 index 000000000..ab7de63c1 --- /dev/null +++ b/tempest2/src/test/kotlin/app/cash/tempest2/DynamoDbTableNameResolverTest.kt @@ -0,0 +1,64 @@ +/* + * Copyright 2021 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package app.cash.tempest2 + +import app.cash.tempest2.musiclibrary.AlbumInfo +import app.cash.tempest2.musiclibrary.MusicDb +import app.cash.tempest2.musiclibrary.MusicTable +import app.cash.tempest2.musiclibrary.testDb +import app.cash.tempest2.testing.logicalDb +import java.time.LocalDate +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class DynamoDbTableNameResolverTest { + + @RegisterExtension + @JvmField + val db = testDb("custom_table_name") + + private object TestTableNameResolver : TableNameResolver { + override fun resolveTableName(clazz: Class<*>, tableNameFromAnnotation: String?): String { + check(clazz == MusicTable::class.java) + + return "custom_table_name" + } + } + + private val musicTable by lazy { db.logicalDb(extensions = listOf(), TestTableNameResolver).music } + + @Test + fun loadAfterSave() { + val albumInfo = AlbumInfo( + "ALBUM_1", + "after hours - EP", + "53 Thieves", + LocalDate.of(2020, 2, 21), + "Contemporary R&B" + ) + musicTable.albumInfo.save(albumInfo) + + // Query the movies created. + val loadedAlbumInfo = musicTable.albumInfo.load(albumInfo.key)!! + + assertThat(loadedAlbumInfo.album_token).isEqualTo(albumInfo.album_token) + assertThat(loadedAlbumInfo.artist_name).isEqualTo(albumInfo.artist_name) + assertThat(loadedAlbumInfo.release_date).isEqualTo(albumInfo.release_date) + assertThat(loadedAlbumInfo.genre_name).isEqualTo(albumInfo.genre_name) + } +} diff --git a/tempest2/src/test/kotlin/app/cash/tempest2/musiclibrary/TestUtils.kt b/tempest2/src/test/kotlin/app/cash/tempest2/musiclibrary/TestUtils.kt index f68fbf907..8467bf660 100644 --- a/tempest2/src/test/kotlin/app/cash/tempest2/musiclibrary/TestUtils.kt +++ b/tempest2/src/test/kotlin/app/cash/tempest2/musiclibrary/TestUtils.kt @@ -26,9 +26,9 @@ import software.amazon.awssdk.enhanced.dynamodb.model.EnhancedLocalSecondaryInde import software.amazon.awssdk.services.dynamodb.model.Projection import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput -fun testDb() = TestDynamoDb.Builder(JvmDynamoDbServer.Factory) +fun testDb(tableName: String = "music_items") = TestDynamoDb.Builder(JvmDynamoDbServer.Factory) .addTable( - TestTable.create("music_items") { + TestTable.create(tableName) { it.toBuilder() .globalSecondaryIndices( EnhancedGSI("genre_album_index"),