diff --git a/emma-language/src/main/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaOptimizations.scala b/emma-language/src/main/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaOptimizations.scala index 96181a2f0..5f487ea3b 100644 --- a/emma-language/src/main/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaOptimizations.scala +++ b/emma-language/src/main/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaOptimizations.scala @@ -2,10 +2,17 @@ package eu.stratosphere.emma.compiler.ir.lnf import eu.stratosphere.emma.compiler.ir.CommonIR +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scalax.collection.GraphEdge.UnDiEdge +import scalax.collection.immutable.Graph + + trait SchemaOptimizations extends CommonIR { self: Language => import universe._ + import Tree._ // --------------------------------------------------------------------------- // Schema Analysis & Optimizations @@ -13,6 +20,8 @@ trait SchemaOptimizations extends CommonIR { object Schema { + import CaseClassMeta._ + /** Root trait for all fields. */ sealed trait Field @@ -29,7 +38,7 @@ trait SchemaOptimizations extends CommonIR { case class Info(fieldClasses: Set[FieldClass]) /** - * Compute the (local) schema information for a tree fragment. + * Compute the (global) schema information for a tree fragment. * * @param tree The ANF [[Tree]] to be analyzed. * @return The schema information for the input tree. @@ -39,14 +48,102 @@ trait SchemaOptimizations extends CommonIR { } /** - * Compute the (global) schema information for an anonymous function. + * Compute the (local) schema information for an anonymous function. * * @param tree The ANF [[Tree]] to be analyzed. * @return The schema information for the input tree. */ private[emma] def local(tree: Function): Schema.Info = { - Info(Set.empty[FieldClass]) + // initialize equivalences relation + val equivalences: mutable.Buffer[(Field, Field)] = ArrayBuffer() + + // traverse the function and collect equivalences + traverse(tree) { + // patterns of type `x = y.target` + case val_(x, Select(y, member@TermName(_)), _) => + equivalences += SimpleField(x) -> MemberField(Term.of(y), Term.member(y, member)) + equivalences ++= caseClassMemberEquivalences(x) + + // patterns of type `x = constructor (arg1, ..., argN)` + case val_(x, Apply(fun, args), _) if isCaseClassConstructor(fun.symbol) => + equivalences += SimpleField(x) -> SimpleField(x) + + // add constructor equivalences and members + val meta = CaseClassMeta(Type.of(x).typeSymbol) + equivalences ++= args.zip(meta.constructorProjections(fun.symbol)).map { + case (arg, param) => SimpleField(arg.symbol) -> MemberField(x, param) + } + equivalences ++= meta.memberEquivalences(x) + + // just a regular ValDef + case val_(x, _, _) => + equivalences += SimpleField(x) -> SimpleField(x) + equivalences ++= caseClassMemberEquivalences(x) + + } + + Info(equivalenceClasses(equivalences)) + } + + private def equivalenceClasses(equivalences: Seq[(Field, Field)]): Set[FieldClass] = { + val edges = equivalences.map { case (f, t) => UnDiEdge(f, t) } + val nodes = edges.flatMap(_.toSet) + val graph = Graph.from[Field, UnDiEdge](nodes, edges) + + // each connected component forms an equivalence class + graph.componentTraverser().map(_.nodes.map(_.value)).toSet + } + + class CaseClassMeta(sym: ClassSymbol) { + + // assert that the input argument is a case class symbol + assert(sym.isCaseClass) + + def constructorProjections(fun: Symbol): Seq[Symbol] = { + assert(isCaseClassConstructor(fun)) + val parameters = fun.asMethod.paramLists.head + parameters.map(p => sym.info.decl(p.name)) + } + + def members(): Set[Symbol] = { + Type.of(sym).members.filter { + case m: TermSymbol => m.isGetter && m.isCaseAccessor + }.toSet + } + + def memberFields(symbol: Symbol): Set[MemberField] = { + members().map(member => MemberField(symbol, member)) + } + + def memberEquivalences(symbol: Symbol): Set[(Field, Field)] = { + memberFields(symbol).map(f => f -> f) + } + } + + object CaseClassMeta { + + def apply(s: Symbol) = new CaseClassMeta(s.asClass) + + def isCaseClass(s: Symbol): Boolean = s.isClass && s.asClass.isCaseClass + + def isCaseClassConstructor(fun: Symbol): Boolean = { + assert(fun.isMethod) + val isCtr = fun.owner.companion.isModule && fun.isConstructor + val isApp = fun.owner.isModuleClass && fun.name == TermName("apply") && fun.isSynthetic + + isCtr || isApp + } + + def caseClassMemberEquivalences(s: Symbol): Set[(Field, Field)] = { + val tpe = Type.of(s).typeSymbol + if (isCaseClass(tpe)) { + CaseClassMeta(tpe).memberEquivalences(s) + } else { + Set.empty + } + } } + } } diff --git a/emma-language/src/test/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaInfoSpec.scala b/emma-language/src/test/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaInfoSpec.scala index 35c8184b0..2f8834ca1 100644 --- a/emma-language/src/test/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaInfoSpec.scala +++ b/emma-language/src/test/scala/eu/stratosphere/emma/compiler/ir/lnf/SchemaInfoSpec.scala @@ -24,9 +24,69 @@ class SchemaInfoSpec extends BaseCompilerSpec { case vd@ValDef(_, TermName(`name`), _, _) => vd.symbol }.head + "tuple tupes playground" - { + + import compiler.Type + + val examples = Seq( + typeCheck(reify { + (42, "foobar") + }), + typeCheck(reify { + new Tuple2(42, "foobar") + }), + typeCheck(reify { + Ad(1, "foobar", AdClass.FASHION) + }), + typeCheck(reify { + new Ad(1, "foobar", AdClass.FASHION) + }) + ) + + // extract the types of the objects constructed the examples + val types = for (e <- examples) yield { + Type.of(e).typeSymbol.asClass + } + + // assert that all are case classes + for (e <- types) assert(e.isCaseClass) + + // extract the symbols of the constructing functions + val funsyms = for (e <- examples) yield e match { + case Apply(fun, args) => fun.symbol + } + + // assert that the funsyms belong to recognized constructors + val res = for ((f, t) <- funsyms zip types) { + // the owning class of the function symbol should be the class of the target + val cmp = f.owner.companionSymbol + val cls = cmp.companion + assert(cls == t) + + // the function symbol should be either of a constructor of of a synthetic apply method + val isCtr = f.owner.companionSymbol.isModule && f.isConstructor + val isApp = f.owner.isModuleClass && f.name == TermName("apply") && f.isSynthetic + + assert(isCtr || isApp) + } + + // extract the projections associated with the function applications + val proj = for (f <- funsyms) yield { + // the owning class of the function symbol should be the class of the target + val cmp = f.owner.companionSymbol + val cls = cmp.companion + + for (param <- f.paramLists.head) yield { + val proj = cls.info.decl(param.name) + assert(proj.isAccessor) + (param, cls.info.decl(param.name)) + } + } + } + "local schema" - { "without control flow" in { - // ANF representation with `desugared` comprehensinos + // ANF representation with `desugared` comprehensions val fn = typeCheck(reify { (c: Click) => { val t = c.time @@ -61,10 +121,10 @@ class SchemaInfoSpec extends BaseCompilerSpec { val cls$04 /* */ = Set[Field](fld$c$time, fld$t) val cls$05 /* */ = Set[Field](fld$p, fld$a$_2) val cls$06 /* */ = Set[Field](fld$m, fld$a$_3) - val cls$v7 /* */ = Set[Field](fld$a) + val cls$07 /* */ = Set[Field](fld$a) // 3) construct the expected local schema information - val exp = Info(Set(cls$01, cls$02, cls$03, cls$01, cls$01, cls$01, cls$01)) + val exp = Info(Set(cls$01, cls$02, cls$03, cls$04, cls$05, cls$06, cls$07)) // compute actual local schema val act = compiler.Schema.local(fn)