From d3dc445b35d839d5a1486f2dee14c031e0fc222f Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Sat, 9 Nov 2024 20:56:33 -0800 Subject: [PATCH] Router: no break after `=>` in lambda w/ type Otherwise, it looks like a "fewer braces" invocation and will be parsed as such. --- .../scala/org/scalafmt/internal/Router.scala | 44 ++++++++++--------- .../scala/org/scalafmt/util/TreeOps.scala | 32 +++++++++++--- .../scala3/CommunityScala3Suite.scala | 2 +- .../test/resources/scala3/OptionalBraces.stat | 8 ++-- .../resources/scala3/OptionalBraces_fold.stat | 23 +++++----- .../resources/scala3/OptionalBraces_keep.stat | 8 ++-- .../scala3/OptionalBraces_unfold.stat | 38 ++++++++-------- .../test/scala/org/scalafmt/FormatTests.scala | 2 +- 8 files changed, 91 insertions(+), 66 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index d8c9a6856..4a8c2746b 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -285,7 +285,7 @@ class Router(formatOps: FormatOps) { else Some(false) (arrow, 0, nlOnly) case (t: Term.FunctionTerm) :: Nil => - val arrow = getFuncArrow(lastLambda(t)).getOrElse(getLast(t)) + val arrow = lastLambda(t).flatMap(getFuncArrow).getOrElse(getLast(t)) val nlOnly = if (style.newlines.alwaysBeforeCurlyLambdaParams) Some(true) else if ( @@ -527,27 +527,29 @@ class Router(formatOps: FormatOps) { case _ => false }) => val leftFunc = leftOwner.asInstanceOf[Term.FunctionTerm] - val (afterCurlySpace, afterCurlyNewlines) = - getSpaceAndNewlineAfterCurlyLambda(newlines) def spaceSplitBase(implicit line: FileLine): Split = Split(Space, 0) - val spaceSplit = leftFunc.body match { - case _: Term.FunctionTerm => spaceSplitBase - case Term.Block((_: Term.FunctionTerm) :: Nil) - if !nextNonComment(ft).right.is[T.LeftBrace] => spaceSplitBase - case _ if afterCurlySpace && { - style.newlines.fold || !rightOwner.is[Defn] - } => - val exp = nextNonCommentSameLine(getLastNonTrivial(leftFunc.body)) - .left - spaceSplitBase.withSingleLine(exp, noSyntaxNL = true) - case _ => Split.ignored - } - val (endIndent, expiresOn) = functionExpire(leftFunc) - Seq( - spaceSplit, - Split(afterCurlyNewlines, 1) - .withIndent(style.indent.main, endIndent, expiresOn), - ) + if (canBreakAfterFuncArrow(leftFunc)) { + val (afterCurlySpace, afterCurlyNewlines) = + getSpaceAndNewlineAfterCurlyLambda(newlines) + val spaceSplit = leftFunc.body match { + case _: Term.FunctionTerm => spaceSplitBase + case Term.Block((_: Term.FunctionTerm) :: Nil) + if !nextNonComment(ft).right.is[T.LeftBrace] => spaceSplitBase + case _ if afterCurlySpace && { + style.newlines.fold || !rightOwner.is[Defn] + } => + val exp = nextNonCommentSameLine(getLastNonTrivial(leftFunc.body)) + .left + spaceSplitBase.withSingleLine(exp, noSyntaxNL = true) + case _ => Split.ignored + } + val (endIndent, expiresOn) = functionExpire(leftFunc) + Seq( + spaceSplit, + Split(afterCurlyNewlines, 1) + .withIndent(style.indent.main, endIndent, expiresOn), + ) + } else Seq(spaceSplitBase) case FormatToken(_: T.RightArrow | _: T.ContextArrow, right, _) if (leftOwner match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala index 38032ea85..f9e63e810 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala @@ -351,14 +351,36 @@ object TreeOps { math.max(res, treeDepth(t)) } + final def canBreakAfterFuncArrow(func: Term.FunctionTerm)(implicit + ftoks: FormatTokens, + style: ScalafmtConfig, + ): Boolean = !style.dialect.allowFewerBraces || { + val params = func.paramClause + params.values match { + case param :: Nil => param.decltpe match { + case Some(_: Type.Name) => ftoks.isEnclosedInMatching(params) + case _ => true + } + case _ => true + } + } + @tailrec final def lastLambda( first: Term.FunctionTerm, - )(implicit ftoks: FormatTokens): Term.FunctionTerm = first.body match { - case child: Term.FunctionTerm => lastLambda(child) - case b @ Term.Block((child: Term.FunctionTerm) :: Nil) - if !ftoks.getHead(b).left.is[Token.LeftBrace] => lastLambda(child) - case _ => first + res: Option[Term.FunctionTerm] = None, + )(implicit + ftoks: FormatTokens, + style: ScalafmtConfig, + ): Option[Term.FunctionTerm] = { + val nextres = if (canBreakAfterFuncArrow(first)) Some(first) else res + first.body match { + case child: Term.FunctionTerm => lastLambda(child, nextres) + case b @ Term.Block((child: Term.FunctionTerm) :: Nil) + if !ftoks.getHead(b).left.is[Token.LeftBrace] => + lastLambda(child, nextres) + case _ => nextres + } } @inline diff --git a/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala b/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala index 6cfba213b..4af39ab2b 100644 --- a/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala +++ b/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala @@ -17,7 +17,7 @@ class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") { class CommunityScala3_3Suite extends CommunityScala3Suite("scala-3.3") { - override protected def totalStatesVisited: Option[Int] = Some(34839832) + override protected def totalStatesVisited: Option[Int] = Some(34839737) override protected def builds = Seq(getBuild("3.3.3", dialects.Scala33, 861)) diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat index f3947de0b..0d96fbca5 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat @@ -7556,11 +7556,11 @@ object Build: .settings( generateScalaDocumentation := Def.inputTaskDyn { val outputDirOverride = - extraArgs.headOption.fold(identity[GenerationConfig](_)) { - newDir => config: GenerationConfig => config.add(OutputDir(newDir)) + extraArgs.headOption.fold(identity[GenerationConfig](_)) { newDir => + config: GenerationConfig => config.add(OutputDir(newDir)) } - val justAPI = justAPIArg.fold(identity[GenerationConfig](_)) { - _ => config: GenerationConfig => config.remove[SiteRoot] + val justAPI = justAPIArg.fold(identity[GenerationConfig](_)) { _ => + config: GenerationConfig => config.remove[SiteRoot] } }.evaluated ) diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat index 1aa74b2f0..595f389c1 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat @@ -7264,15 +7264,14 @@ object Build { ) } >>> -Idempotency violated -=> Diff (- obtained, + expected) - } -- val justAPI = justAPIArg.fold(identity[GenerationConfig](_)) { _ => -- config: GenerationConfig => -- config.remove[SiteRoot] -- } -+ val justAPI = justAPIArg -+ .fold(identity[GenerationConfig](_)) { _ => config: GenerationConfig => -+ config.remove[SiteRoot] -+ } - }.evaluated) +object Build: + lazy val scaladoc = project.in(file("scaladoc")) + .settings(generateScalaDocumentation := Def.inputTaskDyn { + val outputDirOverride = extraArgs.headOption + .fold(identity[GenerationConfig](_)) { newDir => + config: GenerationConfig => config.add(OutputDir(newDir)) + } + val justAPI = justAPIArg.fold(identity[GenerationConfig](_)) { _ => + config: GenerationConfig => config.remove[SiteRoot] + } + }.evaluated) diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat index 73a29a2e0..de21c14bf 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat @@ -7584,11 +7584,11 @@ object Build: settings( generateScalaDocumentation := Def.inputTaskDyn { val outputDirOverride = - extraArgs.headOption.fold(identity[GenerationConfig](_)) { - newDir => config: GenerationConfig => config.add(OutputDir(newDir)) + extraArgs.headOption.fold(identity[GenerationConfig](_)) { newDir => + config: GenerationConfig => config.add(OutputDir(newDir)) } - val justAPI = justAPIArg.fold(identity[GenerationConfig](_)) { - _ => config: GenerationConfig => config.remove[SiteRoot] + val justAPI = justAPIArg.fold(identity[GenerationConfig](_)) { _ => + config: GenerationConfig => config.remove[SiteRoot] } }.evaluated ) diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat index 1b8090d8d..c84a38f02 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat @@ -7859,21 +7859,23 @@ object Build { ) } >>> -Idempotency violated -=> Diff (- obtained, + expected) - .headOption -- .fold(identity[GenerationConfig](_)) { newDir => -- config: GenerationConfig => -- config.add(OutputDir(newDir)) -+ .fold(identity[GenerationConfig](_)) { -+ newDir => config: GenerationConfig => -+ config.add(OutputDir(newDir)) - } - val justAPI = -- justAPIArg.fold(identity[GenerationConfig](_)) { _ => -- config: GenerationConfig => -- config.remove[SiteRoot] -+ justAPIArg.fold(identity[GenerationConfig](_)) { -+ _ => config: GenerationConfig => -+ config.remove[SiteRoot] - } +object Build: + lazy val scaladoc = project + .in(file("scaladoc")) + .settings( + generateScalaDocumentation := + Def + .inputTaskDyn { + val outputDirOverride = + extraArgs + .headOption + .fold(identity[GenerationConfig](_)) { newDir => + config: GenerationConfig => config.add(OutputDir(newDir)) + } + val justAPI = + justAPIArg.fold(identity[GenerationConfig](_)) { _ => + config: GenerationConfig => config.remove[SiteRoot] + } + } + .evaluated + ) diff --git a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala index 6216504c6..372025f7a 100644 --- a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala @@ -137,7 +137,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { val explored = Debug.explored.get() logger.debug(s"Total explored: $explored") if (!onlyUnit && !onlyManual) - assertEquals(explored, 1119018, "total explored") + assertEquals(explored, 1116417, "total explored") val results = debugResults.result() // TODO(olafur) don't block printing out test results. // I don't want to deal with scalaz's Tasks :'(