Skip to content

Commit

Permalink
Support multiple pointsTo for a column (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindberg authored May 31, 2024
1 parent 781cfc6 commit fe2ad81
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 105 deletions.
4 changes: 2 additions & 2 deletions typo/src/scala/typo/MetaDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ object MetaDb {
val decomposedSql = DecomposedSql.parse(sqlContent)
val Right(jdbcMetadata) = JdbcMetadata.from(sqlContent): @unchecked
val nullabilityInfo = NullabilityFromExplain.from(decomposedSql, Nil).nullableIndices
val deps: Map[db.ColName, (db.RelationName, db.ColName)] =
val deps: Map[db.ColName, List[(db.RelationName, db.ColName)]] =
jdbcMetadata.columns match {
case MaybeReturnsRows.Query(columns) =>
columns.toList.flatMap(col => col.baseRelationName.zip(col.baseColumnName).map(col.name -> _)).toMap
columns.toList.flatMap(col => col.baseRelationName.zip(col.baseColumnName).map(t => col.name -> List(t))).toMap
case MaybeReturnsRows.Update =>
Map.empty
}
Expand Down
2 changes: 1 addition & 1 deletion typo/src/scala/typo/db.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ object db {
comment: Option[String],
decomposedSql: DecomposedSql,
cols: NonEmptyList[(db.Col, ParsedName)],
deps: Map[db.ColName, (db.RelationName, db.ColName)],
deps: Map[db.ColName, List[(db.RelationName, db.ColName)]],
isMaterialized: Boolean
) extends Relation
}
2 changes: 1 addition & 1 deletion typo/src/scala/typo/internal/ComputedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package typo
package internal

case class ComputedColumn(
pointsTo: Option[(Source.Relation, db.ColName)],
pointsTo: List[(Source.Relation, db.ColName)],
name: sc.Ident,
tpe: sc.Type,
dbCol: db.Col
Expand Down
46 changes: 19 additions & 27 deletions typo/src/scala/typo/internal/ComputedSqlFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ case class ComputedSqlFile(
) {
val source: Source.SqlFile = Source.SqlFile(sqlFile.relPath)

val deps: Map[db.ColName, (db.RelationName, db.ColName)] =
val deps: Map[db.ColName, List[(db.RelationName, db.ColName)]] =
sqlFile.jdbcMetadata.columns match {
case MaybeReturnsRows.Query(columns) =>
columns.toList.flatMap(col => col.baseRelationName.zip(col.baseColumnName).map(col.name -> _)).toMap
columns.toList.flatMap(col => col.baseRelationName.zip(col.baseColumnName).map(t => col.name -> List(t))).toMap
case MaybeReturnsRows.Update =>
Map.empty
}
Expand All @@ -38,36 +38,28 @@ case class ComputedSqlFile(
logger.warn(s"Couldn't translate type from file ${sqlFile.relPath} column ${col.name.value} with type ${col.columnTypeName}. Falling back to text")
}

// we let types flow through constraints down to this column, the point is to reuse id types downstream
val typeFromFk: Option[sc.Type] =
deps.get(col.name) match {
case Some((otherTableName, otherColName)) =>
for {
existingTable <- eval(otherTableName)
nonCircular <- existingTable.get
col <- nonCircular.cols.find(_.dbName == otherColName)
} yield col.tpe
case _ => None
val pointsTo: List[(Source.Relation, db.ColName)] =
deps.getOrElse(col.name, Nil).flatMap { case (relName, colName) =>
eval(relName).flatMap(_.get).map(x => x.source -> colName)
}

// we let types flow through constraints down to this column, the point is to reuse id types downstream
val typeFromFk: Option[sc.Type] = findTypeFromFk(logger, source, col.name, pointsTo, eval)(_ => None)

val tpe = scalaTypeMapper.sqlFile(col.parsedColumnName.overriddenType.orElse(typeFromFk), dbType, nullability)

ComputedColumn(
pointsTo = deps.get(col.name).flatMap { case (relName, colName) => eval(relName).flatMap(_.get.map(foo => foo.source -> colName)) },
name = naming.field(col.name),
tpe = tpe,
dbCol = db.Col(
parsedName = col.parsedColumnName,
tpe = dbType,
udtName = None,
nullability = nullability,
columnDefault = None,
identity = None,
comment = None,
constraints = Nil,
jsonDescription = DebugJson(col)
)
val dbCol = db.Col(
parsedName = col.parsedColumnName,
tpe = dbType,
udtName = None,
nullability = nullability,
columnDefault = None,
identity = None,
comment = None,
constraints = Nil,
jsonDescription = DebugJson(col)
)
ComputedColumn(pointsTo = pointsTo, name = naming.field(col.name), tpe = tpe, dbCol = dbCol)
}
}

Expand Down
94 changes: 39 additions & 55 deletions typo/src/scala/typo/internal/ComputedTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,30 @@ case class ComputedTable(
eval: Eval[db.RelationName, HasSource]
) extends HasSource {
override val source: Source.Table = Source.Table(dbTable.name)
val pointsTo: Map[db.ColName, (Source.Relation, db.ColName)] = {
val (fkSelf, fkOther) = dbTable.foreignKeys.partition { fk => fk.otherTable == dbTable.name }

val fromSelf = fkSelf.flatMap(fk => fk.cols.zip(fk.otherCols.map(cn => (source, cn))).toList).toMap
val deps: Map[db.ColName, List[(Source.Relation, db.ColName)]] = {
val (fkSelf, fkOther) = dbTable.foreignKeys.partition { fk => fk.otherTable == dbTable.name }

val fromOthers = fkOther.flatMap { fk =>
eval(fk.otherTable).get match {
case None =>
options.logger.warn(s"Circular: ${dbTable.name.value} => ${fk.otherTable.value}")
Nil
case Some(otherTable) =>
val value = fk.otherCols.map(cn => (otherTable.source, cn))
fk.cols.zip(value).toList
val fromSelf: List[(db.ColName, (Source.Relation, db.ColName))] =
fkSelf.flatMap(fk => fk.cols.zip(fk.otherCols.map(cn => (source, cn))).toList)

val fromOthers: List[(db.ColName, (Source.Relation, db.ColName))] =
fkOther.flatMap { fk =>
eval(fk.otherTable).get match {
case None =>
options.logger.warn(s"Circular: ${dbTable.name.value} => ${fk.otherTable.value}")
Nil
case Some(otherTable) =>
fk.cols.zip(fk.otherCols.map(cn => (otherTable.source, cn))).toList
}
}
}.toMap

// prefer inferring types from outside the table
fromSelf ++ fromOthers
(fromSelf ++ fromOthers).groupBy(_._1).map { case (colName, tuples) =>
val sorted = tuples
.map { case (_, other) => other }
.sortBy { case (rel, colName) => (rel.name.value, colName.value) }
colName -> sorted
}
}

val dbColsByName: Map[db.ColName, db.Col] =
Expand All @@ -41,82 +47,60 @@ case class ComputedTable(
pk.colNames match {
case NonEmptyList(colName, Nil) =>
val dbCol = dbColsByName(colName)
val underlying = scalaTypeMapper.col(dbTable.name, dbCol, None)
val col = ComputedColumn(
pointsTo = pointsTo.get(dbCol.name),
name = naming.field(dbCol.name),
tpe = underlying,
dbCol = dbCol
)
col.pointsTo match {
case Some((relationSource, colName)) =>
val cols = eval(relationSource.name).forceGet.cols
val tpe = cols.find(_.dbName == colName).get.tpe
val pointsTo = deps.getOrElse(dbCol.name, Nil)

findTypeFromFk(options.logger, source, dbCol.name, pointsTo, eval.asMaybe)(_ => None) match {
case Some(tpe) =>
val col = ComputedColumn(pointsTo = pointsTo, name = naming.field(dbCol.name), tpe = tpe, dbCol = dbCol)
Some(IdComputed.UnaryInherited(col, tpe))
case None =>
val underlying = scalaTypeMapper.col(dbTable.name, dbCol, None)
val col = ComputedColumn(pointsTo = pointsTo, name = naming.field(dbCol.name), tpe = underlying, dbCol = dbCol)
if (sc.Type.containsUserDefined(underlying))
Some(IdComputed.UnaryUserSpecified(col, underlying))
else if (!options.enablePrimaryKeyType.include(dbTable.name))
Some(IdComputed.UnaryNoIdType(col, underlying))
else
Some(IdComputed.UnaryNormal(col, tpe))

}

case colNames =>
val cols: NonEmptyList[ComputedColumn] =
colNames.map { colName =>
val dbCol = dbColsByName(colName)
ComputedColumn(
pointsTo = None,
pointsTo = Nil,
name = naming.field(colName),
tpe = deriveType(dbCol),
dbCol = dbCol
tpe = deriveType(colName),
dbCol = dbColsByName(colName)
)
}
Some(IdComputed.Composite(cols, tpe, paramName = sc.Ident("compositeId")))
}
}

val cols: NonEmptyList[ComputedColumn] = {
val cols: NonEmptyList[ComputedColumn] =
dbTable.cols.map { dbCol =>
val tpe = deriveType(dbCol)

ComputedColumn(
pointsTo = pointsTo.get(dbCol.name),
pointsTo = deps.getOrElse(dbCol.name, Nil),
name = naming.field(dbCol.name),
tpe = tpe,
tpe = deriveType(dbCol.name),
dbCol = dbCol
)
}
}

def deriveType(dbCol: db.Col): sc.Type = {
def deriveType(colName: db.ColName): sc.Type = {
val dbCol = dbColsByName(colName)
// we let types flow through constraints down to this column, the point is to reuse id types downstream
val typeFromFk: Option[sc.Type] =
pointsTo.get(dbCol.name).flatMap { case (otherTableSource, otherColName) =>
if (otherTableSource.name == dbTable.name)
if (otherColName == dbCol.name) None
else Some(deriveType(dbColsByName(otherColName)))
else
eval(otherTableSource.name).get match {
case Some(otherTable) =>
otherTable.cols.find(_.dbName == otherColName).map(_.tpe)
case None =>
options.logger.warn(s"Unexpected circular dependency involving ${dbTable.name.value} => ${otherTableSource.name.value}")
None
}
}
findTypeFromFk(options.logger, source, colName, deps.getOrElse(colName, Nil), eval.asMaybe)(otherColName => Some(deriveType(otherColName)))

val typeFromId: Option[sc.Type] =
maybeId match {
case Some(id: IdComputed.Unary) if id.col.dbName == dbCol.name => Some(id.tpe)
case _ => None
case Some(id: IdComputed.Unary) if id.col.dbName == colName => Some(id.tpe)
case _ => None
}

val tpe = scalaTypeMapper.col(dbTable.name, dbCol, typeFromFk.orElse(typeFromId))

tpe
scalaTypeMapper.col(dbTable.name, dbCol, typeFromFk.orElse(typeFromId))
}

val names = ComputedNames(naming, source, maybeId, options.enableFieldValue.include(dbTable.name), options.enableDsl)
Expand Down
41 changes: 24 additions & 17 deletions typo/src/scala/typo/internal/ComputedView.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package typo
package internal

import typo.internal.analysis.ParsedName
import typo.internal.rewriteDependentData.Eval

case class ComputedView(
logger: TypoLogger,
view: db.View,
naming: Naming,
typeMapperDb: TypeMapperDb,
Expand All @@ -14,26 +16,31 @@ case class ComputedView(
) extends HasSource {
val source: Source.View = Source.View(view.name, view.isMaterialized)

val pointsToByColName: Map[db.ColName, List[(Source.Relation, db.ColName)]] =
view.cols.map { case (col, _) =>
col.name -> view.deps.getOrElse(col.name, Nil).flatMap { case (relName, colName) => eval(relName).get.map(_.source -> colName) }
}.toMap

val colsByName: Map[db.ColName, (db.Col, ParsedName)] =
view.cols.map { case t @ (col, _) => col.name -> t }.toMap

val cols: NonEmptyList[ComputedColumn] =
view.cols.map { case (col, parsedName) =>
// we let types flow through constraints down to this column, the point is to reuse id types downstream
val typeFromFk: Option[sc.Type] =
view.deps.get(col.name) match {
case Some((otherTableName, otherColName)) =>
val existingTable = eval(otherTableName)
for {
nonCircular <- existingTable.get
col <- nonCircular.cols.find(_.dbName == otherColName)
} yield col.tpe
case _ => None
}

val tpe = scalaTypeMapper.sqlFile(parsedName.overriddenType.orElse(typeFromFk), col.tpe, col.nullability)

val pointsTo = view.deps.get(col.name).flatMap { case (relName, colName) => eval(relName).get.map(_.source -> colName) }
ComputedColumn(pointsTo = pointsTo, name = naming.field(col.name), tpe = tpe, dbCol = col)
view.cols.map { case (col, _) =>
ComputedColumn(
pointsTo = pointsToByColName(col.name),
name = naming.field(col.name),
tpe = inferType(col.name),
dbCol = col
)
}

def inferType(colName: db.ColName): sc.Type = {
val (col, parsedName) = colsByName(colName)
val typeFromFk: Option[sc.Type] =
findTypeFromFk(logger, source, col.name, pointsToByColName(col.name), eval.asMaybe)(otherColName => Some(inferType(otherColName)))
scalaTypeMapper.sqlFile(parsedName.overriddenType.orElse(typeFromFk), col.tpe, col.nullability)
}

val names = ComputedNames(naming, source, maybeId = None, enableFieldValue, enableDsl = enableDsl)

val repoMethods: NonEmptyList[RepoMethod] = {
Expand Down
39 changes: 39 additions & 0 deletions typo/src/scala/typo/internal/findTypeFromFk.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package typo
package internal

import typo.internal.rewriteDependentData.EvalMaybe

// we let types flow through constraints down to this column, the point is to reuse id types downstream
object findTypeFromFk {
def apply(
typoLogger: TypoLogger,
source: Source,
colName: db.ColName,
pointsTo: List[(Source.Relation, db.ColName)],
eval: EvalMaybe[db.RelationName, HasSource]
)(inSameSource: db.ColName => Option[sc.Type]): Option[sc.Type] = {
val all: List[Either[sc.Type, sc.Type]] =
pointsTo.flatMap { case (otherTableSource, otherColName) =>
if (otherTableSource == source)
if (colName == otherColName) None
else inSameSource(otherColName).map(Left.apply)
else
for {
existingTable <- eval(otherTableSource.name)
nonCircular <- existingTable.get
otherCol <- nonCircular.cols.find(_.dbName == otherColName)
} yield Right(otherCol.tpe)
}

all.distinctBy { e => sc.Type.base(e.merge) } match {
case Nil => None
case e :: Nil => Some(e.merge)
case all =>
val fromSelf = all.collectFirst { case Left(tpe) => tpe }
val fromOthers = all.collectFirst { case Right(tpe) => tpe }
val renderedTypes = all.map { e => sc.renderTree(e.merge) }
typoLogger.warn(s"Multiple distinct types inherited for column ${colName.value} in $source: ${renderedTypes.mkString(", ")}")
fromOthers.orElse(fromSelf)
}
}
}
2 changes: 1 addition & 1 deletion typo/src/scala/typo/internal/generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object generate {
case (_, dbTable: db.Table, eval) =>
ComputedTable(options, default, dbTable, naming, scalaTypeMapper, eval)
case (_, dbView: db.View, eval) =>
ComputedView(dbView, naming, metaDb.typeMapperDb, scalaTypeMapper, eval, options.enableFieldValue.include(dbView.name), options.enableDsl)
ComputedView(options.logger, dbView, naming, metaDb.typeMapperDb, scalaTypeMapper, eval, options.enableFieldValue.include(dbView.name), options.enableDsl)
}

// note, these statements will force the evaluation of some of the lazy values
Expand Down
3 changes: 2 additions & 1 deletion typo/src/scala/typo/internal/rewriteDependentData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import scala.collection.immutable.SortedMap
// where processing one thing requires having processed its dependencies first
object rewriteDependentData {
@FunctionalInterface
trait Eval[K, V] {
trait Eval[K, V] { self =>
def apply(key: K): Lazy[V]
def asMaybe: EvalMaybe[K, V] = key => Some(apply(key))
}
@FunctionalInterface
trait EvalMaybe[K, V] {
Expand Down

0 comments on commit fe2ad81

Please sign in to comment.