Skip to content

Commit

Permalink
Fix sorting algorithm to work for all array types (#211)
Browse files Browse the repository at this point in the history
Fixes #204, by really calling keyF for all types in the array.
  • Loading branch information
stephenamar-db authored Nov 5, 2024
1 parent fe13fa4 commit cbef90b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 46 deletions.
91 changes: 51 additions & 40 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,7 @@ class Std {
}

private object SetInter extends Val.Builtin3("a", "b", "keyF", Array(null, null, Val.False(dummyPos))) {
def isStr(a: Val.Arr) = a.forall(_.isInstanceOf[Val.Str])
def isNum(a: Val.Arr) = a.forall(_.isInstanceOf[Val.Num])
private def isStr(a: Val.Arr) = a.forall(_.isInstanceOf[Val.Str])

override def specialize(args: Array[Expr]): (Val.Builtin, Array[Expr]) = args match {
case Array(a: Val.Arr, b) if isStr(a) => (new Spec1Str(a), Array(b))
Expand Down Expand Up @@ -750,8 +749,8 @@ class Std {
case (l: Val.Obj, r: Val.Obj) =>
val kvs = for {
k <- (l.visibleKeyNames ++ r.visibleKeyNames).distinct
val lValue = Option(l.valueRaw(k, l, pos)(ev))
val rValue = Option(r.valueRaw(k, r, pos)(ev))
lValue = Option(l.valueRaw(k, l, pos)(ev))
rValue = Option(r.valueRaw(k, r, pos)(ev))
if !rValue.exists(_.isInstanceOf[Val.Null])
} yield (lValue, rValue) match{
case (Some(lChild), None) => k -> createMember{lChild}
Expand All @@ -767,7 +766,7 @@ class Std {
case obj: Val.Obj =>
val kvs = for{
k <- obj.visibleKeyNames
val value = obj.value(k, pos, obj)(ev)
value = obj.value(k, pos, obj)(ev)
if !value.isInstanceOf[Val.Null]
} yield (k, createMember{recSingle(value)})

Expand Down Expand Up @@ -1367,10 +1366,10 @@ class Std {
}
}

def uniqArr(pos: Position, ev: EvalScope, arr: Val, keyF: Val) = {
private def uniqArr(pos: Position, ev: EvalScope, arr: Val, keyF: Val) = {
val arrValue = arr match {
case arr: Val.Arr => arr.asLazyArray
case str: Val.Str => stringChars(pos, str.value).asLazyArray
case arr: Val.Arr => arr
case str: Val.Str => stringChars(pos, str.asString)
case _ => Error.fail("Argument must be either array or string")
}

Expand Down Expand Up @@ -1404,38 +1403,50 @@ class Std {
new Val.Arr(pos, out.toArray)
}

def sortArr(pos: Position, ev: EvalScope, arr: Val, keyF: Val) = {
arr match{
case vs: Val.Arr =>
new Val.Arr(
pos,
if (vs.forall(_.isInstanceOf[Val.Str])){
vs.asStrictArray.map(_.cast[Val.Str]).sortBy(_.value)
}else if (vs.forall(_.isInstanceOf[Val.Num])) {
vs.asStrictArray.map(_.cast[Val.Num]).sortBy(_.value)
}else if (vs.forall(_.isInstanceOf[Val.Obj])){
if (keyF == null || keyF.isInstanceOf[Val.False]) {
Error.fail("Unable to sort array of objects without key function")
} else {
val objs = vs.asStrictArray.map(_.cast[Val.Obj])

val keyFFunc = keyF.asInstanceOf[Val.Func]
val keys = objs.map((v) => keyFFunc(Array(v), null, pos.noOffset)(ev))

if (keys.forall(_.isInstanceOf[Val.Str])){
objs.sortBy((v) => keyFFunc(Array(v), null, pos.noOffset)(ev).cast[Val.Str].value)
} else if (keys.forall(_.isInstanceOf[Val.Num])) {
objs.sortBy((v) => keyFFunc(Array(v), null, pos.noOffset)(ev).cast[Val.Num].value)
} else {
Error.fail("Cannot sort with key values that are " + keys(0).prettyName + "s")
}
}
}else {
???
}
)
case Val.Str(pos, s) => new Val.Arr(pos, s.sorted.map(c => Val.Str(pos, c.toString)).toArray)
case x => Error.fail("Cannot sort " + x.prettyName)
private def sortArr(pos: Position, ev: EvalScope, arr: Val, keyF: Val) = {
val vs = arr match {
case arr: Val.Arr => arr
case str: Val.Str => stringChars(pos, str.asString)
case _ => Error.fail("Cannot sort " + arr.prettyName)
}
if (vs.length <= 1) {
arr
} else {
val keyFFunc = if (keyF == null || keyF.isInstanceOf[Val.False]) null else keyF.asInstanceOf[Val.Func]
new Val.Arr(pos, if (keyFFunc != null) {
val keys = new Val.Arr(pos.noOffset, vs.asStrictArray.map((v) => keyFFunc(Array(v), null, pos.noOffset)(ev)))
val keyTypes = keys.iterator.map(_.getClass).toSet
if (keyTypes.size != 1) {
Error.fail("Cannot sort with key values that are not all the same type")
}

if (keyTypes.contains(classOf[Val.Str])) {
vs.asStrictArray.sortBy((v) => keyFFunc(Array(v), null, pos.noOffset)(ev).cast[Val.Str].asString)
} else if (keyTypes.contains(classOf[Val.Num])) {
vs.asStrictArray.sortBy((v) => keyFFunc(Array(v), null, pos.noOffset)(ev).cast[Val.Num].asDouble)
} else if (keyTypes.contains(classOf[Val.Bool])) {
vs.asStrictArray.sortBy((v) => keyFFunc(Array(v), null, pos.noOffset)(ev).cast[Val.Bool].asBoolean)
} else {
Error.fail("Cannot sort with key values that are " + keys.force(0).prettyName + "s")
}
} else {
val keyTypes = vs.iterator.map(_.getClass).toSet
if (keyTypes.size != 1) {
Error.fail("Cannot sort with values that are not all the same type")
}

if (keyTypes.contains(classOf[Val.Str])) {
vs.asStrictArray.map(_.cast[Val.Str]).sortBy(_.asString)
} else if (keyTypes.contains(classOf[Val.Num])) {
vs.asStrictArray.map(_.cast[Val.Num]).sortBy(_.asDouble)
} else if (keyTypes.contains(classOf[Val.Bool])) {
vs.asStrictArray.map(_.cast[Val.Bool]).sortBy(_.asBoolean)
} else if (keyTypes.contains(classOf[Val.Obj])) {
Error.fail("Unable to sort array of objects without key function")
} else {
Error.fail("Cannot sort array of " + vs.force(0).prettyName)
}
})
}
}

Expand Down
10 changes: 6 additions & 4 deletions sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ object PrettyNamed{
implicit def strName: PrettyNamed[Val.Str] = new PrettyNamed("string")
implicit def numName: PrettyNamed[Val.Num] = new PrettyNamed("number")
implicit def arrName: PrettyNamed[Val.Arr] = new PrettyNamed("array")
implicit def boolName: PrettyNamed[Val.Bool] = new PrettyNamed("boolean")
implicit def objName: PrettyNamed[Val.Obj] = new PrettyNamed("object")
implicit def funName: PrettyNamed[Val.Func] = new PrettyNamed("function")
implicit def nullName: PrettyNamed[Val.Null] = new PrettyNamed("null")
}
object Val{

Expand All @@ -69,7 +71,7 @@ object Val{
override def asBoolean: Boolean = this.isInstanceOf[True]
}

def bool(pos: Position, b: Boolean) = if (b) True(pos) else False(pos)
def bool(pos: Position, b: Boolean): Bool = if (b) True(pos) else False(pos)

case class True(pos: Position) extends Bool {
def prettyName = "boolean"
Expand All @@ -92,19 +94,19 @@ object Val{

class Arr(val pos: Position, private val value: Array[_ <: Lazy]) extends Literal {
def prettyName = "array"

override def asArr: Arr = this
def length: Int = value.length
def force(i: Int) = value(i).force
def force(i: Int): Val = value(i).force

def asLazy(i: Int) = value(i)
def asLazyArray: Array[Lazy] = value.asInstanceOf[Array[Lazy]]
def asStrictArray: Array[Val] = value.map(_.force)

def concat(newPos: Position, rhs: Arr): Arr =
new Arr(newPos, value ++ rhs.value)

def iterator: Iterator[Val] = value.iterator.map(_.force)
def foreach[U](f: Val => U) = {
def foreach[U](f: Val => U): Unit = {
var i = 0
while(i < value.length) {
f(value(i).force)
Expand Down
13 changes: 11 additions & 2 deletions sjsonnet/test/src/sjsonnet/StdWithKeyFTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sjsonnet

import utest._
import TestUtils.eval
import TestUtils.{eval, evalErr}
object StdWithKeyFTests extends TestSuite {

def tests = Tests {
Expand Down Expand Up @@ -47,7 +47,16 @@ object StdWithKeyFTests extends TestSuite {
""") ==> ujson.True
}
test("stdSortWithKeyF") {
eval("std.sort([\"c\", \"a\", \"b\"])").toString() ==> """["a","b","c"]"""
eval("""std.sort(["a","b","c"])""").toString() ==> """["a","b","c"]"""
eval("""std.sort([1, 2, 3])""").toString() ==> """[1,2,3]"""
eval("""std.sort([1,2,3], keyF=function(x) -x)""").toString() ==> """[3,2,1]"""
eval("""std.sort([1,2,3], function(x) -x)""").toString() ==> """[3,2,1]"""
assert(
evalErr("""std.sort([1,2,3], keyF=function(x) error "foo")""").startsWith("sjsonnet.Error: foo"))
assert(
evalErr("""std.sort([1,2, error "foo"])""").startsWith("sjsonnet.Error: foo"))
assert(
evalErr("""std.sort([1, [error "foo"]])""").startsWith("sjsonnet.Error: Cannot sort with values that are not all the same type"))

eval(
"""local arr = [
Expand Down

0 comments on commit cbef90b

Please sign in to comment.