Skip to content

Commit

Permalink
  Use a reference instead of UUID to identify variables
Browse files Browse the repository at this point in the history
  • Loading branch information
pityka committed Jul 9, 2024
1 parent 0791fc7 commit a188657
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
9 changes: 4 additions & 5 deletions lamp-core/src/main/scala/lamp/autograd/autograd.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
package lamp.autograd
import java.{util => ju}
import lamp.FloatingPointPrecision
import lamp.Scope
import lamp.Sc
Expand Down Expand Up @@ -208,8 +207,8 @@ sealed trait Variable {
/** Returns the shape of its value. */
def shape = sizes

/** Returns unique, stable and random UUID. */
val id = ju.UUID.randomUUID()
/** Returns unique, stable reference. */
val id = new AnyRef

/** Returns an other Variable wrapping the same value tensor, without any
* parent and with `needsGrad=false`.
Expand Down Expand Up @@ -491,8 +490,8 @@ object Autograd {
private[autograd] def topologicalSort[D](root: Variable): Seq[Variable] = {
type V = Variable
var order = List.empty[V]
var marks = Set.empty[ju.UUID]
var currentParents = Set.empty[ju.UUID]
var marks = Set.empty[AnyRef]
var currentParents = Set.empty[AnyRef]

def visit(n: V): Unit =
if (marks.contains(n.id)) ()
Expand Down
2 changes: 1 addition & 1 deletion lamp-onnx/src/main/scala/lamp/onnx/OpSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import lamp.Scope
import lamp.HalfPrecision

trait NameMap {
def apply(u: UUID): String
def apply(u: AnyRef): String
}

case class Converted(
Expand Down
5 changes: 2 additions & 3 deletions lamp-onnx/src/main/scala/lamp/onnx/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import java.nio.ByteBuffer
import java.io.File
import java.io.BufferedOutputStream
import java.io.FileOutputStream
import java.util.UUID

package object onnx {
def serializeToFile(
Expand Down Expand Up @@ -109,11 +108,11 @@ package object onnx {
val inputs = info.filter(_.input).map(_.variable.id)
val nameMap = info.map { input => input.variable.id -> input.name }.toMap

def makeName(u: UUID) =
def makeName(u: AnyRef) =
nameMap.get(u).getOrElse(u.toString.replace("-", "_"))

val namer = new NameMap {
def apply(u: UUID): String = makeName(u)
def apply(u: AnyRef): String = makeName(u)
}

val constantNodes = graph.collect { case x: ConstantWithoutGrad =>
Expand Down

0 comments on commit a188657

Please sign in to comment.