Skip to content

Commit

Permalink
Add RegEx supports using RE2 to sjsonnet (#244)
Browse files Browse the repository at this point in the history
With this PR, I'm adding a handful of methods to expose regular
expressions in jsonnet, through std.native()
I'm modeling them after [this open PR from
jsonnet](google/jsonnet#1039), which was ported
to jrsonnet.

For now, they are in std.native() as they are not part of the default
std package and use RE2 instead of the native regexp package (for
performance and compatibility reasons with a future go-jsonnet
implementation).

- regexFullMatch(pattern, str) -- Full match regex
- regexPartialMatch(pattern, str) -- Partial match regex
- regexReplace(str, pattern, to) -- Replace single occurance using regex
- regexGlobalReplace(str, pattern, to) -- Replace globally using regex

and the utility function:
- regexQuoteMeta(str) -- Escape regex metachararacters


Those functions return a object:
```
std.native("regexFullMatch")("h(?P<mid>.*)o", "hello")

{
   "captures": [
      "ell"
   ],
   "string": "hello"
}
```

This PR does not add support for the "namedCaptures" return field due to
some complications with scalajs and scalanative. Those language both use
the JDK Pattern class (js being powered by ECMA regex and Native being
powered by RE2(!)), but JDK<20 Pattern class does not have a
straightforward way to list the names of groups without some additional
hacks. This will be dealt with in a follow up PR.

This PR also adds the ability to cache patterns, and refactors all users
of regexes to use it.
  • Loading branch information
stephenamar-db authored Dec 31, 2024
1 parent 1497955 commit 729e0d8
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 36 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ lazy val main = (project in file("sjsonnet"))
"org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0",
"org.tukaani" % "xz" % "1.8",
"org.yaml" % "snakeyaml" % "1.33",
"com.google.re2j" % "re2j" % "1.7",
),
libraryDependencies ++= Seq(
"com.lihaoyi" %% "utest" % "0.8.2",
Expand Down
3 changes: 2 additions & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ object sjsonnet extends Module {
ivy"org.json:json:20240303",
ivy"org.tukaani:xz::1.10",
ivy"org.lz4:lz4-java::1.8.0",
ivy"org.yaml:snakeyaml::1.33"
ivy"org.yaml:snakeyaml::1.33",
ivy"com.google.re2j:re2j:1.7",
)
def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.**")

Expand Down
13 changes: 13 additions & 0 deletions sjsonnet/src-js/sjsonnet/Platform.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package sjsonnet

import java.io.File
import java.util
import java.util.regex.Pattern


object Platform {
def gzipBytes(s: Array[Byte]): String = {
throw new Exception("GZip not implemented in Scala.js")
Expand Down Expand Up @@ -34,4 +39,12 @@ object Platform {
def hashFile(file: File): String = {
throw new Exception("hashFile not implemented in Scala.js")
}

private val regexCache = new util.concurrent.ConcurrentHashMap[String, Pattern]

// scala.js does not rely on re2. Per https://www.scala-js.org/doc/regular-expressions.html.
// Expect to see some differences in behavior.
def getPatternFromCache(pat: String) : Pattern = regexCache.computeIfAbsent(pat, _ => Pattern.compile(pat))

def regexQuote(s: String): String = Pattern.quote(s)
}
8 changes: 8 additions & 0 deletions sjsonnet/src-jvm/sjsonnet/Platform.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package sjsonnet

import java.io.{BufferedInputStream, ByteArrayOutputStream, File, FileInputStream}
import java.util
import java.util.Base64
import java.util.zip.GZIPOutputStream
import com.google.re2j.Pattern
import net.jpountz.xxhash.{StreamingXXHash64, XXHashFactory}
import org.json.{JSONArray, JSONObject}
import org.tukaani.xz.LZMA2Options
import org.tukaani.xz.XZOutputStream
import org.yaml.snakeyaml.{LoaderOptions, Yaml}
import org.yaml.snakeyaml.constructor.SafeConstructor

import scala.jdk.CollectionConverters._

object Platform {
Expand Down Expand Up @@ -107,4 +110,9 @@ object Platform {

hash.getValue.toString
}

private val regexCache = new util.concurrent.ConcurrentHashMap[String, Pattern]
def getPatternFromCache(pat: String) : Pattern = regexCache.computeIfAbsent(pat, _ => Pattern.compile(pat))

def regexQuote(s: String): String = Pattern.quote(s)
}
9 changes: 9 additions & 0 deletions sjsonnet/src-native/sjsonnet/Platform.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package sjsonnet

import java.io.{ByteArrayOutputStream, File}
import java.util
import java.util.Base64
import java.util.zip.GZIPOutputStream
import java.util.regex.Pattern

object Platform {
def gzipBytes(b: Array[Byte]): String = {
Expand Down Expand Up @@ -50,4 +52,11 @@ object Platform {
// File hashes in Scala Native are just the file content
scala.io.Source.fromFile(file).mkString
}

private val regexCache = new util.concurrent.ConcurrentHashMap[String, Pattern]
// scala native is powered by RE2, per https://scala-native.org/en/latest/lib/javalib.html#regular-expressions-java-util-regexp
// It should perform similarly to the JVM implementation.
def getPatternFromCache(pat: String) : Pattern = regexCache.computeIfAbsent(pat, _ => Pattern.compile(pat))

def regexQuote(s: String): String = Pattern.quote(s)
}
37 changes: 17 additions & 20 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ import java.io.StringWriter
import java.nio.charset.StandardCharsets.UTF_8
import java.util.Base64
import java.util
import java.util.regex.Pattern
import sjsonnet.Expr.Member.Visibility

import scala.collection.Searching._
import scala.collection.mutable
import scala.util.matching.Regex

/**
* The Jsonnet standard library, `std`, with each builtin function implemented
Expand All @@ -19,9 +17,9 @@ import scala.util.matching.Regex
class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.empty) {
private val dummyPos: Position = new Position(null, 0)
private val emptyLazyArray = new Array[Lazy](0)
private val leadingWhiteSpacePattern = Pattern.compile("^[ \t\n\f\r\u0085\u00A0']+")
private val trailingWhiteSpacePattern = Pattern.compile("[ \t\n\f\r\u0085\u00A0']+$")
private val oldNativeFunctions = Map(
private val leadingWhiteSpacePattern = Platform.getPatternFromCache("^[ \t\n\f\r\u0085\u00A0']+")
private val trailingWhiteSpacePattern = Platform.getPatternFromCache("[ \t\n\f\r\u0085\u00A0']+$")
private val builtinNativeFunctions = Map(
builtin("gzip", "v"){ (_, _, v: Val) =>
v match{
case Val.Str(_, value) => Platform.gzipString(value)
Expand All @@ -46,9 +44,9 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
case x => Error.fail("Cannot xz encode " + x.prettyName)
}
},
)
require(oldNativeFunctions.forall(k => !additionalNativeFunctions.contains(k._1)), "Conflicting native functions")
private val nativeFunctions = oldNativeFunctions ++ additionalNativeFunctions
) ++ StdRegex.functions
require(builtinNativeFunctions.forall(k => !additionalNativeFunctions.contains(k._1)), "Conflicting native functions")
private val nativeFunctions = builtinNativeFunctions ++ additionalNativeFunctions

private object AssertEqual extends Val.Builtin2("assertEqual", "a", "b") {
def evalRhs(v1: Val, v2: Val, ev: EvalScope, pos: Position): Val = {
Expand Down Expand Up @@ -474,26 +472,25 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
Val.Str(pos, str.asString.replaceAll(from.asString, to.asString))
override def specialize(args: Array[Expr]) = args match {
case Array(str, from: Val.Str, to) =>
try { (new SpecFrom(Pattern.compile(from.value)), Array(str, to)) } catch { case _: Exception => null }
try { (new SpecFrom(from.value), Array(str, to)) } catch { case _: Exception => null }
case _ => null
}
private class SpecFrom(from: Pattern) extends Val.Builtin2("strReplaceAll", "str", "to") {
private class SpecFrom(from: String) extends Val.Builtin2("strReplaceAll", "str", "to") {
private[this] val pattern = Platform.getPatternFromCache(from)
def evalRhs(str: Val, to: Val, ev: EvalScope, pos: Position): Val =
Val.Str(pos, from.matcher(str.asString).replaceAll(to.asString))
Val.Str(pos, pattern.matcher(str.asString).replaceAll(to.asString))
}
}

private object StripUtils {
private def getLeadingPattern(chars: String): Pattern =
Pattern.compile("^[" + Regex.quote(chars) + "]+")
private def getLeadingPattern(chars: String): String = "^[" + Platform.regexQuote(chars) + "]+"

private def getTrailingPattern(chars: String): Pattern =
Pattern.compile("[" + Regex.quote(chars) + "]+$")
private def getTrailingPattern(chars: String): String = "[" + Platform.regexQuote(chars) + "]+$"

def unspecializedStrip(str: String, chars: String, left: Boolean, right: Boolean): String = {
var s = str
if (right) s = getTrailingPattern(chars).matcher(s).replaceAll("")
if (left) s = getLeadingPattern(chars).matcher(s).replaceAll("")
if (right) s = Platform.getPatternFromCache(getTrailingPattern(chars)).matcher(s).replaceAll("")
if (left) s = Platform.getPatternFromCache(getLeadingPattern(chars)).matcher(s).replaceAll("")
s
}

Expand All @@ -503,8 +500,8 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
right: Boolean,
functionName: String
) extends Val.Builtin1(functionName, "str") {
private[this] val leftPattern = getLeadingPattern(chars)
private[this] val rightPattern = getTrailingPattern(chars)
private[this] val leftPattern = Platform.getPatternFromCache(getLeadingPattern(chars))
private[this] val rightPattern = Platform.getPatternFromCache(getTrailingPattern(chars))

def evalRhs(str: Val, ev: EvalScope, pos: Position): Val = {
var s = str.asString
Expand Down Expand Up @@ -1522,7 +1519,7 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
Error.fail("Native function " + name + " not found", pos)(ev)
}
},
) ++ oldNativeFunctions
) ++ builtinNativeFunctions

private def toSetArrOrString(args: Array[Val], idx: Int, pos: Position, ev: EvalScope) = {
args(idx) match {
Expand Down
89 changes: 89 additions & 0 deletions sjsonnet/src/sjsonnet/StdRegex.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package sjsonnet

import sjsonnet.Expr.Member.Visibility
import sjsonnet.Val.Obj

object StdRegex {
def functions: Map[String, Val.Builtin] = Map(
"regexPartialMatch" -> new Val.Builtin2("regexPartialMatch", "pattern", "str") {
override def evalRhs(pattern: Val, str: Val, ev: EvalScope, pos: Position): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern.asString)
val matcher = compiledPattern.matcher(str.asString)
var returnStr: Val = null
val captures = Array.newBuilder[Val]
val groupCount = matcher.groupCount()
while (matcher.find()) {
if (returnStr == null) {
val m = matcher.group(0)
if (m != null) {
returnStr = Val.Str(pos.noOffset, matcher.group(0))
} else {
returnStr = Val.Null(pos.noOffset)
}
}
for (i <- 1 to groupCount) {
val m = matcher.group(i)
if (m == null) {
captures += Val.Null(pos.noOffset)
} else {
captures += Val.Str(pos.noOffset, m)
}
}
}
val result = captures.result()
Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal,
if (returnStr == null) Val.Null(pos.noOffset) else returnStr),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, result))
)
}
},
"regexFullMatch" -> new Val.Builtin2("regexFullMatch", "pattern", "str") {
override def evalRhs(pattern: Val, str: Val, ev: EvalScope, pos: Position): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern.asString)
val matcher = compiledPattern.matcher(str.asString)
if (!matcher.matches()) {
Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal, Val.Null(pos.noOffset)),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, Array.empty[Lazy]))
)
} else {
val captures = Array.newBuilder[Val]
val groupCount = matcher.groupCount()
for (i <- 0 to groupCount) {
val m = matcher.group(i)
if (m == null) {
captures += Val.Null(pos.noOffset)
} else {
captures += Val.Str(pos.noOffset, m)
}
}
val result = captures.result()
Val.Obj.mk(pos.noOffset,
"string" -> new Obj.ConstMember(true, Visibility.Normal, result.head),
"captures" -> new Obj.ConstMember(true, Visibility.Normal, new Val.Arr(pos.noOffset, result.drop(1)))
)
}
}
},
"regexGlobalReplace" -> new Val.Builtin3("regexGlobalReplace", "str", "pattern", "to") {
override def evalRhs(str: Val, pattern: Val, to: Val, ev: EvalScope, pos: Position): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern.asString)
val matcher = compiledPattern.matcher(str.asString)
Val.Str(pos.noOffset, matcher.replaceAll(to.asString))
}
},
"regexReplace" -> new Val.Builtin3("regexReplace", "str", "pattern", "to") {
override def evalRhs(str: Val, pattern: Val, to: Val, ev: EvalScope, pos: Position): Val = {
val compiledPattern = Platform.getPatternFromCache(pattern.asString)
val matcher = compiledPattern.matcher(str.asString)
Val.Str(pos.noOffset, matcher.replaceFirst(to.asString))
}
},
"regexQuoteMeta" -> new Val.Builtin1("regexQuoteMeta", "str") {
override def evalRhs(str: Val, ev: EvalScope, pos: Position): Val = {
Val.Str(pos.noOffset, Platform.regexQuote(str.asString))
}
}
)
}
3 changes: 1 addition & 2 deletions sjsonnet/src/sjsonnet/TomlRenderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sjsonnet
import upickle.core.{ArrVisitor, CharBuilder, ObjVisitor, SimpleVisitor, Visitor}

import java.io.StringWriter
import java.util.regex.Pattern


class TomlRenderer(out: StringWriter = new java.io.StringWriter(), cumulatedIndent: String, indent: String) extends SimpleVisitor[StringWriter, StringWriter]{
Expand Down Expand Up @@ -117,7 +116,7 @@ class TomlRenderer(out: StringWriter = new java.io.StringWriter(), cumulatedInde
}

object TomlRenderer {
private val bareAllowed = Pattern.compile("[A-Za-z0-9_-]+")
private val bareAllowed = Platform.getPatternFromCache("[A-Za-z0-9_-]+")
def escapeKey(key: String): String = if (bareAllowed.matcher(key).matches()) key else {
val out = new StringWriter()
BaseRenderer.escape(out, key, unicode = true)
Expand Down
20 changes: 8 additions & 12 deletions sjsonnet/src/sjsonnet/YamlRenderer.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package sjsonnet

import java.io.StringWriter
import java.util.regex.Pattern
import upickle.core.{ArrVisitor, ObjVisitor, SimpleVisitor, Visitor}

import scala.util.Try



class YamlRenderer(_out: StringWriter = new java.io.StringWriter(), indentArrayInObject: Boolean = false,
quoteKeys: Boolean = true, indent: Int = 2) extends BaseCharRenderer(_out, indent){
Expand Down Expand Up @@ -52,7 +48,7 @@ class YamlRenderer(_out: StringWriter = new java.io.StringWriter(), indentArrayI
elemBuilder.append('"')
elemBuilder.append('"')
} else if (s.charAt(len - 1) == '\n') {
val splits = YamlRenderer.newlinePattern.split(s)
val splits = YamlRenderer.newlinePattern.split(s.toString)
elemBuilder.append('|')
depth += 1
splits.foreach { split =>
Expand Down Expand Up @@ -174,15 +170,15 @@ class YamlRenderer(_out: StringWriter = new java.io.StringWriter(), indentArrayI
}
}
object YamlRenderer{
val newlinePattern: Pattern = Pattern.compile("\n")
private val safeYamlKeyPattern = Pattern.compile("^[a-zA-Z0-9/._-]+$")
private[sjsonnet] val newlinePattern = Platform.getPatternFromCache("\n")
private val safeYamlKeyPattern = Platform.getPatternFromCache("^[a-zA-Z0-9/._-]+$")
private val yamlReserved = Set("true", "false", "null", "yes", "no", "on", "off", "y", "n", ".nan",
"+.inf", "-.inf", ".inf", "null", "-", "---", "''")
private val yamlTimestampPattern = Pattern.compile("^(?:[0-9]*-){2}[0-9]*$")
private val yamlBinaryPattern = Pattern.compile("^[-+]?0b[0-1_]+$")
private val yamlHexPattern = Pattern.compile("[-+]?0x[0-9a-fA-F_]+")
private val yamlFloatPattern = Pattern.compile( "^-?([0-9_]*)*(\\.[0-9_]*)?(e[-+][0-9_]+)?$" )
private val yamlIntPattern = Pattern.compile("^[-+]?[0-9_]+$")
private val yamlTimestampPattern = Platform.getPatternFromCache("^(?:[0-9]*-){2}[0-9]*$")
private val yamlBinaryPattern = Platform.getPatternFromCache("^[-+]?0b[0-1_]+$")
private val yamlHexPattern = Platform.getPatternFromCache("[-+]?0x[0-9a-fA-F_]+")
private val yamlFloatPattern = Platform.getPatternFromCache( "^-?([0-9_]*)*(\\.[0-9_]*)?(e[-+][0-9_]+)?$" )
private val yamlIntPattern = Platform.getPatternFromCache("^[-+]?[0-9_]+$")

private def isSafeBareKey(k: String) = {
val l = k.toLowerCase
Expand Down
2 changes: 1 addition & 1 deletion sjsonnet/test/src/sjsonnet/OldYamlRenderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OldYamlRenderer(out: StringWriter = new java.io.StringWriter(), indentArra
val len = s.length()
if (len == 0) out.append("\"\"")
else if (s.charAt(len - 1) == '\n') {
val splits = YamlRenderer.newlinePattern.split(s)
val splits = YamlRenderer.newlinePattern.split(s.toString)
out.append('|')
depth += 1
splits.foreach { split =>
Expand Down
38 changes: 38 additions & 0 deletions sjsonnet/test/src/sjsonnet/StdRegexTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package sjsonnet

import sjsonnet.TestUtils.eval
import utest._

object StdRegexTests extends TestSuite {
def tests: Tests = Tests {
test("std.native - regex") {
eval("""std.native("regexPartialMatch")("a(b)c", "cabc")""") ==> ujson.Obj(
"string" -> "abc",
"captures" -> ujson.Arr("b")
)
eval("""std.native("regexPartialMatch")("a(b)c", "def")""") ==> ujson.Obj(
"string" -> ujson.Null,
"captures" -> ujson.Arr()
)
eval("""std.native("regexPartialMatch")("a(b)c", "abcabc")""") ==> ujson.Obj(
"string" -> "abc",
"captures" -> ujson.Arr("b", "b")
)
eval("""std.native("regexFullMatch")("a(b)c", "abc")""") ==> ujson.Obj(
"string" -> "abc",
"captures" -> ujson.Arr("b")
)
eval("""std.native("regexFullMatch")("a(b)c", "cabc")""") ==> ujson.Obj(
"string" -> ujson.Null,
"captures" -> ujson.Arr()
)
eval("""std.native("regexFullMatch")("a(b)c", "def")""") ==> ujson.Obj(
"string" -> ujson.Null,
"captures" -> ujson.Arr()
)
eval("""std.native("regexGlobalReplace")("abcbbb", "b", "d")""") ==> ujson.Str("adcddd")
eval("""std.native("regexReplace")("abcbbb", "b", "d")""") ==> ujson.Str("adcbbb")
eval("""std.native("regexQuoteMeta")("a.b")""") ==> ujson.Str(Platform.regexQuote("a.b"))
}
}
}

0 comments on commit 729e0d8

Please sign in to comment.