Skip to content

Commit

Permalink
AvoidInfix: apply to postfix select as well
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Nov 8, 2024
1 parent e7215ef commit d8db7d7
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 45 deletions.
2 changes: 2 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -3023,6 +3023,8 @@ The rule takes the following parameters under `rewrite.avoidInfix`:
- if unspecified, and `project.layout` determines that the file being
formatted is not a test file, then these test assert methods will not
be excluded
- (since 3.8.4) `excludePostfix`, unless set to `true` explicitly, will also
apply the rule to `Term.Select` trees specified without a dot

```scala mdoc:scalafmt
rewrite.rules = [AvoidInfix]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ case class AvoidInfixSettings(
private[config] val excludeFilters: Seq[AvoidInfixSettings.Filter],
private val excludeScalaTest: Option[Boolean] = None,
excludePlaceholderArg: Option[Boolean] = None,
excludePostfix: Boolean = false,
) {
// if the user completely redefined (rather than appended), we don't touch
@inline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object AvoidInfix extends RewriteFactory {

class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {

private val matcher = ctx.style.rewrite.avoidInfix
private val cfg = ctx.style.rewrite.avoidInfix

// In a perfect world, we could just use
// Tree.transform {
Expand All @@ -31,14 +31,17 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {

override def rewrite(tree: Tree): Unit = tree match {
case x: Term.ApplyInfix => rewriteImpl(x.lhs, x.op, x.arg, x.targClause)
case x: Term.Select if !cfg.excludePostfix =>
val maybeDot = ctx.tokenTraverser.prevNonTrivialToken(x.name.tokens.head)
if (!maybeDot.forall(_.is[Token.Dot])) rewriteImpl(x.qual, x.name)
case _ =>
}

private def rewriteImpl(
lhs: Term,
op: Name,
rhs: Tree,
targs: Member.SyntaxValuesClause,
rhs: Tree = null,
targs: Member.SyntaxValuesClause = null,
): Unit = {
val (lhsHead, lhsLast) = ends(lhs)
val beforeLhsHead = ctx.tokenTraverser.prevNonTrivialToken(lhsHead)
Expand All @@ -59,27 +62,29 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {
val (opHead, opLast) = ends(op)
builder += TokenPatch.AddLeft(opHead, ".", keepTok = true)

def moveOpenDelim(prev: Token, open: Token): Unit = {
// move delimiter (before comment or newline)
builder += TokenPatch.AddRight(prev, open.text, keepTok = true)
builder += TokenPatch.Remove(open)
}
if (rhs ne null) {
def moveOpenDelim(prev: Token, open: Token): Unit = {
// move delimiter (before comment or newline)
builder += TokenPatch.AddRight(prev, open.text, keepTok = true)
builder += TokenPatch.Remove(open)
}

// move the left bracket if targs
val beforeLp =
if ((targs eq null) || targs.values.isEmpty) opLast
// move the left bracket if targs
val beforeLp =
if ((targs eq null) || targs.values.isEmpty) opLast
else {
val (targsHead, targsLast) = ends(targs)
moveOpenDelim(opLast, targsHead)
targsLast
}
// move the left paren if enclosed, else enclose
val (argsHead, argsLast) = ends(rhs)
if (ctx.getMatchingOpt(argsHead).exists(argsLast.end <= _.end))
moveOpenDelim(beforeLp, argsHead)
else {
val (targsHead, targsLast) = ends(targs)
moveOpenDelim(opLast, targsHead)
targsLast
builder += TokenPatch.AddRight(beforeLp, "(", keepTok = true)
builder += TokenPatch.AddRight(argsLast, ")", keepTok = true)
}
// move the left paren if enclosed, else enclose
val (argsHead, argsLast) = ends(rhs)
if (ctx.getMatchingOpt(argsHead).exists(argsLast.end <= _.end))
moveOpenDelim(beforeLp, argsHead)
else {
builder += TokenPatch.AddRight(beforeLp, "(", keepTok = true)
builder += TokenPatch.AddRight(argsLast, ")", keepTok = true)
}

val shouldWrapLhs = !lhsIsWrapped &&
Expand All @@ -106,7 +111,7 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {
lhsIsOK: Option[Boolean] = None,
): Boolean = {
val op = name.value
InfixApp.isLeftAssoc(op) && matcher.matches(lhs.text, op) &&
InfixApp.isLeftAssoc(op) && cfg.matches(lhs.text, op) &&
(rhs match {
case ac @ Term.ArgClause(arg :: Nil, _) if !isWrapped(ac) =>
!hasPlaceholder(arg, ctx.style.rewrite.isAllowInfixPlaceholderArg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract class CommunityIntellijScalaSuite(name: String)
class CommunityIntellijScala_2024_2_Suite
extends CommunityIntellijScalaSuite("intellij-scala-2024.2") {

override protected def totalStatesVisited: Option[Int] = Some(47100766)
override protected def totalStatesVisited: Option[Int] = Some(47100854)

override protected def builds = Seq(getBuild(
"2024.2.28",
Expand Down Expand Up @@ -51,7 +51,7 @@ class CommunityIntellijScala_2024_2_Suite
class CommunityIntellijScala_2024_3_Suite
extends CommunityIntellijScalaSuite("intellij-scala-2024.3") {

override protected def totalStatesVisited: Option[Int] = Some(47277877)
override protected def totalStatesVisited: Option[Int] = Some(47277965)

override protected def builds = Seq(getBuild(
"2024.3.4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract class CommunityScala2Suite(name: String)

class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") {

override protected def totalStatesVisited: Option[Int] = Some(34655844)
override protected def totalStatesVisited: Option[Int] = Some(34656507)

override protected def builds =
Seq(getBuild("v2.12.20", dialects.Scala212, 1277))
Expand Down
30 changes: 11 additions & 19 deletions scalafmt-tests/shared/src/test/resources/rewrite/AvoidInfix.stat
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ val future = (actor1 ? "Hello").flatMap { case i: Int ⇒ actor2 ? i }
<<< excluded combinators in lhs
terminated should be theSameInstanceAs Await.result(f, 10 seconds)
>>>
(terminated should be).theSameInstanceAs(Await.result(f, 10 seconds))
(terminated should be).theSameInstanceAs(Await.result(f, 10.seconds))
<<< excluded combinators in lhs 2
seq should have length 5
>>>
Expand Down Expand Up @@ -746,29 +746,25 @@ def templates = symbols filter (x => x.isClass || x.isTrait || x == AnyRefClass/
>>>
def templates = symbols.filter(x =>
x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */
) toSet
).toSet
<<< #4133 infix, with comment before closing paren, and postfix; overflow slc
maxColumn = 78
newlines.source = fold
newlines.avoidForSimpleOverflow = [slc]
===
def templates = symbols filter (x => x.isClass || x.isTrait || x == AnyRefClass/* which is now a type alias */) toSet
>>>
def templates = symbols
.filter(x =>
x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */
) toSet
def templates = symbols.filter(x => x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */ )
.toSet
<<< #4133 infix, with comment before closing paren, and postfix; overflow all
maxColumn = 78
newlines.source = fold
newlines.avoidForSimpleOverflow = all
===
def templates = symbols filter (x => x.isClass || x.isTrait || x == AnyRefClass/* which is now a type alias */) toSet
>>>
def templates = symbols
.filter(x =>
x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */
) toSet
def templates = symbols.filter(x => x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */ )
.toSet
<<< #4133 select, apply with comment before closing paren, and postfix; overflow punct
maxColumn = 78
newlines.source = fold
Expand All @@ -778,29 +774,25 @@ def templates = symbols.filter (x => x.isClass || x.isTrait || x == AnyRefClass/
>>>
def templates = symbols.filter(x =>
x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */
) toSet
).toSet
<<< #4133 select, apply with comment before closing paren, and postfix; overflow slc
maxColumn = 78
newlines.source = fold
newlines.avoidForSimpleOverflow = [slc]
===
def templates = symbols.filter (x => x.isClass || x.isTrait || x == AnyRefClass/* which is now a type alias */) toSet
>>>
def templates = symbols
.filter(x =>
x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */
) toSet
def templates = symbols.filter(x => x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */ )
.toSet
<<< #4133 select, apply with comment before closing paren, and postfix; overflow all
maxColumn = 78
newlines.source = fold
newlines.avoidForSimpleOverflow = all
===
def templates = symbols.filter (x => x.isClass || x.isTrait || x == AnyRefClass/* which is now a type alias */) toSet
>>>
def templates = symbols
.filter(x =>
x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */
) toSet
def templates = symbols.filter(x => x.isClass || x.isTrait || x == AnyRefClass /* which is now a type alias */ )
.toSet
<<< #4133 select, apply with comment before closing paren, and select; overflow punct
maxColumn = 78
newlines.source = fold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions {
val explored = Debug.explored.get()
logger.debug(s"Total explored: $explored")
if (!onlyUnit && !onlyManual)
assertEquals(explored, 1114827, "total explored")
assertEquals(explored, 1114157, "total explored")
val results = debugResults.result()
// TODO(olafur) don't block printing out test results.
// I don't want to deal with scalaz's Tasks :'(
Expand Down

0 comments on commit d8db7d7

Please sign in to comment.