From 23f79a8592f16a0275c2c6a959d17f7edae91090 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:24:50 -0400 Subject: [PATCH 01/34] first draft of core functionality --- src/choice_map2/array_interface.jl | 92 ++++++++++ src/choice_map2/choice_map.jl | 247 ++++++++++++++++++++++++++ src/choice_map2/dynamic_choice_map.jl | 153 ++++++++++++++++ src/choice_map2/nested_view.jl | 81 +++++++++ src/choice_map2/static_choice_map.jl | 131 ++++++++++++++ 5 files changed, 704 insertions(+) create mode 100644 src/choice_map2/array_interface.jl create mode 100644 src/choice_map2/choice_map.jl create mode 100644 src/choice_map2/dynamic_choice_map.jl create mode 100644 src/choice_map2/nested_view.jl create mode 100644 src/choice_map2/static_choice_map.jl diff --git a/src/choice_map2/array_interface.jl b/src/choice_map2/array_interface.jl new file mode 100644 index 00000000..f88c5b11 --- /dev/null +++ b/src/choice_map2/array_interface.jl @@ -0,0 +1,92 @@ +### interface for to_array and fill_array ### + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: ChoiceMap` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) <: start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) + ValueChoiceMap(arr[start_idx]) +end \ No newline at end of file diff --git a/src/choice_map2/choice_map.jl b/src/choice_map2/choice_map.jl new file mode 100644 index 00000000..d7e7101f --- /dev/null +++ b/src/choice_map2/choice_map.jl @@ -0,0 +1,247 @@ +######################### +# choice map interface # +######################### + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `ValueChoiceMap`s.) +""" +function get_submaps_shallow end + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +function get_submap end + +# provide _get_submap so when users overwrite get_submap(choices::CustomChoiceMap, addr::Pair) +# they can just call _get_submap for convenience if they want +@inline function _get_submap(choices::ChoiceMap, addr::Pair) + (first, rest) = addr + submap = get_submap(choices, first) + get_submap(submap, rest) +end +@inline get_submap(choices::ChoiceMap, addr::Pair) = _get_submap(choices, addr) + +""" + has_value(choices::ChoiceMap) + +Returns true if `choices` is a `ValueChoiceMap`. + + has_value(choices::ChoiceMap, addr) + +Returns true if `choices` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(::ChoiceMap) = false +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; +throws a `KeyError` if `choices` is not a `ValueChoiceMap`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `KeyError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +get_value(::ChoiceMap) = throw(KeyError(nothing)) +get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +# get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `ValueChoiceMap`. +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + filter(! ∘ has_value, get_submaps_shallow(choices)) +end + +# a choicemap is empty if it has no submaps and no value +Base.isempty(c::ChoiceMap) = isempty(get_submaps_shallow(c)) && !has_value(c) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + +""" + EmptyChoiceMap + +A choicemap with no submaps or values. +""" +struct EmptyChoiceMap <: ChoiceMap end + +@inline has_value(::EmptyChoiceMap, addr...) = false +@inline get_value(::EmptyChoiceMap) = throw(KeyError(nothing)) +@inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() +@inline Base.isempty(::EmptyChoiceMap) = true +@inline get_submaps_shallow(::EmptyChoiceMap) = () + +""" + ValueChoiceMap + +A leaf-node choicemap. Stores a single value. +""" +struct ValueChoiceMap{T} <: ChoiceMap + val::T +end + +@inline has_value(choices::ValueChoiceMap) = true +@inline get_value(choices::ValueChoiceMap) = choices.val +@inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() +@inline get_submaps_shallow(choices::ValueChoiceMap) = () +Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) + +""" + choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + +Merge two choice maps. + +It is an error if the choice maps both have values at the same address, or if +one choice map has a value at an address that is the prefix of the address of a +value in the other choice map. +""" +function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + choices = DynamicChoiceMap() + for (key, submap) in get_submaps_shallow(choices1) + set_submap!(choices, key, merge(submap, get_submap(choices2, key))) + end + choices +end +Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c +Base.merge(c::ValueChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ValueChoiceMap) = c +Base.merge(::ValueChoiceMap, ::ChoiceMap) = error("ValueChoiceMaps cannot be merged") +Base.merge(::ChoiceMap, ::ValueChoiceMap) = error("ValueChoiceMaps cannot be merged") + +""" +Variadic merge of choice maps. +""" +function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) + reduce(Base.merge, choices_rest; init=choices1) +end + +function Base.:(==)(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if get_submap(b, addr) != submap + return false + end + end + return true +end + +function Base.isapprox(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if !isapprox(get_submap(b, addr), submap) + return false + end + end + return true +end + +""" + selected_choices = get_selected(choices::ChoiceMap, selection::Selection) + +Filter the choice map to include only choices in the given selection. + +Returns a new choice map. +""" +function get_selected( + choices::ChoiceMap, selection::Selection) + # TODO: return a `FilteringChoiceMap` which does this filtering lazily! + output = choicemap() + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) && addr in selection + output[addr] = get_value(submap) + else + subselection = selection[addr] + set_submap!(output, addr, get_selected(submap, subselection)) + end + end + output +end + +function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_values = collect(get_values_shallow(choices)) + key_and_submaps = collect(get_nonvalue_submaps_shallow(choices)) + n = length(key_and_values) + length(key_and_submaps) + cur = 1 + for (key, value) in key_and_values + # For strings, `print` is what we want; `Base.show` includes quote marks. + # https://docs.julialang.org/en/v1/base/io-network/#Base.print + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + for (key, submap) in key_and_submaps + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) + _show_pretty(io, choices, 0, ()) +end + +export ChoiceMap, ValueChoiceMap, EmptyChoiceMap +export get_submap, get_submaps_shallow +export get_value, has_value +export get_values_shallow, get_nonvalue_submaps_shallow + +include("array_interface.jl") +include("dynamic_choice_map.jl") +include("static_choice_map.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map2/dynamic_choice_map.jl new file mode 100644 index 00000000..a93a4902 --- /dev/null +++ b/src/choice_map2/dynamic_choice_map.jl @@ -0,0 +1,153 @@ +####################### +# dynamic assignment # +####################### + +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, <:ChoiceMap} +end + +""" + struct DynamicChoiceMap <: ChoiceMap .. end + +A mutable map from arbitrary hierarchical addresses to values. + + choices = DynamicChoiceMap() + +Construct an empty map. + + choices = DynamicChoiceMap(tuples...) + +Construct a map containing each of the given (addr, value) tuples. +""" +function DynamicChoiceMap() + DynamicChoiceMap(Dict()) +end + +function DynamicChoiceMap(tuples...) + choices = DynamicChoiceMap() + for (addr, value) in tuples + choices[addr] = value + end + choices +end + +""" + choices = DynamicChoiceMap(other::ChoiceMap) + +Copy a choice map, returning a mutable choice map. +""" +function DynamicChoiceMap(other::ChoiceMap) + choices = DynamicChoiceMap() + for (addr, submap) in get_submaps_shallow(other) + if choices isa ValueChoiceMap + set_submap!(choices, addr, submap) + else + set_submap!(choices, addr, DynamicChoiceMap(submap)) + end + end +end + +DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given address, value tuples. +""" +function choicemap(tuples...) + DynamicChoiceMap(tuples...) +end + +get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +function get_submap(choices::DynamicChoiceMap, addr) + if haskey(choices.submaps, addr) + choices.submaps[addr] + else + EmptyChoiceMap() + end +end +get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) + +# mutation (not part of the assignment interface) + +""" + set_value!(choices::DynamicChoiceMap, addr, value) + +Set the given value for the given address. + +Will cause any previous value or sub-assignment at this address to be deleted. +It is an error if there is already a value present at some prefix of the given address. + +The following syntactic sugar is provided: + + choices[addr] = value +""" +function set_value!(choices::DynamicChoiceMap, addr, value) + delete!(choices.submaps, addr) + choices.submaps[addr] = ValueChoiceMap(value) +end + +function set_value!(choices::DynamicChoiceMap, addr::Pair, value) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_value!(choices.submaps[first], rest, value) +end + +""" + set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) + +Replace the sub-assignment rooted at the given address with the given sub-assignment. +Set the given value for the given address. + +Will cause any previous value or sub-assignment at the given address to be deleted. +It is an error if there is already a value present at some prefix of address. +""" +function set_submap!(choices::DynamicChoiceMap, addr, new_node) + delete!(choices.submaps, addr) + if !isempty(new_node) + choices.submaps[addr] = new_node + end +end + +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_submap!(choices.submaps[first], rest, new_node) +end + +Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) + +function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + choices = DynamicChoiceMap() + keys_sorted = sort(collect(keys(choices.submaps))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) + idx += n_read + choices.submaps[key] = submap + end + (idx - start_idx, choices) +end + +export DynamicChoiceMap +export choicemap +export set_value! +export set_submap! \ No newline at end of file diff --git a/src/choice_map2/nested_view.jl b/src/choice_map2/nested_view.jl new file mode 100644 index 00000000..6693234f --- /dev/null +++ b/src/choice_map2/nested_view.jl @@ -0,0 +1,81 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) +ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) + a.choice_map = b.choice_map +end +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map2/static_choice_map.jl new file mode 100644 index 00000000..e5e2d89e --- /dev/null +++ b/src/choice_map2/static_choice_map.jl @@ -0,0 +1,131 @@ +###################### +# static assignment # +###################### + +struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap + submaps::NamedTuple{Addrs, SubmapTypes} +end + +@inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) +@inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) + +# TODO: profiling! +@generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} + if A in Addrs + quote choices.submaps[A] end + else + quote EmptyChoiceMap() end + end +end + +static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +function StaticChoiceMap(other::ChoiceMap) + keys_and_nodes = get_submaps_shallow(other) + (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + StaticChoiceMap(NamedTuple{addrs}(submaps)) +end +StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") + +# TODO: deep conversion to static choicemap + +""" + choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + +Return an assignment that contains `choices1` as a sub-assignment under `key1` +and `choices2` as a sub-assignment under `key2`. +""" +function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + StaticChoiceMap(NamedTuple{(key1, key2)}((choices1, choices2))) +end + +""" + (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + +Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any submaps at keys other than `key1` and `key2`. +""" +function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + if length(collect(get_submaps_shallow(choices))) != 2 + error("Not a pair") + end + (get_submap(choices, key1), get_submap(choices, key2)) +end + +@generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, + choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, ::Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, ::Type{<:ChoiceMap}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyChoiceMap) + type2 = get(addr_to_type2, addr, EmptyChoiceMap) + if ((type1 <: ValueChoiceMap && type2 != EmptyChoiceMap) + || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) + error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") + end + if type1 <: ValueChoiceMap + push!(submap_exprs, + quote choices1.submaps[$addr] end + ) + elseif type2 <: ValueChoiceMap + push!(submap_exprs, + quote choices2.submaps[$addr] end + ) + else + push!(submap_exprs, + quote merge(choices1.submaps[$addr], choices2.submaps[$addr]) end + ) + end + end + + quote + StaticChoiceMap{$merged_addrs}(submap_exprs...) + end +end + +@generated function _from_array!(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} + + perm = sortperm(Addrs) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, submap_var_name = _from_array(proto_choices.submaps[$addr], arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...) )) + choices = StaticChoiceMap{Addrs, SubmapTypes}(submaps) + (idx - start_idx, choices) + end +end + +export StaticChoiceMap +export pair, unpair +export static_get_submap, static_get_value \ No newline at end of file From c9b1d4982e5f8f4903254adc982d4d5a216c5580 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:33:00 -0400 Subject: [PATCH 02/34] add support for address schemas --- src/choice_map2/choice_map.jl | 9 +++++++++ src/choice_map2/dynamic_choice_map.jl | 2 ++ src/choice_map2/static_choice_map.jl | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/src/choice_map2/choice_map.jl b/src/choice_map2/choice_map.jl index d7e7101f..0ebb19f0 100644 --- a/src/choice_map2/choice_map.jl +++ b/src/choice_map2/choice_map.jl @@ -60,6 +60,13 @@ get_value(::ChoiceMap) = throw(KeyError(nothing)) get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) +""" +schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} + +Return the (top-level) address schema for the given choice map. +""" +function get_address_schema end + # get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow """ get_values_shallow(choices::ChoiceMap) @@ -108,6 +115,7 @@ struct EmptyChoiceMap <: ChoiceMap end @inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() @inline Base.isempty(::EmptyChoiceMap) = true @inline get_submaps_shallow(::EmptyChoiceMap) = () +@inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() """ ValueChoiceMap @@ -124,6 +132,7 @@ end @inline get_submaps_shallow(choices::ValueChoiceMap) = () Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() """ choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map2/dynamic_choice_map.jl index a93a4902..5dfca0b5 100644 --- a/src/choice_map2/dynamic_choice_map.jl +++ b/src/choice_map2/dynamic_choice_map.jl @@ -147,6 +147,8 @@ function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx: (idx - start_idx, choices) end +get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() + export DynamicChoiceMap export choicemap export set_value! diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map2/static_choice_map.jl index e5e2d89e..3508762d 100644 --- a/src/choice_map2/static_choice_map.jl +++ b/src/choice_map2/static_choice_map.jl @@ -126,6 +126,10 @@ end end end +function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} + StaticAddressSchema(set(Addrs)) +end + export StaticChoiceMap export pair, unpair export static_get_submap, static_get_value \ No newline at end of file From 1e0a58997d4717eb687aba6306b8b556108475bb Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:42:52 -0400 Subject: [PATCH 03/34] update choicemap docs --- docs/src/ref/choice_maps.md | 43 +++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index c065b1b3..8d3f4200 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -8,13 +8,20 @@ ChoiceMap Choice maps are constructed by users to express observations and/or constraints on the traces of generative functions. Choice maps are also returned by certain Gen inference methods, and are used internally by various Gen inference methods. +A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses +for sub-choicemaps. Leaf nodes have type: +```@docs +ValueChoiceMap +``` + Choice maps provide the following methods: ```@docs +get_submap +get_submaps_shallow has_value get_value -get_submap get_values_shallow -get_submaps_shallow +get_nonvalue_submaps_shallow to_array from_array get_selected @@ -50,3 +57,35 @@ choicemap set_value! set_submap! ``` + +## Implementing custom choicemap types + +To implement a custom choicemap, one must implement +`get_submap` and `get_submaps_shallow`. +To avoid method ambiguity with the default +`get_submap(::ChoiceMap, ::Pair)`, one must implement both +```julia +get_submap(::CustomChoiceMap, addr) +``` +and +```julia +get_submap(::CustomChoiceMap, addr::Pair) +``` +To use the default implementation of `get_submap(_, ::Pair)`, +one may define +```julia +get_submap(c::CustomChoiceMap, addr::Pair) = _get_choicemap(c, addr) +``` + +Once `get_submap` and `get_submaps_shallow` are defined, default +implementations are provided for: +- `has_value` +- `get_value` +- `get_values_shallow` +- `get_nonvalue_submaps_shallow` +- `to_array` +- `get_selected` + +If one wishes to support `from_array`, they must implement +`_from_array`, as described in the documentation for +[`from_array`](@ref). \ No newline at end of file From 623bc8fcba7fc81eecb039a13d861baf06102d57 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 16:57:46 -0400 Subject: [PATCH 04/34] refactoring and tests --- docs/src/ref/choice_maps.md | 2 +- src/Gen.jl | 2 +- src/choice_map.jl | 1009 ----------------- .../array_interface.jl | 20 +- src/{choice_map2 => choice_map}/choice_map.jl | 57 +- .../dynamic_choice_map.jl | 20 +- .../nested_view.jl | 7 +- .../static_choice_map.jl | 52 +- src/dynamic/dynamic.jl | 31 +- src/dynamic/generate.jl | 2 +- src/dynamic/trace.jl | 36 +- src/dynamic/update.jl | 26 +- src/inference/kernel_dsl.jl | 11 +- src/modeling_library/call_at/call_at.jl | 5 +- src/modeling_library/choice_at/choice_at.jl | 4 +- src/modeling_library/recurse/recurse.jl | 18 +- src/modeling_library/vector.jl | 4 - src/static_ir/backprop.jl | 21 +- src/static_ir/trace.jl | 75 +- src/static_ir/update.jl | 7 +- test/assignment.jl | 224 ++-- test/benchmark.md | 21 + test/dynamic_dsl.jl | 14 +- test/modeling_library/call_at.jl | 26 +- test/modeling_library/choice_at.jl | 26 +- test/modeling_library/recurse.jl | 4 +- test/modeling_library/unfold.jl | 6 +- test/optional_args.jl | 2 +- test/runtests.jl | 2 +- test/static_ir/static_ir.jl | 10 +- test/tilde_sugar.jl | 2 +- 31 files changed, 347 insertions(+), 1399 deletions(-) delete mode 100644 src/choice_map.jl rename src/{choice_map2 => choice_map}/array_interface.jl (83%) rename src/{choice_map2 => choice_map}/choice_map.jl (82%) rename src/{choice_map2 => choice_map}/dynamic_choice_map.jl (93%) rename src/{choice_map2 => choice_map}/nested_view.jl (93%) rename src/{choice_map2 => choice_map}/static_choice_map.jl (68%) create mode 100644 test/benchmark.md diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 8d3f4200..6c445df6 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -30,7 +30,7 @@ Note that none of these methods mutate the choice map. Choice maps also implement: -- `Base.isempty`, which tests of there are no random choices in the choice map +- `Base.isempty`, which returns `false` if the choicemap contains no value or submaps, and `true` otherwise. - `Base.merge`, which takes two choice maps, and returns a new choice map containing all random choices in either choice map. It is an error if the choice maps both have values at the same address, or if one choice map has a value at an address that is the prefix of the address of a value in the other choice map. diff --git a/src/Gen.jl b/src/Gen.jl index 9f3da9e3..fa239359 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -37,7 +37,7 @@ include("backprop.jl") include("address.jl") # abstract and built-in concrete choice map data types -include("choice_map.jl") +include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") diff --git a/src/choice_map.jl b/src/choice_map.jl deleted file mode 100644 index b7891b40..00000000 --- a/src/choice_map.jl +++ /dev/null @@ -1,1009 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -""" - submap = get_submap(choices::ChoiceMap, addr) - -Return the sub-assignment containing all choices whose address is prefixed by addr. - -It is an error if the assignment contains a value at the given address. If -there are no choices whose address is prefixed by addr then return an -`EmptyChoiceMap`. -""" -function get_submap end - -""" - value = get_value(choices::ChoiceMap, addr) - -Return the value at the given address in the assignment, or throw a KeyError if -no value exists. A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end - -""" - key_submap_iterable = get_submaps_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, submap::ChoiceMap)` for each top-level key -that has a non-empty sub-assignment. -""" -function get_submaps_shallow end - -""" - has_value(choices::ChoiceMap, addr) - -Return true if there is a value at the given address. -""" -function has_value end - -""" - key_submap_iterable = get_values_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, value)` for each -top-level key associated with a value. -""" -function get_values_shallow end - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - Base.isempty(choices::ChoiceMap) - -Return true if there are no values in the assignment. -""" -function Base.isempty(::ChoiceMap) - true -end - -@inline get_submap(choices::ChoiceMap, addr) = EmptyChoiceMap() -@inline has_value(choices::ChoiceMap, addr) = false -@inline get_value(choices::ChoiceMap, addr) = throw(KeyError(addr)) -@inline Base.getindex(choices::ChoiceMap, addr) = get_value(choices, addr) - -@inline function _has_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - has_value(submap, rest) -end - -@inline function _get_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_value(submap, rest) -end - -@inline function _get_submap(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -# assignments that have static address schemas should also support faster -# accessors, which make the address explicit in the type (Val(:foo) instaed of -# :foo) -function static_get_value end -function static_get_submap end - -function _fill_array! end -function _from_array end - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -# Implementation - -To support `to_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(value::T, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = value - 1 -end - -function _fill_array!(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -The order in which addresses are populated is determined by the prototype -assignment. It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -# Implementation - -To support `from_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`, and the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::T, arr::Vector{T}, start_idx::Int) where {T} - (1, arr[start_idx]) -end - -function _from_array(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(value) - (n_read, arr[start_idx:start_idx+n_read-1]) -end - - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, value) in get_values_shallow(choices1) - choices.leaf_nodes[key] = value - end - for (key, node1) in get_submaps_shallow(choices1) - node2 = get_submap(choices2, key) - node = merge(node1, node2) - choices.internal_nodes[key] = node - end - for (key, value) in get_values_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has leaf node at $key") - end - if haskey(choices.internal_nodes, key) - error("choices1 has internal node at $key and choices2 has leaf node at $key") - end - choices.leaf_nodes[key] = value - end - for (key, node) in get_submaps_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has internal node at $key") - end - if !haskey(choices.internal_nodes, key) - # otherwise it should already be included - choices.internal_nodes[key] = node - end - end - return choices -end - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || (get_value(b, addr) != value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || (get_value(a, addr) != value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if submap != get_submap(b, addr) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if submap != get_submap(a, addr) - return false - end - end - return true -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || !isapprox(get_value(b, addr), value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || !isapprox(get_value(a, addr), value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(submap, get_submap(b, addr)) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if !isapprox(submap, get_submap(a, addr)) - return false - end - end - return true -end - - -export ChoiceMap -export get_address_schema -export get_submap -export get_value -export has_value -export get_submaps_shallow -export get_values_shallow -export static_get_value -export static_get_submap -export to_array, from_array - - -###################### -# static assignment # -###################### - -struct StaticChoiceMap{R,S,T,U} <: ChoiceMap - leaf_nodes::NamedTuple{R,S} - internal_nodes::NamedTuple{T,U} - isempty::Bool -end - -function StaticChoiceMap{R,S,T,U}(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - -function StaticChoiceMap(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - - -# invariant: all internal_nodes are nonempty - -function get_address_schema(::Type{StaticChoiceMap{R,S,T,U}}) where {R,S,T,U} - keys = Set{Symbol}() - for (key, _) in zip(R, S.parameters) - push!(keys, key) - end - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -function Base.isempty(choices::StaticChoiceMap) - choices.isempty -end - -get_values_shallow(choices::StaticChoiceMap) = pairs(choices.leaf_nodes) -get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.internal_nodes) -has_value(choices::StaticChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::StaticChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# NOTE: there is no static_has_value because this is known from the static -# address schema - -## has_value ## - -function has_value(choices::StaticChoiceMap, key::Symbol) - haskey(choices.leaf_nodes, key) -end - -## get_submap ## - -function get_submap(choices::StaticChoiceMap, key::Symbol) - if haskey(choices.internal_nodes, key) - choices.internal_nodes[key] - elseif haskey(choices.leaf_nodes, key) - throw(KeyError(key)) - else - EmptyChoiceMap() - end -end - -function static_get_submap(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.internal_nodes[A] -end - -## get_value ## - -function get_value(choices::StaticChoiceMap, key::Symbol) - choices.leaf_nodes[key] -end - -function static_get_value(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.leaf_nodes[A] -end - -# convert from any other schema that has only Val{:foo} addresses -function StaticChoiceMap(other::ChoiceMap) - leaf_keys_and_nodes = collect(get_values_shallow(other)) - internal_keys_and_nodes = collect(get_submaps_shallow(other)) - if length(leaf_keys_and_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(leaf_keys_and_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) - end - if length(internal_keys_and_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(internal_keys_and_nodes...)) - else - (internal_keys, internal_nodes) = ((), ()) - end - StaticChoiceMap( - NamedTuple{leaf_keys}(leaf_nodes), - NamedTuple{internal_keys}(internal_nodes), - isempty(other)) -end - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple(), NamedTuple{(key1,key2)}((choices1, choices2)), - isempty(choices1) && isempty(choices2)) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any top-level values, or any non-empty top-level -sub-assignments at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if !isempty(get_values_shallow(choices)) || length(collect(get_submaps_shallow(choices))) > 2 - error("Not a pair") - end - a = get_submap(choices, key1) - b = get_submap(choices, key2) - (a, b) -end - -# TODO make a generated function? -function _fill_array!(choices::StaticChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for value in choices.leaf_nodes - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for node in choices.internal_nodes - n_written = _fill_array!(node, arr, idx) - idx += n_written - end - idx - start_idx -end - -@generated function _from_array( - proto_choices::StaticChoiceMap{R,S,T,U}, arr::Vector{V}, start_idx::Int) where {R,S,T,U,V} - leaf_node_keys = proto_choices.parameters[1] - leaf_node_types = proto_choices.parameters[2].parameters - internal_node_keys = proto_choices.parameters[3] - internal_node_types = proto_choices.parameters[4].parameters - - exprs = [quote idx = start_idx end] - leaf_node_names = [] - internal_node_names = [] - - # leaf nodes - for key in leaf_node_keys - value = gensym() - push!(leaf_node_names, value) - push!(exprs, quote - (n_read, $value) = _from_array(proto_choices.leaf_nodes.$key, arr, idx) - idx += n_read - end) - end - - # internal nodes - for key in internal_node_keys - node = gensym() - push!(internal_node_names, node) - push!(exprs, quote - (n_read, $node) = _from_array(proto_choices.internal_nodes.$key, arr, idx) - idx += n_read - end) - end - - quote - $(exprs...) - leaf_nodes_field = NamedTuple{R,S}(($(leaf_node_names...),)) - internal_nodes_field = NamedTuple{T,U}(($(internal_node_names...),)) - choices = StaticChoiceMap{R,S,T,U}(leaf_nodes_field, internal_nodes_field) - (idx - start_idx, choices) - end -end - -@generated function Base.merge(choices1::StaticChoiceMap{R,S,T,U}, - choices2::StaticChoiceMap{W,X,Y,Z}) where {R,S,T,U,W,X,Y,Z} - - # unpack first assignment type parameters - leaf_node_keys1 = choices1.parameters[1] - leaf_node_types1 = choices1.parameters[2].parameters - internal_node_keys1 = choices1.parameters[3] - internal_node_types1 = choices1.parameters[4].parameters - keys1 = (leaf_node_keys1..., internal_node_keys1...,) - - # unpack second assignment type parameters - leaf_node_keys2 = choices2.parameters[1] - leaf_node_types2 = choices2.parameters[2].parameters - internal_node_keys2 = choices2.parameters[3] - internal_node_types2 = choices2.parameters[4].parameters - keys2 = (leaf_node_keys2..., internal_node_keys2...,) - - # leaf vs leaf collision is an error - colliding_leaf_leaf_keys = intersect(leaf_node_keys1, leaf_node_keys2) - if !isempty(colliding_leaf_leaf_keys) - error("choices1 and choices2 both have leaf nodes at key(s): $colliding_leaf_leaf_keys") - end - - # leaf vs internal collision is an error - colliding_leaf_internal_keys = intersect(leaf_node_keys1, internal_node_keys2) - if !isempty(colliding_leaf_internal_keys) - error("choices1 has leaf node and choices2 has internal node at key(s): $colliding_leaf_internal_keys") - end - - # internal vs leaf collision is an error - colliding_internal_leaf_keys = intersect(internal_node_keys1, leaf_node_keys2) - if !isempty(colliding_internal_leaf_keys) - error("choices1 has internal node and choices2 has leaf node at key(s): $colliding_internal_leaf_keys") - end - - # internal vs internal collision is not an error, recursively call merge - colliding_internal_internal_keys = (intersect(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys1_exclusive = (setdiff(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys2_exclusive = (setdiff(internal_node_keys2, internal_node_keys1)...,) - - # leaf nodes named tuple - leaf_node_keys = (leaf_node_keys1..., leaf_node_keys2...,) - leaf_node_types = map(QuoteNode, (leaf_node_types1..., leaf_node_types2...,)) - leaf_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys1]..., - [Expr(:(.), :(choices2.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys2]...) - leaf_nodes = Expr(:call, - Expr(:curly, :NamedTuple, - QuoteNode(leaf_node_keys), - Expr(:curly, :Tuple, leaf_node_types...)), - leaf_node_values) - - # internal nodes named tuple - internal_node_keys = (internal_node_keys1_exclusive..., - internal_node_keys2_exclusive..., - colliding_internal_internal_keys...) - internal_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)) - for key in internal_node_keys1_exclusive]..., - [Expr(:(.), :(choices2.internal_nodes), QuoteNode(key)) - for key in internal_node_keys2_exclusive]..., - [Expr(:call, :merge, - Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)), - Expr(:(.), :(choices2.internal_nodes), QuoteNode(key))) - for key in colliding_internal_internal_keys]...) - internal_nodes = Expr(:call, - Expr(:curly, :NamedTuple, QuoteNode(internal_node_keys)), - internal_node_values) - - # construct assignment from named tuples - Expr(:call, :StaticChoiceMap, leaf_nodes, internal_nodes) -end - -export StaticChoiceMap -export pair, unpair - -####################### -# dynamic assignment # -####################### - -struct DynamicChoiceMap <: ChoiceMap - leaf_nodes::Dict{Any,Any} - internal_nodes::Dict{Any,Any} - function DynamicChoiceMap(leaf_nodes::Dict{Any,Any}, internal_nodes::Dict{Any,Any}) - new(leaf_nodes, internal_nodes) - end -end - -# invariant: all internal nodes are nonempty - -""" - struct DynamicChoiceMap <: ChoiceMap .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -function DynamicChoiceMap() - DynamicChoiceMap(Dict(), Dict()) -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for (addr, value) in tuples - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, val) in get_values_shallow(other) - choices[addr] = val - end - for (addr, submap) in get_submaps_shallow(other) - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - choices -end - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -get_values_shallow(choices::DynamicChoiceMap) = choices.leaf_nodes - -get_submaps_shallow(choices::DynamicChoiceMap) = choices.internal_nodes - -has_value(choices::DynamicChoiceMap, addr::Pair) = _has_value(choices, addr) - -get_value(choices::DynamicChoiceMap, addr::Pair) = _get_value(choices, addr) - -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.internal_nodes, addr) - choices.internal_nodes[addr] - elseif haskey(choices.leaf_nodes, addr) - throw(KeyError(addr)) - else - EmptyChoiceMap() - end -end - -has_value(choices::DynamicChoiceMap, addr) = haskey(choices.leaf_nodes, addr) - -get_value(choices::DynamicChoiceMap, addr) = choices.leaf_nodes[addr] - -function Base.isempty(choices::DynamicChoiceMap) - isempty(choices.leaf_nodes) && isempty(choices.internal_nodes) -end - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.internal_nodes, addr) - choices.leaf_nodes[addr] = value -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - node = choices.internal_nodes[first] - set_value!(node, rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node) - delete!(choices.leaf_nodes, addr) - delete!(choices.internal_nodes, addr) - if !isempty(new_node) - choices.internal_nodes[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - set_submap!(node, rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _fill_array!(choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - leaf_keys_sorted = sort(collect(keys(choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - value = choices.leaf_nodes[key] - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for key in internal_node_keys_sorted - n_written = _fill_array!(get_submap(choices, key), arr, idx) - idx += n_written - end - idx - start_idx -end - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - @assert length(arr) >= start_idx - choices = DynamicChoiceMap() - leaf_keys_sorted = sort(collect(keys(proto_choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(proto_choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - (n_read, value) = _from_array(proto_choices.leaf_nodes[key], arr, idx) - idx += n_read - choices.leaf_nodes[key] = value - end - for key in internal_node_keys_sorted - (n_read, node) = _from_array(get_submap(proto_choices, key), arr, idx) - idx += n_read - choices.internal_nodes[key] = node - end - (idx - start_idx, choices) -end - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! - - -####################################### -## vector combinator for assignments # -####################################### - -# TODO implement LeafVectorChoiceMap, which stores a vector of leaf nodes - -struct InternalVectorChoiceMap{T} <: ChoiceMap - internal_nodes::Vector{T} - is_empty::Bool -end - -function vectorize_internal(nodes::Vector{T}) where {T} - is_empty = all(map(isempty, nodes)) - InternalVectorChoiceMap(nodes, is_empty) -end - -# note some internal nodes may be empty - -get_address_schema(::Type{InternalVectorChoiceMap}) = VectorAddressSchema() - -Base.isempty(choices::InternalVectorChoiceMap) = choices.is_empty -has_value(choices::InternalVectorChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::InternalVectorChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::InternalVectorChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::InternalVectorChoiceMap, addr::Int) - if addr > 0 && addr <= length(choices.internal_nodes) - choices.internal_nodes[addr] - else - EmptyChoiceMap() - end -end - -function get_submaps_shallow(choices::InternalVectorChoiceMap) - ((i, choices.internal_nodes[i]) - for i=1:length(choices.internal_nodes) - if !isempty(choices.internal_nodes[i])) -end - -get_values_shallow(::InternalVectorChoiceMap) = () - -function _fill_array!(choices::InternalVectorChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for key=1:length(choices.internal_nodes) - n = _fill_array!(choices.internal_nodes[key], arr, idx) - idx += n - end - idx - start_idx -end - -function _from_array(proto_choices::InternalVectorChoiceMap{U}, arr::Vector{T}, start_idx::Int) where {T,U} - @assert length(arr) >= start_idx - nodes = Vector{U}(undef, length(proto_choices.internal_nodes)) - idx = start_idx - for key=1:length(proto_choices.internal_nodes) - (n_read, nodes[key]) = _from_array(proto_choices.internal_nodes[key], arr, idx) - idx += n_read - end - choices = InternalVectorChoiceMap(nodes, proto_choices.is_empty) - (idx - start_idx, choices) -end - -export InternalVectorChoiceMap -export vectorize_internal - - -#################### -# empty assignment # -#################### - -struct EmptyChoiceMap <: ChoiceMap end - -Base.isempty(::EmptyChoiceMap) = true -get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -get_submaps_shallow(::EmptyChoiceMap) = () -get_values_shallow(::EmptyChoiceMap) = () - -_fill_array!(::EmptyChoiceMap, arr::Vector, start_idx::Int) = 0 -_from_array(::EmptyChoiceMap, arr::Vector, start_idx::Int) = (0, EmptyChoiceMap()) - -export EmptyChoiceMap - -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -function Base.getindex(choices::ChoiceMapNestedView, addr) - if has_value(choices.choice_map, addr) - return get_value(choices.choice_map, addr) - end - submap = get_submap(choices.choice_map, addr) - if isempty(submap) - throw(KeyError(addr)) - end - ChoiceMapNestedView(submap) -end - -function Base.iterate(c::ChoiceMapNestedView) - inner_iterator = Base.Iterators.flatten(( - get_values_shallow(c.choice_map), - ((k, ChoiceMapNestedView(v)) - for (k, v) in get_submaps_shallow(c.choice_map)))) - r = Base.iterate(inner_iterator) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (inner_iterator, inner_state) = state - r = Base.iterate(inner_iterator, inner_state) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::Gen.ChoiceMapNestedView) = (k for (k, v) in cv) - -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map == b.choice_map -end - -# Length of a `ChoiceMapNestedView` is number of leaf values + number of -# submaps. Motivation: This matches what `length` would return for the -# equivalent nested dict. -function Base.length(cv::ChoiceMapNestedView) - +(get_values_shallow(cv.choice_map) |> collect |> length, - get_submaps_shallow(cv.choice_map) |> collect |> length) -end - -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - output = choicemap() - for (key, value) in get_values_shallow(choices) - if (key in selection) - output[key] = value - end - end - for (key, submap) in get_submaps_shallow(choices) - subselection = selection[key] - set_submap!(output, key, get_selected(submap, subselection)) - end - output -end - -export get_selected diff --git a/src/choice_map2/array_interface.jl b/src/choice_map/array_interface.jl similarity index 83% rename from src/choice_map2/array_interface.jl rename to src/choice_map/array_interface.jl index f88c5b11..cf9d0bd0 100644 --- a/src/choice_map2/array_interface.jl +++ b/src/choice_map/array_interface.jl @@ -34,12 +34,20 @@ function to_array(choices::ChoiceMap, ::Type{T}) where {T} end function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) <: start_idx + if length(arr) < start_idx resize!(arr, 2 * start_idx) end arr[start_idx] = get_value(c) 1 end +function _fill_array!(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end # default _fill_array! implementation function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} @@ -88,5 +96,11 @@ function from_array(proto_choices::ChoiceMap, arr::Vector) end function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) - ValueChoiceMap(arr[start_idx]) -end \ No newline at end of file + (1, ValueChoiceMap(arr[start_idx])) +end +function _from_array(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, ValueChoiceMap(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/choice_map2/choice_map.jl b/src/choice_map/choice_map.jl similarity index 82% rename from src/choice_map2/choice_map.jl rename to src/choice_map/choice_map.jl index 0ebb19f0..402cefa3 100644 --- a/src/choice_map2/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -2,6 +2,22 @@ # choice map interface # ######################### +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + """ get_submaps_shallow(choices::ChoiceMap) @@ -26,7 +42,6 @@ function get_submap end submap = get_submap(choices, first) get_submap(submap, rest) end -@inline get_submap(choices::ChoiceMap, addr::Pair) = _get_submap(choices, addr) """ has_value(choices::ChoiceMap) @@ -45,18 +60,18 @@ function has_value end get_value(choices::ChoiceMap) Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; -throws a `KeyError` if `choices` is not a `ValueChoiceMap`. +throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. get_value(choices::ChoiceMap, addr) Returns the value stored in the submap with address `addr` or throws -a `KeyError` if no value exists at this address. +a `ChoiceMapGetValueError` if no value exists at this address. A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -get_value(::ChoiceMap) = throw(KeyError(nothing)) +get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) @@ -73,6 +88,8 @@ function get_address_schema end Returns an iterable collection of tuples `(address, value)` for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) """ function get_values_shallow(choices::ChoiceMap) ( @@ -88,20 +105,15 @@ end Returns an iterable collection of tuples `(address, submap)` for every top-level submap stored in `choices` which is not a `ValueChoiceMap`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) """ function get_nonvalue_submaps_shallow(choices::ChoiceMap) - filter(! ∘ has_value, get_submaps_shallow(choices)) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) end # a choicemap is empty if it has no submaps and no value -Base.isempty(c::ChoiceMap) = isempty(get_submaps_shallow(c)) && !has_value(c) - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end +Base.isempty(c::ChoiceMap) = all(((addr, submap),) -> isempty(submap), get_submaps_shallow(c)) && !has_value(c) """ EmptyChoiceMap @@ -111,11 +123,14 @@ A choicemap with no submaps or values. struct EmptyChoiceMap <: ChoiceMap end @inline has_value(::EmptyChoiceMap, addr...) = false -@inline get_value(::EmptyChoiceMap) = throw(KeyError(nothing)) +@inline get_value(::EmptyChoiceMap) = throw(ChoiceMapGetValueError()) @inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() @inline Base.isempty(::EmptyChoiceMap) = true @inline get_submaps_shallow(::EmptyChoiceMap) = () @inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() +@inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true +@inline Base.:(==)(::ChoiceMap, ::EmptyChoiceMap) = false +@inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false """ ValueChoiceMap @@ -148,6 +163,11 @@ function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) for (key, submap) in get_submaps_shallow(choices1) set_submap!(choices, key, merge(submap, get_submap(choices2, key))) end + for (key, submap) in get_submaps_shallow(choices2) + if isempty(get_submap(choices1, key)) + set_submap!(choices, key, submap) + end + end choices end Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c @@ -170,6 +190,11 @@ function Base.:(==)(a::ChoiceMap, b::ChoiceMap) return false end end + for (addr, submap) in get_submaps_shallow(b) + if get_submap(a, addr) != submap + return false + end + end return true end @@ -246,9 +271,11 @@ function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) end export ChoiceMap, ValueChoiceMap, EmptyChoiceMap -export get_submap, get_submaps_shallow +export _get_submap, get_submap, get_submaps_shallow export get_value, has_value export get_values_shallow, get_nonvalue_submaps_shallow +export get_address_schema, get_selected +export ChoiceMapGetValueError include("array_interface.jl") include("dynamic_choice_map.jl") diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl similarity index 93% rename from src/choice_map2/dynamic_choice_map.jl rename to src/choice_map/dynamic_choice_map.jl index 5dfca0b5..a3403307 100644 --- a/src/choice_map2/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -2,10 +2,6 @@ # dynamic assignment # ####################### -struct DynamicChoiceMap <: ChoiceMap - submaps::Dict{Any, <:ChoiceMap} -end - """ struct DynamicChoiceMap <: ChoiceMap .. end @@ -19,8 +15,11 @@ Construct an empty map. Construct a map containing each of the given (addr, value) tuples. """ -function DynamicChoiceMap() - DynamicChoiceMap(Dict()) +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, ChoiceMap} + function DynamicChoiceMap() + new(Dict()) + end end function DynamicChoiceMap(tuples...) @@ -39,12 +38,13 @@ Copy a choice map, returning a mutable choice map. function DynamicChoiceMap(other::ChoiceMap) choices = DynamicChoiceMap() for (addr, submap) in get_submaps_shallow(other) - if choices isa ValueChoiceMap + if submap isa ValueChoiceMap set_submap!(choices, addr, submap) else set_submap!(choices, addr, DynamicChoiceMap(submap)) end end + choices end DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") @@ -116,14 +116,14 @@ Set the given value for the given address. Will cause any previous value or sub-assignment at the given address to be deleted. It is an error if there is already a value present at some prefix of address. """ -function set_submap!(choices::DynamicChoiceMap, addr, new_node) +function set_submap!(choices::DynamicChoiceMap, addr, new_node::ChoiceMap) delete!(choices.submaps, addr) if !isempty(new_node) choices.submaps[addr] = new_node end end -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node::ChoiceMap) (first, rest) = addr if !haskey(choices.submaps, first) choices.submaps[first] = DynamicChoiceMap() @@ -137,7 +137,7 @@ Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, add function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} choices = DynamicChoiceMap() - keys_sorted = sort(collect(keys(choices.submaps))) + keys_sorted = sort(collect(keys(proto_choices.submaps))) idx = start_idx for key in keys_sorted (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) diff --git a/src/choice_map2/nested_view.jl b/src/choice_map/nested_view.jl similarity index 93% rename from src/choice_map2/nested_view.jl rename to src/choice_map/nested_view.jl index 6693234f..68add0a0 100644 --- a/src/choice_map2/nested_view.jl +++ b/src/choice_map/nested_view.jl @@ -33,7 +33,7 @@ ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") function Base.getindex(choices::ChoiceMapNestedView, addr) - ChoiceMapNestedView(get_submap(choices, addr)) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) end function Base.iterate(c::ChoiceMapNestedView) @@ -62,9 +62,8 @@ end # exists. Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map = b.choice_map -end +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + function Base.length(cv::ChoiceMapNestedView) length(collect(get_submaps_shallow(cv.choice_map))) end diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map/static_choice_map.jl similarity index 68% rename from src/choice_map2/static_choice_map.jl rename to src/choice_map/static_choice_map.jl index 3508762d..1f75b3bc 100644 --- a/src/choice_map2/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -4,13 +4,21 @@ struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap submaps::NamedTuple{Addrs, SubmapTypes} + function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} + new{Addrs, SubmapTypes}(submaps) + end +end + +function StaticChoiceMap(;addrs_to_vals_and_maps...) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) + maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : ValueChoiceMap(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) + StaticChoiceMap(NamedTuple{addrs}(maps)) end @inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) @inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) @inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) -# TODO: profiling! @generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} if A in Addrs quote choices.submaps[A] end @@ -18,17 +26,25 @@ end quote EmptyChoiceMap() end end end +static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = get_submaps_shallow(other) - (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + keys_and_nodes = collect(get_submaps_shallow(other)) + if length(keys_and_nodes) > 0 + (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + else + addrs = () + submaps = () + end StaticChoiceMap(NamedTuple{addrs}(submaps)) end StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") +StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() # TODO: deep conversion to static choicemap @@ -58,9 +74,9 @@ end @generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} - - addr_to_type1 = Dict{Symbol, ::Type{<:ChoiceMap}}() - addr_to_type2 = Dict{Symbol, ::Type{<:ChoiceMap}}() + + addr_to_type1 = Dict{Symbol, Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, Type{<:ChoiceMap}}() for (i, addr) in enumerate(Addrs1) addr_to_type1[addr] = SubmapTypes1.parameters[i] end @@ -78,30 +94,30 @@ end || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") end - if type1 <: ValueChoiceMap + if type1 <: EmptyChoiceMap push!(submap_exprs, - quote choices1.submaps[$addr] end + quote choices2.submaps.$addr end ) - elseif type2 <: ValueChoiceMap + elseif type2 <: EmptyChoiceMap push!(submap_exprs, - quote choices2.submaps[$addr] end + quote choices1.submaps.$addr end ) else push!(submap_exprs, - quote merge(choices1.submaps[$addr], choices2.submaps[$addr]) end + quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end ) end end quote - StaticChoiceMap{$merged_addrs}(submap_exprs...) + StaticChoiceMap(NamedTuple{$merged_addrs}(($(submap_exprs...),))) end end -@generated function _from_array!(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, +@generated function _from_array(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} - perm = sortperm(Addrs) + perm = sortperm(collect(Addrs)) sorted_addrs = Addrs[perm] submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) @@ -112,7 +128,7 @@ end submap_var_names[idx] = submap_var_name push!(exprs, quote - (n_read, submap_var_name = _from_array(proto_choices.submaps[$addr], arr, idx) + (n_read, $submap_var_name) = _from_array(proto_choices.submaps.$addr, arr, idx) idx += n_read end ) @@ -120,14 +136,14 @@ end quote $(exprs...) - submaps = NamedTuple{Addrs}(( $(submap_var_names...) )) - choices = StaticChoiceMap{Addrs, SubmapTypes}(submaps) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticChoiceMap(submaps) (idx - start_idx, choices) end end function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} - StaticAddressSchema(set(Addrs)) + StaticAddressSchema(Set(Addrs)) end export StaticChoiceMap diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index c6f09374..73f22159 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -124,42 +124,33 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end +all_visited(::Selection, ::ValueChoiceMap) = false +all_visited(::AllSelection, ::ValueChoiceMap) = true function all_visited(visited::Selection, choices::ChoiceMap) - allvisited = true - for (key, _) in get_values_shallow(choices) - allvisited = allvisited && (key in visited) - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - allvisited = allvisited && all_visited(subvisited, submap) + if !all_visited(visited[key], submap) + return false end end - allvisited + return true end +get_unvisited(::Selection, v::ValueChoiceMap) = v +get_unvisited(::AllSelection, v::ValueChoiceMap) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() - for (key, _) in get_values_shallow(choices) - if !(key in visited) - set_value!(unvisited, key, get_value(choices, key)) - end - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - sub_unvisited = get_unvisited(subvisited, submap) - set_submap!(unvisited, key, sub_unvisited) - end + sub_unvisited = get_unvisited(visited[key], submap) + set_submap!(unvisited, key, sub_unvisited) end unvisited end get_visited(visitor) = visitor.visited -function check_no_submap(constraints::ChoiceMap, addr) +function check_is_empty(constraints::ChoiceMap, addr) if !isempty(get_submap(constraints, addr)) - error("Expected a value at address $addr but found a sub-assignment") + error("Expected a value or EmptyChoiceMap at address $addr but found a sub-assignment") end end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index df6a5f46..970dac42 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -20,7 +20,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # get return value if constrained diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb..882297e4 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -119,9 +119,6 @@ struct DynamicDSLChoiceMap <: ChoiceMap end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -Base.isempty(::DynamicDSLChoiceMap) = false # TODO not necessarily true -has_value(choices::DynamicDSLChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::DynamicDSLChoiceMap, addr::Pair) = _get_value(choices, addr) get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_submap(choices::DynamicDSLChoiceMap, addr) @@ -130,9 +127,10 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) # leaf node, must be a call call = trie[addr] if call.is_choice - throw(KeyError(addr)) + ValueChoiceMap(call.subtrace_or_retval) + else + get_choices(call.subtrace_or_retval) end - get_choices(call.subtrace_or_retval) elseif has_internal_node(trie, addr) # internal node subtrie = get_internal_node(trie, addr) @@ -142,32 +140,12 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) end end -function has_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - has_leaf_node(trie, addr) && trie[addr].is_choice -end - -function get_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - choice = trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - choice.subtrace_or_retval -end - -function get_values_shallow(choices::DynamicDSLChoiceMap) - ((key, choice.subtrace_or_retval) - for (key, choice) in get_leaf_nodes(choices.trie) - if choice.is_choice) -end - function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ((key, get_choices(call.subtrace_or_retval)) + calls_iter = ( + (key, call.is_choice ? ValueChoiceMap(call.subtrace_or_retval) : get_choices(call.subtrace_or_retval)) for (key, call) in get_leaf_nodes(choices.trie) - if !call.is_choice) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) - for (key, trie) in get_internal_nodes(choices.trie)) + ) + internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) Iterators.flatten((calls_iter, internal_nodes_iter)) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f2..7acc1630 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -35,7 +35,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # record the previous value as discarded if it is replaced if constrained && has_previous @@ -149,32 +149,22 @@ end function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) - for (key, value) in get_values_shallow(prev_choices) + for (key, submap) in get_submaps_shallow(prev_choices) + # if key IS in visited, + # the recursive call to update already handled the discard + # for this entire submap; else we need to handle it if !(key in visited) - @assert !has_value(discard, key) @assert isempty(get_submap(discard, key)) - set_value!(discard, key, value) - end - end - for (key, submap) in get_submaps_shallow(prev_choices) - @assert !has_value(discard, key) - if key in visited - # the recursive call to update already handled the discard - # for this entire submap - continue - else subvisited = visited[key] if isempty(subvisited) # none of this submap was visited, so we discard the whole thing - @assert isempty(get_submap(discard, key)) set_submap!(discard, key, submap) else subdiscard = get_submap(discard, key) - add_unvisited_to_discard!( - isempty(subdiscard) ? choicemap() : subdiscard, - subvisited, submap) + subdiscard = isempty(subdiscard) ? choicemap() : subdiscard + add_unvisited_to_discard!(subdiscard, subvisited, submap) set_submap!(discard, key, subdiscard) - end + end end end end diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index a231f03a..d662dbb7 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -1,12 +1,13 @@ import MacroTools function check_observations(choices::ChoiceMap, observations::ChoiceMap) - for (key, value) in get_values_shallow(observations) - !has_value(choices, key) && error("Check failed: observed choice at $key not found") - choices[key] != value && error("Check failed: value of observed choice at $key changed") - end for (key, submap) in get_submaps_shallow(observations) - check_observations(get_submap(choices, key), submap) + if has_value(submap) + !has_value(choices, key) && error("Check failed: observed choice at $key not found") + choices[key] != value && error("Check failed: value of observed choice at $key changed") + else + check_observations(get_submap(choices, key), submap) + end end end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 23411697..f17d061f 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -14,10 +14,7 @@ function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} end get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_value(choices::CallAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::CallAtChoiceMap, addr::Pair) = _has_value(choices, addr) get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) -get_values_shallow(::CallAtChoiceMap) = () # TODO optimize CallAtTrace using type parameters @@ -69,7 +66,7 @@ unpack_call_at_args(args) = (args[end], args[1:end-1]) function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) - if length(get_submaps_shallow(choices)) > 1 || length(get_values_shallow(choices)) > 0 + if length(get_submaps_shallow(choices)) > 1 error("Not all constraints were consumed") end submap = get_submap(choices, key) diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 69bb4851..f3875895 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -25,10 +25,12 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} end get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) +get_submap(choices::ChoiceAtChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} choices.key == addr ? choices.value : throw(KeyError(choices, addr)) end -get_submaps_shallow(choices::ChoiceAtChoiceMap) = () +get_submap(choices::ChoiceAtChoiceMap, addr) = addr == choices.key ? ValueChoiceMap(choices.value) : EmptyChoiceMap() +get_submaps_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, ValueChoiceMap(choices.value)),) get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 71580073..1f101725 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -84,17 +84,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, addr::Pair) - _get_submap(choices, addr) -end - -function has_value(choices::RecurseTraceChoiceMap, addr::Pair) - _has_value(choices, addr) -end - -function get_value(choices::RecurseTraceChoiceMap, addr::Pair) - _get_value(choices, addr) -end +get_submap(choices::RecurseTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () @@ -333,6 +323,9 @@ function recurse_unpack_constraints(constraints::ChoiceMap) production_constraints = Dict{Int, Any}() aggregation_constraints = Dict{Int, Any}() for (addr, node) in get_submaps_shallow(constraints) + if has_value(node) + error("Unknown address: $(addr)") + end idx::Int = addr[1] if addr[2] == Val(:production) production_constraints[idx] = node @@ -342,9 +335,6 @@ function recurse_unpack_constraints(constraints::ChoiceMap) error("Unknown address: $addr") end end - if length(get_values_shallow(constraints)) > 0 - error("Unknown address: $(first(get_values_shallow(constraints))[1])") - end return (production_constraints, aggregation_constraints) end diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 9b0eb763..3af416ef 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -92,10 +92,6 @@ end end @inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_value(choices::VectorTraceChoiceMap, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::VectorTraceChoiceMap, addr::Pair) = _has_value(choices, addr) -@inline get_values_shallow(::VectorTraceChoiceMap) = () - ############################################ # code shared by vector-shaped combinators # diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 7a0fe384..b352d3ca 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -330,21 +330,22 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) - leaf_gradients = map((node) -> gradient_var(node), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(trace.$(get_value_fieldname(node)))), selected_choices_vec) + leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) - internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), + internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) + internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) + + quoted_all_keys = Iterators.flatten((quoted_leaf_keys, quoted_internal_keys)) + all_value_choicemaps = Iterators.flatten((leaf_value_choicemaps, internal_value_choicemaps)) + all_gradient_choicemaps = Iterators.flatten((leaf_gradient_choicemaps, internal_gradient_choicemaps)) + quote - $value_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_gradient_choicemaps...),))) end end diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 713c0863..5ac3ced1 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -9,25 +9,8 @@ end function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) - @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) - -@inline static_has_value(choices::StaticIRTraceAssmt, key) = false - -@inline function get_value(choices::StaticIRTraceAssmt, key::Symbol) - static_get_value(choices, Val(key)) -end - -@inline function has_value(choices::StaticIRTraceAssmt, key::Symbol) - static_has_value(choices, Val(key)) -end - -@inline function get_submap(choices::StaticIRTraceAssmt, key::Symbol) - static_get_submap(choices, Val(key)) -end - -@inline get_value(choices::StaticIRTraceAssmt, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::StaticIRTraceAssmt, addr::Pair) = _has_value(choices, addr) +@inline get_submap(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) @inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) ######################### @@ -36,16 +19,13 @@ end abstract type StaticIRTrace <: Trace end -@inline function static_get_subtrace(trace::StaticIRTrace, addr) - error("Not implemented") -end +@inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline function Base.getindex(trace::StaticIRTrace, addr) - Gen.static_getindex(trace, Val(addr)) -end +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] @@ -161,21 +141,13 @@ function generate_get_choices(trace_struct_name::Symbol) :($(QuoteNode(EmptyChoiceMap))()))) end -function generate_get_values_shallow(ir::StaticIR, trace_struct_name::Symbol) +function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes addr = node.addr value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), $value))) + push!(elements, :(($(QuoteNode(addr)), ValueChoiceMap($value)))) end - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_values_shallow)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), - Expr(:block, Expr(:tuple, elements...))) -end - -function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) - elements = [] for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) @@ -224,30 +196,6 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] end -function generate_static_get_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(choices.trace.$(get_value_fieldname(node)))))) - end - methods -end - -function generate_static_has_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_has_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(true)))) - end - methods -end - function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes @@ -259,13 +207,13 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) end - # throw a KeyError if get_submap is run on an address containing a value + # return a ValueChoiceMap if get_submap is run on an address containing a value for node in ir.choice_nodes push!(methods, Expr(:function, Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(throw(KeyError($(QuoteNode(node.addr)))))))) + Expr(:block, :(ValueChoiceMap(choices.trace.$(get_value_fieldname(node))))))) end methods end @@ -290,18 +238,13 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_retval_expr = generate_get_retval(ir, trace_struct_name) get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) - static_get_value_exprs = generate_static_get_value(ir, trace_struct_name) - static_has_value_exprs = generate_static_has_value(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_values_shallow_expr, - get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + get_choices_expr, get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index dc4fddf3..c806bba3 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -454,9 +454,10 @@ function generate_discard!(stmts::Vector{Expr}, end leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - expr = :($(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(leaf_keys...),)}(($(leaf_nodes...),)), - $(QuoteNode(NamedTuple)){($(internal_keys...),)}(($(internal_nodes...),)))) + all_keys = (leaf_keys..., internal_keys...) + all_nodes = ([:($(QuoteNode(ValueChoiceMap))($node)) for node in leaf_nodes]..., internal_nodes...) + expr = quote $(QuoteNode(StaticChoiceMap))( + $(QuoteNode(NamedTuple)){($(all_keys...),)}(($(all_nodes...),))) end push!(stmts, :($discard = $expr)) end diff --git a/test/assignment.jl b/test/assignment.jl index 1bba754a..1d7e48a8 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,6 +1,46 @@ +@testset "ValueChoiceMap" begin + vcm1 = ValueChoiceMap(2) + vcm2 = ValueChoiceMap(2.) + vcm3 = ValueChoiceMap([1,2]) + @test vcm1 isa ValueChoiceMap{Int} + @test vcm2 isa ValueChoiceMap{Float64} + @test vcm3 isa ValueChoiceMap{Vector{Int}} + + @test !isempty(vcm1) + @test has_value(vcm1) + @test get_value(vcm1) == 2 + @test vcm1 == vcm2 + @test isempty(get_submaps_shallow(vcm1)) + @test isempty(get_values_shallow(vcm1)) + @test isempty(get_nonvalue_submaps_shallow(vcm1)) + @test to_array(vcm1, Int) == [2] + @test from_array(vcm1, [4]) == ValueChoiceMap(4) + @test from_array(vcm3, [4, 5]) == ValueChoiceMap([4, 5]) + @test_throws Exception merge(vcm1, vcm2) + @test_throws Exception merge(vcm1, choicemap(:a, 5)) + @test merge(vcm1, EmptyChoiceMap()) == vcm1 + @test merge(EmptyChoiceMap(), vcm1) == vcm1 + @test get_submap(vcm1, :addr) == EmptyChoiceMap() + @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) + @test !has_value(vcm1, :addr) + @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) + @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) + @test get_address_schema(typeof(vcm1)) == EmptyAddressSchema() + @test get_address_schema(ValueChoiceMap) == EmptyAddressSchema() + @test nested_view(vcm1) == 2 +end + +@testset "static choicemap constructor" begin + @test StaticChoiceMap((a=ValueChoiceMap(5), b=ValueChoiceMap(6))) == StaticChoiceMap(a=5, b=6) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + @test submap == StaticChoiceMap((a=ValueChoiceMap(1.), b=ValueChoiceMap([2., 2.5]))) + outer = StaticChoiceMap(c=3, d=submap, e=submap) + @test outer == StaticChoiceMap((c=ValueChoiceMap(3), d=submap, e=submap)) +end + @testset "static assignment to/from array" begin - submap = StaticChoiceMap((a=1., b=[2., 2.5]),NamedTuple()) - outer = StaticChoiceMap((c=3.,), (d=submap, e=submap)) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + outer = StaticChoiceMap(c=3., d=submap, e=submap) arr = to_array(outer, Float64) @test to_array(outer, Float64) == Float64[3.0, 1.0, 2.0, 2.5, 1.0, 2.0, 2.5] @@ -11,14 +51,16 @@ @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment to/from array" begin @@ -39,14 +81,18 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test get_submap(choices, :c) == ValueChoiceMap(1.0) + @test get_submap(choices, :d => :b) == ValueChoiceMap([3.0, 4.0]) + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment copy constructor" begin @@ -64,25 +110,6 @@ end @test choices[:u => :w] == 4 end -@testset "internal vector assignment to/from array" begin - inner = choicemap() - set_value!(inner, :a, 1.) - set_value!(inner, :b, 2.) - outer = vectorize_internal([inner, inner, inner]) - - arr = to_array(outer, Float64) - @test to_array(outer, Float64) == Float64[1, 2, 1, 2, 1, 2] - - choices = from_array(outer, Float64[1, 2, 3, 4, 5, 6]) - @test choices[1 => :a] == 1.0 - @test choices[1 => :b] == 2.0 - @test choices[2 => :a] == 3.0 - @test choices[2 => :b] == 4.0 - @test choices[3 => :a] == 5.0 - @test choices[3 => :b] == 6.0 - @test length(collect(get_submaps_shallow(choices))) == 3 -end - @testset "dynamic assignment merge" begin submap = choicemap() set_value!(submap, :x, 1) @@ -107,7 +134,7 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @@ -125,8 +152,8 @@ end set_value!(submap, :x, 1) submap2 = choicemap() set_value!(submap2, :y, 4.) - choices1 = StaticChoiceMap((a=1., b=2.), (c=submap, shared=submap)) - choices2 = StaticChoiceMap((d=3.,), (e=submap, f=submap, shared=submap2)) + choices1 = StaticChoiceMap(a=1., b=2., c=submap, shared=submap) + choices2 = StaticChoiceMap(d=3., e=submap, f=submap, shared=submap2) choices = merge(choices1, choices2) @test choices[:a] == 1. @test choices[:b] == 2. @@ -136,124 +163,91 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @testset "static assignment variadic merge" begin - choices1 = StaticChoiceMap((a=1,), NamedTuple()) - choices2 = StaticChoiceMap((b=2,), NamedTuple()) - choices3 = StaticChoiceMap((c=3,), NamedTuple()) - choices_all = StaticChoiceMap((a=1, b=2, c=3), NamedTuple()) + choices1 = StaticChoiceMap(a=1) + choices2 = StaticChoiceMap(b=2) + choices3 = StaticChoiceMap(c=3) + choices_all = StaticChoiceMap(a=1, b=2, c=3) @test merge(choices1) == choices1 @test merge(choices1, choices2, choices3) == choices_all end +# TODO: in changing a lot of these to reflect the new behavior of choicemap, +# they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; +# should we relabel this testset? @testset "static assignment errors" begin + # get_choices on an address that returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x) == ValueChoiceMap(1) + + # static_get_submap on an address that contains a value returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test static_get_submap(choices, Val(:x)) == ValueChoiceMap(1) - # get_choices on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw - - # static_get_submap on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # get_submap on an address whose prefix contains a value returns EmptyChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) + choices = StaticChoiceMap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # static_get_choices on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_choices on an address that contains nothing returns an EmptyChoiceMap + choices = StaticChoiceMap() + @test static_get_submap(choices, Val(:x)) == EmptyChoiceMap() - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # static_get_value on an address that contains a submap throws a KeyError + # static_get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) + + # get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) + + # static_get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) end @testset "dynamic assignment errors" begin - - # get_choices on an address that contains a value throws a KeyError + # get_choices on an address that contains a value returns a ValueChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(1) - # get_choices on an address whose prefix contains a value throws a KeyError + # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment choices = choicemap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError choices = choicemap() choices[:x => :y] = 1 - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # get_value on an address that contains nothing throws a KeyError + # get_value on an address that contains nothing throws a ChoiceMapGetValueError choices = choicemap() - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) end @testset "dynamic assignment overwrite" begin @@ -276,9 +270,7 @@ end choices = choicemap() choices[:x => :y] = 1 choices[:x] = 2 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment @@ -293,17 +285,13 @@ end # illegal set value under existing value choices = choicemap() choices[:x] = 1 - threw = false - try set_value!(choices, :x => :y, 2) catch KeyError threw = true end - @test threw + @test_throws Exception set_value!(choices, :x => :y, 2) # illegal set submap under existing value choices = choicemap() choices[:x] = 1 submap = choicemap(); choices[:z] = 2 - threw = false - try set_submap!(choices, :x => :y, submap) catch KeyError threw = true end - @test threw + @test_throws Exception set_submap!(choices, :x => :y, submap) end @testset "dynamic assignment constructor" begin diff --git a/test/benchmark.md b/test/benchmark.md new file mode 100644 index 00000000..adabb8a5 --- /dev/null +++ b/test/benchmark.md @@ -0,0 +1,21 @@ +NEW version: +static choicemap nonnested lookup: + 0.728112 seconds (149.59 k allocations: 4.259 MiB) + 0.785652 seconds (100.00 k allocations: 1.526 MiB) + 0.693433 seconds (100.00 k allocations: 1.526 MiB) + 0.660211 seconds (100.00 k allocations: 1.526 MiB) +static choicemap nested lookup: + 0.680497 seconds (49.59 k allocations: 2.732 MiB) + 0.665768 seconds (1 allocation: 32 bytes) + 0.666708 seconds (1 allocation: 32 bytes) + 0.671009 seconds (1 allocation: 32 bytes) +static gen function choicemap nonnested lookup: + 0.701754 seconds (62.76 k allocations: 3.415 MiB) + 0.662916 seconds + 0.659019 seconds + 0.663398 seconds +static gen function choicemap nested lookup: + 1.338034 seconds (172.13 k allocations: 5.352 MiB) + 1.311123 seconds (100.00 k allocations: 1.526 MiB) + 1.311800 seconds (100.00 k allocations: 1.526 MiB) + 1.310289 seconds (100.00 k allocations: 1.526 MiB) \ No newline at end of file diff --git a/test/dynamic_dsl.jl b/test/dynamic_dsl.jl index 35f81703..5561ae54 100644 --- a/test/dynamic_dsl.jl +++ b/test/dynamic_dsl.jl @@ -119,7 +119,7 @@ end @test get_value(discard, :x) == x @test get_value(discard, :u => :a) == a @test length(collect(get_values_shallow(discard))) == 2 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 # test new trace new_assignment = get_choices(new_trace) @@ -127,7 +127,7 @@ end @test get_value(new_assignment, :y) == y @test get_value(new_assignment, :v => :b) == b @test length(collect(get_values_shallow(new_assignment))) == 2 - @test length(collect(get_submaps_shallow(new_assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(new_assignment))) == 1 # test score and weight prev_score = ( @@ -242,7 +242,7 @@ end @test !isempty(get_submap(assignment, :v)) end @test length(collect(get_values_shallow(assignment))) == 2 - @test length(collect(get_submaps_shallow(assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(assignment))) == 1 # test weight if assignment[:branch] == prev_assignment[:branch] @@ -332,11 +332,11 @@ end @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @test !has_value(choices, :b) # was not selected - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 # check gradient trie - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected @test isapprox(get_value(gradients, :bar => :z), @@ -431,14 +431,14 @@ end @test choices[:x => 2] == 2 @test choices[:x => 3 => :z] == 3 @test length(collect(get_values_shallow(choices))) == 1 # :y - @test length(collect(get_submaps_shallow(choices))) == 1 # :x + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # :x submap = get_submap(choices, :x) @test submap[1] == 1 @test submap[2] == 2 @test submap[3 => :z] == 3 @test length(collect(get_values_shallow(submap))) == 2 # 1, 2 - @test length(collect(get_submaps_shallow(submap))) == 1 # 3 + @test length(collect(get_nonvalue_submaps_shallow(submap))) == 1 # 3 bar_submap = get_submap(submap, 3) @test bar_submap[:z] == 3 diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130..607eb61f 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -20,7 +20,7 @@ y = choices[3 => :y] @test isapprox(weight, logpdf(normal, y, 0.4, 1)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end @testset "generate" begin @@ -32,7 +32,7 @@ y = choices[3 => :y] @test get_retval(trace) == 0.4 + y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # with constraints y = 1.234 @@ -44,7 +44,7 @@ @test get_retval(trace) == 0.4 + y @test isapprox(weight, logpdf(normal, y, 0.4, 1.)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end function get_trace() @@ -71,7 +71,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isempty(discard) @@ -86,12 +86,12 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) # change kernel_args, different key, with constraint @@ -103,12 +103,12 @@ choices = get_choices(new_trace) @test choices[4 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) end @@ -121,7 +121,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isapprox(get_score(new_trace), logpdf(normal, y, 0.2, 1)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) y_new = choices[3 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -144,7 +144,7 @@ choices = get_choices(new_trace) y_new = choices[4 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -171,9 +171,9 @@ @test choices[3 => :y] == y @test isapprox(gradients[3 => :y], logpdf_grad(normal, y, 0.4, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 0 - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(input_grads) == 2 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.4, 1.0)[2] + retval_grad) @test input_grads[2] == nothing # the key has no gradient diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b46..4f524138 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -15,7 +15,7 @@ @test isapprox(weight, value ? log(0.4) : log(0.6)) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end @testset "generate" begin @@ -27,7 +27,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 # with constraints constraints = choicemap() @@ -39,7 +39,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end function get_trace() @@ -65,7 +65,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isempty(discard) @@ -78,12 +78,12 @@ choices = get_choices(new_trace) @test choices[3] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 # change kernel_args, different key, with constraint constraints = choicemap() @@ -93,12 +93,12 @@ choices = get_choices(new_trace) @test choices[4] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 end @testset "regenerate" begin @@ -110,7 +110,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isapprox(get_score(new_trace), log(0.2)) @@ -122,7 +122,7 @@ choices = get_choices(new_trace) value = choices[3] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) value = choices[4] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -163,9 +163,9 @@ @test choices[3] == y @test isapprox(gradients[3], logpdf_grad(normal, y, 0.0, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 1 - @test length(collect(get_submaps_shallow(gradients))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 0 @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test length(input_grads) == 3 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.0, 1.0)[2]) @test isapprox(input_grads[2], logpdf_grad(normal, y, 0.0, 1.0)[3]) diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 46954e3b..b440a44f 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -197,9 +197,9 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false @test discard[(3, Val(:aggregation)) => :prefix] == true - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 @test length(collect(get_values_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 1 @test retdiff == UnknownChange() diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index ba748453..0f3a5618 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -28,7 +28,7 @@ x3 = trace[3 => :x] choices = get_choices(trace) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 expected_score = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x2, x1 * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -55,7 +55,7 @@ @test choices[1 => :x] == x1 @test choices[3 => :x] == x3 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x2 = choices[2 => :x] expected_weight = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -77,7 +77,7 @@ beta = 0.3 (choices, weight, retval) = propose(foo, (3, x_init, alpha, beta)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x1 = choices[1 => :x] x2 = choices[2 => :x] x3 = choices[3 => :x] diff --git a/test/optional_args.jl b/test/optional_args.jl index fd6c4ea7..b0fb821b 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -1,4 +1,4 @@ -using Gen +#using Gen @testset "optional positional args (calling + GFI)" begin diff --git a/test/runtests.jl b/test/runtests.jl index a67a5f78..74923603 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,4 +74,4 @@ include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 91c6c320..9e2cecf3 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -1,4 +1,6 @@ using Gen: generate_generative_function +using Test +using Gen @testset "static IR" begin @@ -362,12 +364,12 @@ end @test get_value(value_trie, :out) == out @test get_value(value_trie, :bar => :z) == z @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(value_trie))) == 1 + @test length(collect(get_values_shallow(value_trie))) == 2 # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(gradient_trie))) == 1 + @test length(collect(get_values_shallow(gradient_trie))) == 2 @test !has_value(gradient_trie, :b) # was not selected @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index fbd528b7..8396fe51 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,4 +1,4 @@ -using Gen +using .Gen import MacroTools normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) From 83349c7d4a320e028c9b24e26da4c3b44066fce9 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:13:03 -0400 Subject: [PATCH 05/34] performance improvements and benchmarking --- src/choice_map/choice_map.jl | 8 ++--- src/choice_map/static_choice_map.jl | 6 ++-- src/static_ir/trace.jl | 15 +++------ test/static_choicemap_benchmark.jl | 50 +++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 18 deletions(-) create mode 100644 test/static_choicemap_benchmark.jl diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index 402cefa3..213bc5f8 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -71,8 +71,8 @@ A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) -get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) """ @@ -145,8 +145,8 @@ end @inline get_value(choices::ValueChoiceMap) = choices.val @inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() @inline get_submaps_shallow(choices::ValueChoiceMap) = () -Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val -Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +@inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) @inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() """ diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 1f75b3bc..58ef57d3 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -26,10 +26,10 @@ end quote EmptyChoiceMap() end end end -static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() +@inline static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() -static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) -static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) +@inline static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +@inline static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 5ac3ced1..168ccf50 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -23,7 +23,7 @@ abstract type StaticIRTrace <: Trace end @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false - Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) +@inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) @inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) @@ -31,6 +31,8 @@ abstract type StaticIRTrace <: Trace end return Gen.static_get_subtrace(trace, Val(first))[rest] end +@inline get_choices(trace::T) where {T <: StaticIRTrace} = StaticIRTraceAssmt{T}(trace) + const arg_prefix = gensym("arg") const choice_value_prefix = gensym("choice_value") const choice_score_prefix = gensym("choice_score") @@ -133,14 +135,6 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_choices(trace_struct_name::Symbol) - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_choices)), :(trace::$trace_struct_name)), - Expr(:if, :(!isempty(trace)), - :($(QuoteNode(StaticIRTraceAssmt))(trace)), - :($(QuoteNode(EmptyChoiceMap))()))) -end - function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes @@ -236,7 +230,6 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) - get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) @@ -244,7 +237,7 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) + get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/test/static_choicemap_benchmark.jl b/test/static_choicemap_benchmark.jl new file mode 100644 index 00000000..1e62b9a8 --- /dev/null +++ b/test/static_choicemap_benchmark.jl @@ -0,0 +1,50 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +scm = StaticChoiceMap(a=1, b=StaticChoiceMap(c=2)) + +println("static choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(scm) +end + +println("static choicemap nested lookup:") +for _=1:4 + @time many_nested(scm) +end + +@gen (static) function inner() + c ~ normal(0, 1) +end +@gen (static) function outer() + a ~ normal(0, 1) + b ~ inner() +end + +load_generated_functions() + +tr, _ = generate(outer, ()) +choices = get_choices(tr) + +println("static gen function choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(choices) +end + +println("static gen function choicemap nested lookup:") +for _=1:4 + @time many_nested(choices) +end From b9b5312e990fc49b08611b7077b7c6f3aa5d99ee Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:21:45 -0400 Subject: [PATCH 06/34] benchmark for dynamic choicemap lookups --- test/dynamic_choicemap_benchmark.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 test/dynamic_choicemap_benchmark.jl diff --git a/test/dynamic_choicemap_benchmark.jl b/test/dynamic_choicemap_benchmark.jl new file mode 100644 index 00000000..3724e44d --- /dev/null +++ b/test/dynamic_choicemap_benchmark.jl @@ -0,0 +1,27 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +cm = choicemap((:a, 1), (:b => :c, 2)) + +println("dynamic choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(cm) +end + +println("dynamic choicemap nested lookup:") +for _=1:4 + @time many_nested(cm) +end \ No newline at end of file From bce5e7724db64175bf2fd0f15fe25a4dc68af13e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:30:40 -0400 Subject: [PATCH 07/34] inline dynamicchoicemap methods --- src/choice_map/dynamic_choice_map.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/choice_map/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl index a3403307..0f27c89d 100644 --- a/src/choice_map/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -67,16 +67,10 @@ function choicemap(tuples...) DynamicChoiceMap(tuples...) end -get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.submaps, addr) - choices.submaps[addr] - else - EmptyChoiceMap() - end -end -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) -Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) +@inline get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +@inline get_submap(choices::DynamicChoiceMap, addr) = get(choices.submaps, addr, EmptyChoiceMap()) +@inline get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) # mutation (not part of the assignment interface) From a985f9bd3dc3f8806e2da1e7c81fbe891334bac9 Mon Sep 17 00:00:00 2001 From: georgematheos Date: Tue, 19 May 2020 09:13:32 -0400 Subject: [PATCH 08/34] remove old version benchmark file --- test/benchmark.md | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 test/benchmark.md diff --git a/test/benchmark.md b/test/benchmark.md deleted file mode 100644 index adabb8a5..00000000 --- a/test/benchmark.md +++ /dev/null @@ -1,21 +0,0 @@ -NEW version: -static choicemap nonnested lookup: - 0.728112 seconds (149.59 k allocations: 4.259 MiB) - 0.785652 seconds (100.00 k allocations: 1.526 MiB) - 0.693433 seconds (100.00 k allocations: 1.526 MiB) - 0.660211 seconds (100.00 k allocations: 1.526 MiB) -static choicemap nested lookup: - 0.680497 seconds (49.59 k allocations: 2.732 MiB) - 0.665768 seconds (1 allocation: 32 bytes) - 0.666708 seconds (1 allocation: 32 bytes) - 0.671009 seconds (1 allocation: 32 bytes) -static gen function choicemap nonnested lookup: - 0.701754 seconds (62.76 k allocations: 3.415 MiB) - 0.662916 seconds - 0.659019 seconds - 0.663398 seconds -static gen function choicemap nested lookup: - 1.338034 seconds (172.13 k allocations: 5.352 MiB) - 1.311123 seconds (100.00 k allocations: 1.526 MiB) - 1.311800 seconds (100.00 k allocations: 1.526 MiB) - 1.310289 seconds (100.00 k allocations: 1.526 MiB) \ No newline at end of file From 1f5029cfc1637d4d3ac257cd46835312131c6ee2 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 09:15:55 -0400 Subject: [PATCH 09/34] minor testing cleanup --- test/optional_args.jl | 2 +- test/static_inference_benchmark.jl | 23 +++++++++++++++++++++++ test/static_ir/static_ir.jl | 2 -- test/tilde_sugar.jl | 2 +- 4 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 test/static_inference_benchmark.jl diff --git a/test/optional_args.jl b/test/optional_args.jl index b0fb821b..fd6c4ea7 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -1,4 +1,4 @@ -#using Gen +using Gen @testset "optional positional args (calling + GFI)" begin diff --git a/test/static_inference_benchmark.jl b/test/static_inference_benchmark.jl new file mode 100644 index 00000000..b70d08be --- /dev/null +++ b/test/static_inference_benchmark.jl @@ -0,0 +1,23 @@ +using Gen + +@gen (static, diffs) function foo() + a ~ normal(0, 1) + b ~ normal(a, 1) + c ~ normal(b, 1) +end + +@load_generated_functions + +observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) +tr, _ = generate(foo, (), observations) + +function run_inference(trace) + tr = trace + for _=1:10^3 + tr, acc = mh(tr, select(:a)) + end +end + +for _=1:4 + @time run_inference(tr) +end \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 9e2cecf3..1b594d39 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -1,6 +1,4 @@ using Gen: generate_generative_function -using Test -using Gen @testset "static IR" begin diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index 8396fe51..fbd528b7 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,4 +1,4 @@ -using .Gen +using Gen import MacroTools normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) From eb6adf7a76c5975fa20d7567a560175588aafed4 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:33:16 -0400 Subject: [PATCH 10/34] ensure valuechoicemap[] syntax works --- test/assignment.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/assignment.jl b/test/assignment.jl index 1d7e48a8..69890297 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -5,6 +5,8 @@ @test vcm1 isa ValueChoiceMap{Int} @test vcm2 isa ValueChoiceMap{Float64} @test vcm3 isa ValueChoiceMap{Vector{Int}} + @test vcm1[] == 2 + @test vcm1[] == get_value(vcm1) @test !isempty(vcm1) @test has_value(vcm1) From eef941776857c50d8ad93ead2ee0d164d60f737e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:43:49 -0400 Subject: [PATCH 11/34] provide some examples in the documentation --- docs/src/ref/choice_maps.md | 38 ++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 6c445df6..bf742f90 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -14,6 +14,42 @@ for sub-choicemaps. Leaf nodes have type: ValueChoiceMap ``` +### Example Usage Overview + +Choicemaps store values nested in a tree where each node posesses an address for each subtree. +A leaf-node choicemap simply contains a value, and has it's value looked up via: +```julia +value = choicemap[] +``` +If a choicemap has a value choicemap at address `:a`, it is looked up via: +```julia +value = choicemap[:a] +``` +And a choicemap may also have a non-value choicemap stored at a value. For instance, +if a choicemap has another choicemap stored at address `:a`, and this internal choicemap +has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: +```julia +value1 = choicemap[:a => :b] +value2 = choicemap[:a => :c] +``` +Nesting can be arbitrarily deep, and the keys can be arbitrary values; for instance +choicemaps can be constructed with values at the following nested addresses: +```julia +value = choicemap[:a => :b => :c => 4 => 1.63 => :e] +value = choicemap[:a => :b => :a => 2 => "alphabet" => :e] +``` +To get a sub-choicemap, use `get_submap`: +```julia +value1 = choicemap[:a => :b] +submap = get_submap(choicemap, :a) +value1 == submap[:b] # is true + +value_submap = get_submap(choicemap, :a => :b) +value_submap[] == value1 # is true +``` + +### Interface + Choice maps provide the following methods: ```@docs get_submap @@ -58,7 +94,7 @@ set_value! set_submap! ``` -## Implementing custom choicemap types +### Implementing custom choicemap types To implement a custom choicemap, one must implement `get_submap` and `get_submaps_shallow`. From a83adfbc2d02bed4e9c0a78163151c742cc660f8 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:50:25 -0400 Subject: [PATCH 12/34] fix some typos --- docs/src/ref/choice_maps.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index bf742f90..2963d304 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -17,15 +17,15 @@ ValueChoiceMap ### Example Usage Overview Choicemaps store values nested in a tree where each node posesses an address for each subtree. -A leaf-node choicemap simply contains a value, and has it's value looked up via: +A leaf-node choicemap simply contains a value, and has its value looked up via: ```julia value = choicemap[] ``` -If a choicemap has a value choicemap at address `:a`, it is looked up via: +If a choicemap has a value choicemap at address `:a`, the value it stores is looked up via: ```julia value = choicemap[:a] ``` -And a choicemap may also have a non-value choicemap stored at a value. For instance, +A choicemap may also have a non-value choicemap stored at an address. For instance, if a choicemap has another choicemap stored at address `:a`, and this internal choicemap has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: ```julia From 1bd705f101bb7c783aedad30fe442f864bcec625 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:54:25 -0400 Subject: [PATCH 13/34] add phrase 'nesting level zero' to docs --- docs/src/ref/choice_maps.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 2963d304..4a23b7cf 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -47,6 +47,8 @@ value1 == submap[:b] # is true value_submap = get_submap(choicemap, :a => :b) value_submap[] == value1 # is true ``` +One can think of `ValueChoiceMap`s at storing being a choicemap which has a value at "nesting level zero", +while other choicemaps have values at "nesting level" one or higher. ### Interface From 676828b0d16872f35ee2a327e2d445df9a449269 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 15:29:17 -0400 Subject: [PATCH 14/34] distribution <: GenFn; dynamic DSL simplification --- src/Gen.jl | 9 +- src/distribution.jl | 123 ++++++++++++++++++++++ src/dynamic/assess.jl | 16 --- src/dynamic/backprop.jl | 4 +- src/dynamic/dynamic.jl | 6 -- src/dynamic/generate.jl | 32 ------ src/dynamic/project.jl | 14 +-- src/dynamic/propose.jl | 19 ---- src/dynamic/regenerate.jl | 58 ++--------- src/dynamic/simulate.jl | 18 ---- src/dynamic/trace.jl | 127 ++++------------------- src/dynamic/update.jl | 64 ++---------- src/modeling_library/modeling_library.jl | 48 --------- 13 files changed, 166 insertions(+), 372 deletions(-) create mode 100644 src/distribution.jl diff --git a/src/Gen.jl b/src/Gen.jl index fa239359..fe3cdaa1 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -42,13 +42,16 @@ include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") +# built-in data types for arg-diff and ret-diff values +include("diff.jl") + # generative function interface include("gen_fn_interface.jl") -# built-in data types for arg-diff and ret-diff values -include("diff.jl") +# distribution abstract type +include("distribution.jl") -# built-in probability disributions +# built-in probability disributions; distribution dsl; combinators include("modeling_library/modeling_library.jl") # optimization of trainable parameters diff --git a/src/distribution.jl b/src/distribution.jl new file mode 100644 index 00000000..d72b21a6 --- /dev/null +++ b/src/distribution.jl @@ -0,0 +1,123 @@ +############################### +# Core Distribution Interface # +############################### + +struct DistributionTrace{T, Dist} <: Trace + val::T + args + dist::Dist +end + +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end + +function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.dist) +end + +""" + val::T = random(dist::Distribution{T}, args...) + +Sample a random choice from the given distribution with the given arguments. +""" +function random end + +""" + lpdf = logpdf(dist::Distribution{T}, value::T, args...) + +Evaluate the log probability (density) of the value. +""" +function logpdf end + +""" + has::Bool = has_output_grad(dist::Distribution) + +Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. +""" +function has_output_grad end + +""" + grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) + +Compute the gradient of the logpdf with respect to the value, and each of the arguments. + +If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. +Otherwise, the first element of the tuple is the gradient with respect to the value. +If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. +Otherwise, this element contains the gradient with respect to the `i`th argument. +""" +function logpdf_grad end + +function is_discrete end + +# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl + +get_return_type(::Distribution{T}) where {T} = T + + +############################## +# Distribution GFI Interface # +############################## + +@inline Base.getindex(trace::DistributionTrace) = trace.val +@inline Gen.get_args(trace::DistributionTrace) = trace.args +@inline Gen.get_choices(trace::DistributionTrace) = ValueChoiceMap(trace.val) # should be able to get type of val +@inline Gen.get_retval(trace::DistributionTrace) = trace.val +@inline Gen.get_gen_fn(trace::DistributionTrace) = trace.dist + +# TODO: for performance would it be better to store the score in the trace? +@inline Gen.get_score(trace::DistributionTrace) = logpdf(trace.dist, trace.val, trace.args...) +@inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. +@inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) + +@inline function Gen.simulate(dist::Distribution, args::Tuple) + val = random(dist, args...) + DistributionTrace(val, args, dist) +end +@inline Gen.generate(dist::Distribution, args::Tuple, ::EmptyChoiceMap) = (simulate(dist, args), 0.) +@inline function Gen.generate(dist::Distribution, args::Tuple, constraints::ValueChoiceMap) + tr = DistributionTrace(get_value(constraints), args, dist) + weight = get_score(tr) + (tr, weight) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::ValueChoiceMap) + new_tr = DistributionTrace(get_value(constraints), args, tr.dist) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, UnknownChange(), get_choices(tr)) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::EmptyChoiceMap) + new_tr = DistributionTrace(tr.val, args, tr.dist) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange(), EmptyChoiceMap()) +end +# TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? +@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) +@inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) + new_tr = DistributionTrace(tr.val, args, tr.dist) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange()) +end +@inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::AllSelection) + new_val = random(tr.dist, args...) + new_tr = DistributionTrace(new_val, args, tr.dist) + (new_tr, 0., UnknownChange()) +end +@inline function Gen.propose(dist::Distribution, args::Tuple) + val = random(dist, args...) + score = logpdf(dist, val, args...) + (ValueChoiceMap(val), score, val) +end +@inline function Gen.assess(dist::Distribution, args::Tuple, choices::ValueChoiceMap) + weight = logpdf(dist, choices.val, args...) + (weight, choices.val) +end + +########### +# Exports # +########### + +export Distribution +export random +export logpdf +export logpdf_grad +export has_output_grad +export is_discrete diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index c583d507..0bf37a07 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -9,22 +9,6 @@ function GFAssessState(choices, params::Dict{Symbol,Any}) GFAssessState(choices, 0., AddressVisitor(), params) end -function traceat(state::GFAssessState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # get return value - retval = get_value(state.choices, key) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 6a7278a0..e870d0c7 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -74,7 +74,7 @@ function traceat(state::GFBackpropParamsState, dist::Distribution{T}, args_maybe_tracked, key) where {T} local retval::T visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval + retval = get_retval(get_call(state.trace, key).subtrace) args = map(value, args_maybe_tracked) score_tracked = track(logpdf(dist, retval, args...), state.tape) record!(state.tape, ReverseDiff.SpecialInstruction, dist, @@ -275,7 +275,7 @@ function traceat(state::GFBackpropTraceState, dist::Distribution{T}, args_maybe_tracked, key) where {T} local retval::T visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval + retval = get_retval(get_call(state.trace, key).subtrace) args = map(value, args_maybe_tracked) score_tracked = track(logpdf(dist, retval, args...), state.tape) if key in state.selection diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 73f22159..0d8e03b4 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -154,12 +154,6 @@ function check_is_empty(constraints::ChoiceMap, addr) end end -function check_no_value(constraints::ChoiceMap, addr) - if has_value(constraints, addr) - error("Expected a sub-assignment at address $addr but found a value") - end -end - function gen_fn_changed_error(addr) error("Generative function changed at address: $addr") end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index 970dac42..4a5796aa 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -11,38 +11,6 @@ function GFGenerateState(gen_fn, args, constraints, params) GFGenerateState(trace, constraints, 0., AddressVisitor(), params) end -function traceat(state::GFGenerateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_is_empty(state.constraints, key) - - # get return value - if constrained - retval = get_value(state.constraints, key) - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - # increment weight - if constrained - state.weight += score - end - - retval -end - function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/project.jl b/src/dynamic/project.jl index 358398e5..81b45b3c 100644 --- a/src/dynamic/project.jl +++ b/src/dynamic/project.jl @@ -1,15 +1,9 @@ -function project_recurse(trie::Trie{Any,ChoiceOrCallRecord}, +function project_recurse(trie::Trie{Any, CallRecord}, selection::Selection) weight = 0. - for (key, choice_or_call) in get_leaf_nodes(trie) - if choice_or_call.is_choice - if key in selection - weight += choice_or_call.score - end - else - subselection = selection[key] - weight += project(choice_or_call.subtrace_or_retval, subselection) - end + for (key, call) in get_leaf_nodes(trie) + subselection = selection[key] + weight += project(call.subtrace, subselection) end for (key, subtrie) in get_internal_nodes(trie) subselection = selection[key] diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index e4281f49..32fc95da 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -9,25 +9,6 @@ function GFProposeState(params::Dict{Symbol,Any}) GFProposeState(choicemap(), 0., AddressVisitor(), params) end -function traceat(state::GFProposeState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # sample return value - retval = random(dist, args...) - - # update assignment - set_value!(state.choices, key, retval) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 81ba8b3c..a4006a6c 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -14,48 +14,6 @@ function GFRegenerateState(gen_fn, args, prev_trace, 0., visitor, params) end -function traceat(state::GFRegenerateState, dist::Distribution{T}, - args, key) where {T} - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check whether the key was selected - in_selection = key in state.selection - - # get return value - if has_previous && in_selection - retval = random(dist, args...) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update weight - if has_previous && !in_selection - state.weight += score - prev_score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local prev_retval::T @@ -101,13 +59,11 @@ function splice(state::GFRegenerateState, gen_fn::DynamicDSLFunction, retval end -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function regenerate_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::EmptySelection) noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !choice_or_call.is_choice - noise += choice_or_call.noise - end + for (key, call) in get_leaf_nodes(prev_trie) + noise += call.noise end for (key, subtrie) in get_internal_nodes(prev_trie) noise += regenerate_delete_recurse(subtrie, EmptySelection()) @@ -115,12 +71,12 @@ function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, noise end -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function regenerate_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::DynamicSelection) noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !(key in visited) && !choice_or_call.is_choice - noise += choice_or_call.noise + for (key, call) in get_leaf_nodes(prev_trie) + if !(key in visited) + noise += call.noise end end for (key, subtrie) in get_internal_nodes(prev_trie) diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 0addd8bf..57f709dc 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -9,24 +9,6 @@ function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) GFSimulateState(trace, AddressVisitor(), params) end -function traceat(state::GFSimulateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - retval = random(dist, args...) - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 882297e4..0a169d73 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -1,96 +1,37 @@ -struct ChoiceRecord{T} - retval::T - score::Float64 -end - struct CallRecord{T} subtrace::T score::Float64 noise::Float64 end -struct ChoiceOrCallRecord{T} - subtrace_or_retval::T - score::Float64 - noise::Float64 # if choice then NaN - is_choice::Bool -end - -function ChoiceRecord(record::ChoiceOrCallRecord) - if !record.is_choice - error("Found call but expected choice") - end - ChoiceRecord(record.subtrace_or_retval, record.score) -end - -function CallRecord(record::ChoiceOrCallRecord) - if record.is_choice - error("Found choice but expected call") - end - CallRecord(record.subtrace_or_retval, record.score, record.noise) -end - mutable struct DynamicDSLTrace{T} <: Trace gen_fn::T - trie::Trie{Any,ChoiceOrCallRecord} - isempty::Bool + trie::Trie{Any,CallRecord} score::Float64 noise::Float64 args::Tuple retval::Any function DynamicDSLTrace{T}(gen_fn::T, args) where {T} - trie = Trie{Any,ChoiceOrCallRecord}() + trie = Trie{Any,CallRecord}() # retval is not known yet - new(gen_fn, trie, true, 0, 0, args) + new(gen_fn, trie, 0, 0, args) end end set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) -function has_choice(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && trace.trie[addr].is_choice -end - -function has_call(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && !trace.trie[addr].is_choice -end - -function get_choice(trace::DynamicDSLTrace, addr) - choice = trace.trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - ChoiceRecord(choice) -end - -function get_call(trace::DynamicDSLTrace, addr) - call = trace.trie[addr] - if call.is_choice - throw(KeyError(addr)) - end - CallRecord(call) -end - -function add_choice!(trace::DynamicDSLTrace, addr, retval, score) - if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. - The same address cannot be reused for multiple random choices.") - end - trace.trie[addr] = ChoiceOrCallRecord(retval, score, NaN, true) - trace.score += score - trace.isempty = false -end +has_call(trace::DynamicDSLTrace, addr) = haskey(trace.trie, addr) +get_call(trace::DynamicDSLTrace, addr) = trace.trie[addr] function add_call!(trace::DynamicDSLTrace, addr, subtrace) if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. + error("Subtrace already present at address $addr. The same address cannot be reused for multiple random choices.") end score = get_score(subtrace) noise = project(subtrace, EmptySelection()) submap = get_choices(subtrace) - trace.isempty = trace.isempty && isempty(submap) - trace.trie[addr] = ChoiceOrCallRecord(subtrace, score, noise, false) + trace.trie[addr] = CallRecord(subtrace, score, noise) trace.score += score trace.noise += noise end @@ -106,47 +47,28 @@ get_gen_fn(trace::DynamicDSLTrace) = trace.gen_fn ## get_choices ## -function get_choices(trace::DynamicDSLTrace) - if !trace.isempty - DynamicDSLChoiceMap(trace.trie) # see below - else - EmptyChoiceMap() - end -end +get_choices(trace::DynamicDSLTrace) = DynamicDSLChoiceMap(trace.trie) struct DynamicDSLChoiceMap <: ChoiceMap - trie::Trie{Any,ChoiceOrCallRecord} + trie::Trie{Any,CallRecord} end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) - function get_submap(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - if has_leaf_node(trie, addr) - # leaf node, must be a call - call = trie[addr] - if call.is_choice - ValueChoiceMap(call.subtrace_or_retval) - else - get_choices(call.subtrace_or_retval) - end - elseif has_internal_node(trie, addr) - # internal node - subtrie = get_internal_node(trie, addr) - DynamicDSLChoiceMap(subtrie) # see below + if haskey(choices.trie.leaf_nodes, addr) + get_choices(choices.trie[addr].subtrace) + elseif haskey(choices.trie.internal_nodes, addr) + DynamicDSLChoiceMap(choices.trie.internal_nodes[addr]) else EmptyChoiceMap() end end function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ( - (key, call.is_choice ? ValueChoiceMap(call.subtrace_or_retval) : get_choices(call.subtrace_or_retval)) - for (key, call) in get_leaf_nodes(choices.trie) - ) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) - Iterators.flatten((calls_iter, internal_nodes_iter)) + leafs = ((key, get_choices(record.subtrace)) for (key, record) in get_leaf_nodes(choices.trie)) + internals = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) + Iterators.flatten((leafs, internals)) end ## Base.getindex ## @@ -154,13 +76,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr::Pair) (first, rest) = addr if haskey(trie.leaf_nodes, first) - choice_or_call = trie.leaf_nodes[first] - if choice_or_call.is_choice - error("Unknown address $addr; random choice at $first") - else - subtrace = choice_or_call.subtrace_or_retval - return subtrace[rest] - end + return trie.leaf_nodes[first].subtrace[rest] elseif haskey(trie.internal_nodes, first) return _getindex(trace, trie.internal_nodes[first], rest) else @@ -170,14 +86,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr) if haskey(trie.leaf_nodes, addr) - choice_or_call = trie.leaf_nodes[addr] - if choice_or_call.is_choice - # the value of the random choice - return choice_or_call.subtrace_or_retval - else - # the return value of the generative function call - return get_retval(choice_or_call.subtrace_or_retval) - end + return get_retval(trie.leaf_nodes[addr].subtrace) else error("No random choice or generative function call at address $addr") end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 7acc1630..94f442ac 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -16,57 +16,6 @@ function GFUpdateState(gen_fn, args, prev_trace, constraints, params) 0., visitor, params, discard) end -function traceat(state::GFUpdateState, dist::Distribution{T}, - args::Tuple, key) where {T} - - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_is_empty(state.constraints, key) - - # record the previous value as discarded if it is replaced - if constrained && has_previous - set_value!(state.discard, key, prev_retval) - end - - # get return value - if constrained - retval = get_value(state.constraints, key) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update the weight - if has_previous - state.weight += score - prev_score - elseif constrained - state.weight += score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, args::Tuple, key) where {T,U} @@ -78,7 +27,6 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # check for constraints at this key - check_no_value(state.constraints, key) constraints = get_submap(state.constraints, key) # get subtrace @@ -119,11 +67,11 @@ function splice(state::GFUpdateState, gen_fn::DynamicDSLFunction, retval end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::EmptySelection) score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - score += choice_or_call.score + for (key, call) in get_leaf_nodes(prev_trie) + score += call.score end for (key, subtrie) in get_internal_nodes(prev_trie) score += update_delete_recurse(subtrie, EmptySelection()) @@ -131,12 +79,12 @@ function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, score end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, visited::DynamicSelection) score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) + for (key, call) in get_leaf_nodes(prev_trie) if !(key in visited) - score += choice_or_call.score + score += call.score end end for (key, subtrie) in get_internal_nodes(prev_trie) diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index d0797426..6200e5b5 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -5,54 +5,6 @@ import Distributions using SpecialFunctions: loggamma, logbeta, digamma -abstract type Distribution{T} end - -""" - val::T = random(dist::Distribution{T}, args...) - -Sample a random choice from the given distribution with the given arguments. -""" -function random end - -""" - lpdf = logpdf(dist::Distribution{T}, value::T, args...) - -Evaluate the log probability (density) of the value. -""" -function logpdf end - -""" - has::Bool = has_output_grad(dist::Distribution) - -Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. -""" -function has_output_grad end - -""" - grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) - -Compute the gradient of the logpdf with respect to the value, and each of the arguments. - -If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. -Otherwise, the first element of the tuple is the gradient with respect to the value. -If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. -Otherwise, this element contains the gradient with respect to the `i`th argument. -""" -function logpdf_grad end - -function is_discrete end - -# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl - -get_return_type(::Distribution{T}) where {T} = T - -export Distribution -export random -export logpdf -export logpdf_grad -export has_output_grad -export is_discrete - # built-in distributions include("distributions/distributions.jl") From 5bf4207c6fb7ae23208b7559ba866b9fdb1bb717 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 21:37:06 -0400 Subject: [PATCH 15/34] simplify static ir code --- src/address.jl | 3 +- src/choice_map/choice_map.jl | 2 +- src/distribution.jl | 3 + src/static_ir/backprop.jl | 230 +++++++++++++++++------------------ src/static_ir/dag.jl | 34 +----- src/static_ir/generate.jl | 21 ---- src/static_ir/project.jl | 8 -- src/static_ir/render_ir.jl | 11 +- src/static_ir/simulate.jl | 13 -- src/static_ir/trace.jl | 44 +------ src/static_ir/update.jl | 203 +++++-------------------------- test/runtests.jl | 12 +- test/static_dsl.jl | 36 +++--- 13 files changed, 186 insertions(+), 434 deletions(-) diff --git a/src/address.jl b/src/address.jl index 2d6499a6..ad33cfe7 100644 --- a/src/address.jl +++ b/src/address.jl @@ -151,6 +151,7 @@ A hierarchical selection whose keys are among its type parameters. struct StaticSelection{T,U} <: HierarchicalSelection subselections::NamedTuple{T,U} end +StaticSelection(::NamedTuple{(), Tuple{}}) = EmptySelection() function Base.isempty(selection::StaticSelection{T,U}) where {T,U} length(R) == 0 && all(isempty(node) for node in selection.subselections) @@ -208,7 +209,7 @@ function StaticSelection(other::HierarchicalSelection) (keys, subselections) = ((), ()) end types = map(typeof, subselections) - StaticSelection{keys,Tuple{types...}}(NamedTuple{keys}(subselections)) + StaticSelection(NamedTuple{keys}(subselections)) end export StaticSelection diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index 213bc5f8..a1ca2eae 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -147,7 +147,7 @@ end @inline get_submaps_shallow(choices::ValueChoiceMap) = () @inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val @inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) -@inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() +@inline get_address_schema(::Type{<:ValueChoiceMap}) = AllAddressSchema() """ choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) diff --git a/src/distribution.jl b/src/distribution.jl index d72b21a6..a1880304 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -91,6 +91,9 @@ end end # TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? @inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) +# TODO: this next regenerate method is here because StaticSelections can have this sort of empty leaf node; choicemaps +# cannot right now and only have empty ones; we should try to fix this if possible. +#@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, ::StaticSelection{(), Tuple{}}) where {n} = (tr, 0., NoChange()) @inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) new_tr = DistributionTrace(tr.val, args, tr.dist) weight = get_score(new_tr) - get_score(tr) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index b352d3ca..eba97a82 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -36,15 +36,15 @@ function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode) - if node in selected_choices - push!(fwd_marked, node) - end -end - function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) - if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) - push!(fwd_marked, node) + if node.generative_function isa Distribution + if node in selected_choices + push!(fwd_marked, node) + end + else + if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) + push!(fwd_marked, node) + end end end @@ -60,20 +60,15 @@ function back_pass!(back_marked, node::JuliaNode) end end -function back_pass!(back_marked, node::RandomChoiceNode) - # the logpdf of every random choice is a SINK - for input_node in node.inputs - push!(back_marked, input_node) - end - # the value of every random choice is in back_marked, since it affects its logpdf - push!(back_marked, node) -end - function back_pass!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK for input_node in node.inputs push!(back_marked, input_node) end + if node.generative_function isa Distribution + # the value of every random choice is in back_marked, since it affects its logpdf + push!(back_marked, node) + end end function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) @@ -134,35 +129,35 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - - # every random choice is in back_marked, since it affects it logpdf, but - # also possibly due to other downstream usage of the value - @assert node in back_marked +function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + # for reference by other nodes during back_codegen! + # could performance optimize this away + push!(stmts, :($(node.name) = get_retval(trace.$(get_subtrace_fieldname(node))))) - if node in fwd_marked - # the only way we are fwd_marked is if this choice was selected + # every random choice is in back_marked, since it affects it logpdf, but + # also possibly due to other downstream usage of the value + @assert node in back_marked - # initialize gradient with respect to the value of the random choice to zero - # it will be a runtime error, thrown here, if there is no zero() method - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - end -end + if node in fwd_marked + # the only way we are fwd_marked is if this choice was selected -function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) + # initialize gradient with respect to the value of the random choice to zero + # it will be a runtime error, thrown here, if there is no zero() method + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + end + else + # for reference by other nodes during back_codegen! + # could performance optimize this away + subtrace_fieldname = get_subtrace_fieldname(node) + push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) - # NOTE: we will still potentially run choice_gradients recursively on the generative function, - # we just might not use its return value gradient. - if node in fwd_marked && node in back_marked - # we are fwd_marked if an input was fwd_marked, or if we were selected internally - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + # NOTE: we will still potentially run choice_gradients recursively on the generative function, + # we just might not use its return value gradient. + if node in fwd_marked && node in back_marked + # we are fwd_marked if an input was fwd_marked, or if we were selected internally + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + end end end @@ -217,19 +212,19 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: end function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, - node::RandomChoiceNode, logpdf_grad::Symbol) + node::GenerativeFunctionCallNode, logpdf_grad::Symbol) # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) - push!(stmts, :($logpdf_grad = logpdf_grad($(node.dist), $(node.name), $(args...)))) + push!(stmts, :($logpdf_grad = logpdf_grad($(node.generative_function), $(node.name), $(args...)))) end # increment gradients of input nodes that are in fwd_marked for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized - if !has_argument_grads(node.dist)[i] - error("Distribution $(node.dist) does not have logpdf gradient for argument $i") + if !has_argument_grads(node.generative_function)[i] + error("Distribution $(node.generative_function) does not have logpdf gradient for argument $i") end push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) end @@ -243,94 +238,91 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropTraceMode) - logpdf_grad = gensym("logpdf_grad") + node::GenerativeFunctionCallNode, mode::BackpropTraceMode) + if node.generative_function isa Distribution + logpdf_grad = gensym("logpdf_grad") - # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + # backpropagate to the inputs + back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) - # backpropagate to the value (if it was selected) - if node in fwd_marked - if !has_output_grad(node.dist) - error("Distribution $dist does not logpdf gradient for its output value") + # backpropagate to the value (if it was selected) + if node in fwd_marked + if !has_output_grad(node.generative_function) + error("Distribution $(node.generative_function) does not logpdf gradient for its output value") + end + push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) + end + else + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) - end -end - -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropParamsMode) - logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) -end - -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::GenerativeFunctionCallNode, mode::BackpropTraceMode) - - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end - if node in fwd_marked - input_grads = gensym("call_input_grads") - value_trie = value_trie_var(node) - gradient_trie = gradient_trie_var(node) - subtrace_fieldname = get_subtrace_fieldname(node) - call_selection = gensym("call_selection") - if node in selected_calls - push!(stmts, :($call_selection = $qn_static_getindex(selection, $(QuoteNode(Val(node.addr)))))) - else - push!(stmts, :($call_selection = EmptySelection())) + if node in fwd_marked + input_grads = gensym("call_input_grads") + value_trie = value_trie_var(node) + gradient_trie = gradient_trie_var(node) + subtrace_fieldname = get_subtrace_fieldname(node) + call_selection = gensym("call_selection") + if node in selected_calls + push!(stmts, :($call_selection = $qn_static_getindex(selection, $(QuoteNode(Val(node.addr)))))) + else + push!(stmts, :($call_selection = EmptySelection())) + end + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + trace.$subtrace_fieldname, $call_selection, $retval_grad))) end - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( - trace.$subtrace_fieldname, $call_selection, $retval_grad))) - end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + end end - end - # NOTE: the value_trie and gradient_trie are dealt with later + # NOTE: the value_trie and gradient_trie are dealt with later + end end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropParamsMode) - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end + if node.generative_function isa Distribution + logpdf_grad = gensym("logpdf_grad") + back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + else + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) + end - if node in fwd_marked - input_grads = gensym("call_input_grads") - subtrace_fieldname = get_subtrace_fieldname(node) - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) - end + if node in fwd_marked + input_grads = gensym("call_input_grads") + subtrace_fieldname = get_subtrace_fieldname(node) + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) + end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + end end end end -function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, +function generate_value_gradient_trie(selected_choices::Set{GenerativeFunctionCallNode}, selected_calls::Set{GenerativeFunctionCallNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(trace.$(get_value_fieldname(node)))), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(get_retval(trace.$(get_subtrace_fieldname(node))))), selected_choices_vec) leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) @@ -350,18 +342,18 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, end function get_selected_choices(::EmptyAddressSchema, ::StaticIR) - Set{RandomChoiceNode}() + Set{GenerativeFunctionCallNode}() end function get_selected_choices(::AllAddressSchema, ir::StaticIR) - Set{RandomChoiceNodes}(ir.choice_nodes) + Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if node.generative_function isa Distribution]...) end function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) selected_choice_addrs = Set(keys(schema)) - selected_choices = Set{RandomChoiceNode}() - for node in ir.choice_nodes - if node.addr in selected_choice_addrs + selected_choices = Set{GenerativeFunctionCallNode}() + for node in ir.call_nodes + if node.generative_function isa Distribution && node.addr in selected_choice_addrs push!(selected_choices, node) end end @@ -373,14 +365,14 @@ function get_selected_calls(::EmptyAddressSchema, ::StaticIR) end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}(ir.call_nodes) + Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if !(node.generative_function isa Distribution)]...) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) selected_call_addrs = Set(keys(schema)) selected_calls = Set{GenerativeFunctionCallNode}() for node in ir.call_nodes - if node.addr in selected_call_addrs + if !(node.generative_function isa Distribution) && node.addr in selected_call_addrs push!(selected_calls, node) end end @@ -452,7 +444,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, ir = get_ir(gen_fn_type) # unlike choice_gradients we don't take gradients w.r.t. the value of random choices - selected_choices = Set{RandomChoiceNode}() + selected_choices = Set{GenerativeFunctionCallNode}() # we need to guarantee that we visit every generative function call, # because we need to backpropagate to its trainable parameters diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index c8265889..de6acf6e 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -18,14 +18,6 @@ struct JuliaNode <: StaticIRNode typ::Union{Symbol,Expr,QuoteNode} end -struct RandomChoiceNode <: StaticIRNode - dist::Distribution - inputs::Vector{StaticIRNode} - addr::Symbol - name::Symbol - typ::Union{Symbol,Expr,QuoteNode} -end - struct GenerativeFunctionCallNode <: StaticIRNode generative_function::GenerativeFunction inputs::Vector{StaticIRNode} @@ -38,7 +30,6 @@ struct StaticIR nodes::Vector{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::StaticIRNode @@ -50,12 +41,10 @@ mutable struct StaticIRBuilder node_set::Set{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::Union{Nothing,StaticIRNode} vars::Set{Symbol} - addrs_to_choice_nodes::Dict{Symbol,RandomChoiceNode} addrs_to_call_nodes::Dict{Symbol,GenerativeFunctionCallNode} accepts_output_grad::Bool end @@ -65,17 +54,15 @@ function StaticIRBuilder() node_set = Set{StaticIRNode}() trainable_param_nodes = Vector{TrainableParameterNode}() arg_nodes = Vector{ArgumentNode}() - choice_nodes = Vector{RandomChoiceNode}() call_nodes = Vector{GenerativeFunctionCallNode}() julia_nodes = Vector{JuliaNode}() return_node = nothing vars = Set{Symbol}() - addrs_to_choice_nodes = Dict{Symbol,RandomChoiceNode}() addrs_to_call_nodes = Dict{Symbol,GenerativeFunctionCallNode}() accepts_output_grad = false - StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, choice_nodes, call_nodes, + StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, call_nodes, julia_nodes, - return_node, vars, addrs_to_choice_nodes, addrs_to_call_nodes, + return_node, vars, addrs_to_call_nodes, accepts_output_grad) end @@ -87,7 +74,6 @@ function build_ir(builder::StaticIRBuilder) builder.nodes, builder.trainable_param_nodes, builder.arg_nodes, - builder.choice_nodes, builder.call_nodes, builder.julia_nodes, builder.return_node, @@ -109,7 +95,7 @@ function check_inputs_exist(builder::StaticIRBuilder, input_nodes) end function check_addr_unique(builder::StaticIRBuilder, addr::Symbol) - if haskey(builder.addrs_to_choice_nodes, addr) || haskey(builder.addrs_to_call_nodes, addr) + if haskey(builder.addrs_to_call_nodes, addr) error("Address $addr was not unique") end end @@ -164,20 +150,6 @@ function add_constant_node!(builder::StaticIRBuilder, val, node end -function add_addr_node!(builder::StaticIRBuilder, dist::Distribution; - inputs::Vector=[], addr::Symbol=gensym(), - name::Symbol=gensym()) - check_unique_var(builder, name) - check_addr_unique(builder, addr) - check_inputs_exist(builder, inputs) - typ = QuoteNode(get_return_type(dist)) - node = RandomChoiceNode(dist, inputs, addr, name, typ) - _add_node!(builder, node) - builder.addrs_to_choice_nodes[addr] = node - push!(builder.choice_nodes, node) - node -end - function add_addr_node!(builder::StaticIRBuilder, gen_fn::GenerativeFunction; inputs::Vector=[], addr::Symbol=gensym(), name::Symbol=gensym()) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index b53eae95..2beecca5 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -21,27 +21,6 @@ function process!(state::StaticIRGenerateState, node::JuliaNode, options) end end -function process!(state::StaticIRGenerateState, node::RandomChoiceNode, options) - schema = state.schema - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) - if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) - push!(state.stmts, :($(node.name) = $qn_static_get_value(constraints, Val($addr)))) - push!(state.stmts, :($incr = $qn_logpdf($dist, $(node.name), $(args...)))) - push!(state.stmts, :($weight += $incr)) - else - push!(state.stmts, :($(node.name) = $qn_random($dist, $(args...)))) - push!(state.stmts, :($incr = $qn_logpdf($dist, $(node.name), $(args...)))) - end - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode, options) schema = state.schema args = map((input_node) -> input_node.name, node.inputs) diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index 2f65ecc2..ed14a2f2 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -5,14 +5,6 @@ end function process!(state::StaticIRProjectState, node) end -function process!(state::StaticIRProjectState, node::RandomChoiceNode) - schema = state.schema - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($weight += trace.$(get_score_fieldname(node)))) - end -end - function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) schema = state.schema addr = QuoteNode(node.addr) diff --git a/src/static_ir/render_ir.jl b/src/static_ir/render_ir.jl index 22e7b362..880fec50 100644 --- a/src/static_ir/render_ir.jl +++ b/src/static_ir/render_ir.jl @@ -1,7 +1,12 @@ label(node::ArgumentNode) = String(node.name) label(node::JuliaNode) = String(node.name) -label(node::RandomChoiceNode) = "$(node.dist) $(node.addr) $(node.name)" -label(node::GenerativeFunctionCallNode) = "$(node.addr) $(node.name)" +function label(node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + "$(node.generative_function) $(node.addr) $(node.name)" + else + "$(node.addr) $(node.name)" + end +end function draw_graph(ir::StaticIR, graphviz, fname) dot = graphviz.Digraph() @@ -14,7 +19,7 @@ function draw_graph(ir::StaticIR, graphviz, fname) shape = "diamond" color = "white" parents = [] - elseif isa(node, RandomChoiceNode) + elseif isa(node, GenerativeFunctionCallNode) && node.generative_function isa Distribution shape = "ellipse" color = "white" parents = node.inputs diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index b2d5429e..267183ac 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -20,19 +20,6 @@ function process!(state::StaticIRSimulateState, node::JuliaNode, options) end end -function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options) - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - push!(state.stmts, :($(node.name) = $qn_random($dist, $(args...)))) - push!(state.stmts, :($incr = $qn_logpdf($dist, $(node.name), $(args...)))) - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode, options) args = map((input_node) -> input_node.name, node.inputs) args_tuple = Expr(:tuple, args...) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 168ccf50..358a01d7 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -43,18 +43,10 @@ function get_value_fieldname(node::ArgumentNode) Symbol("$(arg_prefix)_$(node.name)") end -function get_value_fieldname(node::RandomChoiceNode) - Symbol("$(choice_value_prefix)_$(node.addr)") -end - function get_value_fieldname(node::JuliaNode) Symbol("$(julia_prefix)_$(node.name)") end -function get_score_fieldname(node::RandomChoiceNode) - Symbol("$(choice_score_prefix)_$(node.addr)") -end - function get_subtrace_fieldname(node::GenerativeFunctionCallNode) Symbol("$(subtrace_prefix)_$(node.addr)") end @@ -75,12 +67,6 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio fieldname = get_value_fieldname(node) push!(fields, TraceField(fieldname, node.typ)) end - for node in ir.choice_nodes - value_fieldname = get_value_fieldname(node) - push!(fields, TraceField(value_fieldname, node.typ)) - score_fieldname = get_score_fieldname(node) - push!(fields, TraceField(score_fieldname, QuoteNode(Float64))) - end for node in ir.call_nodes subtrace_fieldname = get_subtrace_fieldname(node) subtrace_type = QuoteNode(get_trace_type(node.generative_function)) @@ -137,11 +123,6 @@ end function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] - for node in ir.choice_nodes - addr = node.addr - value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), ValueChoiceMap($value)))) - end for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) @@ -175,19 +156,8 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) end ) end - - choice_getindex_exprs = Expr[] - for node in ir.choice_nodes - push!(choice_getindex_exprs, - quote - function Gen.static_getindex(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) - return trace.$(get_value_fieldname(node)) - end - end - ) - end - return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] + return [get_subtrace_exprs; call_getindex_exprs] end function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) @@ -201,21 +171,11 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) end - # return a ValueChoiceMap if get_submap is run on an address containing a value - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(ValueChoiceMap(choices.trace.$(get_value_fieldname(node))))))) - end methods end function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol) - choice_addrs = [QuoteNode(node.addr) for node in ir.choice_nodes] - call_addrs = [QuoteNode(node.addr) for node in ir.call_nodes] - addrs = vcat(choice_addrs, call_addrs) + addrs = [QuoteNode(node.addr) for node in ir.call_nodes] Expr(:function, Expr(:call, Expr(:(.), Gen, QuoteNode(:get_schema)), :(::Type{$trace_struct_name})), Expr(:block, diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index c806bba3..03f6d8c8 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -9,7 +9,6 @@ const calldiff_prefix = gensym("calldiff") calldiff_var(node::GenerativeFunctionCallNode) = Symbol("$(calldiff_prefix)_$(node.addr)") const choice_discard_prefix = gensym("choice_discard") -choice_discard_var(node::RandomChoiceNode) = Symbol("$(choice_discard_prefix)_$(node.addr)") const call_discard_prefix = gensym("call_discard") call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_prefix)_$(node.addr)") @@ -19,21 +18,18 @@ call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_pref ######################## struct ForwardPassState - input_changed::Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}} + input_changed::Set{GenerativeFunctionCallNode} value_changed::Set{StaticIRNode} - constrained_or_selected_choices::Set{RandomChoiceNode} constrained_or_selected_calls::Set{GenerativeFunctionCallNode} discard_calls::Set{GenerativeFunctionCallNode} end function ForwardPassState() - input_changed = Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}}() + input_changed = Set{GenerativeFunctionCallNode}() value_changed = Set{StaticIRNode}() - constrained_or_selected_choices = Set{RandomChoiceNode}() constrained_or_selected_calls = Set{GenerativeFunctionCallNode}() discard_calls = Set{GenerativeFunctionCallNode}() - ForwardPassState(input_changed, value_changed, constrained_or_selected_choices, - constrained_or_selected_calls, discard_calls) + ForwardPassState(input_changed, value_changed, constrained_or_selected_calls, discard_calls) end function forward_pass_argdiff!(state::ForwardPassState, @@ -46,30 +42,19 @@ function forward_pass_argdiff!(state::ForwardPassState, end end -function process_forward!(::AddressSchema, ::ForwardPassState, ::TrainableParameterNode) end +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, ::ForwardPassState, ::TrainableParameterNode) end -function process_forward!(::AddressSchema, ::ForwardPassState, node::ArgumentNode) end +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, ::ForwardPassState, node::ArgumentNode) end -function process_forward!(::AddressSchema, state::ForwardPassState, node::JuliaNode) +function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, state::ForwardPassState, node::JuliaNode) if any(input_node in state.value_changed for input_node in node.inputs) push!(state.value_changed, node) end end -function process_forward!(schema::AddressSchema, state::ForwardPassState, - node::RandomChoiceNode) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.constrained_or_selected_choices, node) - push!(state.value_changed, node) - end - if any(input_node in state.value_changed for input_node in node.inputs) - push!(state.input_changed, node) - end -end - -function process_forward!(schema::AddressSchema, state::ForwardPassState, +function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, Selection}}, state::ForwardPassState, node::GenerativeFunctionCallNode) + schema = get_address_schema(constraint_type) @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) push!(state.constrained_or_selected_calls, node) @@ -78,8 +63,20 @@ function process_forward!(schema::AddressSchema, state::ForwardPassState, end if any(input_node in state.value_changed for input_node in node.inputs) push!(state.input_changed, node) - push!(state.value_changed, node) # TODO can check whether the node is satically absorbing push!(state.discard_calls, node) + + ## check if we can statically guarantee that this generative function has a `NoChange` diff ## + update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate + + trace_type = get_trace_type(node.generative_function) + update_rettype = Core.Compiler.return_type( + update_fn, + Tuple{trace_type, Tuple, Tuple, constraint_type} + ) + guaranteed_returns_nochange = update_rettype <: Tuple && update_rettype != Union{} && update_rettype.parameters[3] == NoChange + if !guaranteed_returns_nochange + push!(state.value_changed, node) + end end end @@ -113,15 +110,6 @@ function process_backward!(fwd::ForwardPassState, back::BackwardPassState, end end -function process_backward!(fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, options) - if node in fwd.input_changed || node in fwd.constrained_or_selected_choices - for input_node in node.inputs - push!(back.marked, input_node) - end - end -end - function process_backward!(fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, options) if node in fwd.input_changed || node in fwd.constrained_or_selected_calls @@ -189,118 +177,6 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, end end -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::UpdateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $qn_Diffed($qn_static_get_value(constraints, Val($addr)), $qn_unknown_change))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), NoChange()))) - end - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $qn_strip_diff($(node.name)), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), $qn_no_change))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $qn_static_get_value(constraints, Val($addr)))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - end - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::RegenerateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - output_value = Expr(:call, qn_strip_diff, node.name) - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $qn_Diffed($qn_random($dist, $(arg_values...)), UnknownChange()))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $output_value, $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $output_value, $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $qn_random($dist, $(arg_values...)))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $(node.name), $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($new_logpdf = $qn_logpdf($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, ::UpdateMode, options) @@ -431,33 +307,20 @@ function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) end end -function generate_discard!(stmts::Vector{Expr}, - constrained_choices::Set{RandomChoiceNode}, - discard_calls::Set{GenerativeFunctionCallNode}) - discard_leaf_nodes = Dict{Symbol,Symbol}() - for node in constrained_choices - discard_leaf_nodes[node.addr] = choice_discard_var(node) - end - discard_internal_nodes = Dict{Symbol,Symbol}() +function generate_discard!(stmts::Vector{Expr}, discard_calls::Set{GenerativeFunctionCallNode}) + discard_nodes = Dict{Symbol,Symbol}() for node in discard_calls - discard_internal_nodes[node.addr] = call_discard_var(node) - end - if length(discard_leaf_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(discard_leaf_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) + discard_nodes[node.addr] = call_discard_var(node) end - if length(discard_internal_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(discard_internal_nodes...)) + + if length(discard_nodes) > 0 + (keys, nodes) = collect(zip(discard_nodes...)) else - (internal_keys, internal_nodes) = ((), ()) + (keys, nodes) = ((), ()) end - leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) - internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - all_keys = (leaf_keys..., internal_keys...) - all_nodes = ([:($(QuoteNode(ValueChoiceMap))($node)) for node in leaf_nodes]..., internal_nodes...) + keys = map((key::Symbol) -> QuoteNode(key), keys) expr = quote $(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(all_keys...),)}(($(all_nodes...),))) end + $(QuoteNode(NamedTuple)){($(keys...),)}(($(nodes...),))) end push!(stmts, :($discard = $expr)) end @@ -482,7 +345,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(constraints_type, fwd_state, node) end # backward marking pass @@ -505,7 +368,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ end generate_return_value!(stmts, fwd_state, ir.return_node, options) generate_new_trace!(stmts, trace_type, options) - generate_discard!(stmts, fwd_state.constrained_or_selected_choices, fwd_state.discard_calls) + generate_discard!(stmts, fwd_state.discard_calls) # return trace and weight and discard and retdiff push!(stmts, :(return ($trace, $weight, $retdiff, $discard))) @@ -530,7 +393,7 @@ function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type: fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(selection_type, fwd_state, node) end # backward marking pass diff --git a/test/runtests.jl b/test/runtests.jl index 74923603..5ad3a2d1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,12 +64,12 @@ end const dx = 1e-6 -include("autodiff.jl") -include("diff.jl") -include("selection.jl") -include("assignment.jl") -include("dynamic_dsl.jl") -include("optional_args.jl") +# include("autodiff.jl") +# include("diff.jl") +# include("selection.jl") +# include("assignment.jl") +# include("dynamic_dsl.jl") +# include("optional_args.jl") include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 317a8427..5576aa7d 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -117,14 +117,13 @@ params = ir.arg_nodes[2] @test params.compute_grad # choice nodes and call nodes -@test length(ir.choice_nodes) == 2 -@test length(ir.call_nodes) == 0 +@test length(ir.call_nodes) == 2 # is_outlier -is_outlier = ir.choice_nodes[1] +is_outlier = ir.call_nodes[1] @test is_outlier.addr == :z @test is_outlier.typ == QuoteNode(Bool) -@test is_outlier.dist == bernoulli +@test is_outlier.generative_function == bernoulli @test length(is_outlier.inputs) == 1 # std @@ -138,10 +137,10 @@ in2 = std.inputs[2] @test (in1 === is_outlier && in2 === params) || (in2 === is_outlier && in1 === params) # y -y = ir.choice_nodes[2] +y = ir.call_nodes[2] @test y.addr == :y @test y.typ == QuoteNode(Float64) -@test y.dist == normal +@test y.generative_function == normal @test length(y.inputs) == 2 @test y.inputs[2] === std @@ -174,40 +173,39 @@ xs = ir.arg_nodes[1] @test xs.typ == :(Vector{Float64}) @test !xs.compute_grad -# choice nodes and call nodes -@test length(ir.choice_nodes) == 4 -@test length(ir.call_nodes) == 1 +# call nodes +@test length(ir.call_nodes) == 5 # inlier_std -inlier_std = ir.choice_nodes[1] +inlier_std = ir.call_nodes[1] @test inlier_std.addr == :inlier_std @test inlier_std.typ == QuoteNode(Float64) -@test inlier_std.dist == gamma +@test inlier_std.generative_function == gamma @test length(inlier_std.inputs) == 2 # outlier_std -outlier_std = ir.choice_nodes[2] +outlier_std = ir.call_nodes[2] @test outlier_std.addr == :outlier_std @test outlier_std.typ == QuoteNode(Float64) -@test outlier_std.dist == gamma +@test outlier_std.generative_function == gamma @test length(outlier_std.inputs) == 2 # slope -slope = ir.choice_nodes[3] +slope = ir.call_nodes[3] @test slope.addr == :slope @test slope.typ == QuoteNode(Float64) -@test slope.dist == normal +@test slope.generative_function == normal @test length(slope.inputs) == 2 # intercept -intercept = ir.choice_nodes[4] +intercept = ir.call_nodes[4] @test intercept.addr == :intercept @test intercept.typ == QuoteNode(Float64) -@test intercept.dist == normal +@test intercept.generative_function == normal @test length(intercept.inputs) == 2 # data -ys = ir.call_nodes[1] +ys = ir.call_nodes[5] @test ys.addr == :data @test ys.typ == QuoteNode(PersistentVector{Float64}) @test ys.generative_function == data_fn @@ -376,7 +374,7 @@ ir2 = Gen.get_ir(typeof(f2)) return_node1 = ir1.return_node return_node2 = ir2.return_node @test isa(return_node2, typeof(return_node1)) -@test return_node2.dist == return_node1.dist +@test return_node2.generative_function == return_node1.generative_function inputs1 = return_node1.inputs inputs2 = return_node2.inputs From 61673a46cef1af989b003e226e762fc21c1814eb Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 21:59:39 -0400 Subject: [PATCH 16/34] brief documentation for Dist <: GenFn --- docs/src/ref/distributions.md | 6 ++++++ docs/src/ref/extending.md | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/src/ref/distributions.md b/docs/src/ref/distributions.md index c81801d4..f4b6d768 100644 --- a/docs/src/ref/distributions.md +++ b/docs/src/ref/distributions.md @@ -1,5 +1,11 @@ # Probability Distributions +In Gen, a probability distribution is a generative function which makes a single random choice +and returns the value of this choice. The choicemap for a probability distribution +is always a [`ValueChoiceMap`](@ref). In addition to supporting the regular `GFI` methods, +every distribution supports the methods [`random`](@ref) and [`logpdf`](@ref), described +in the [Distribution API](@ref custom_distributions). + Gen provides a library of built-in probability distributions, and two ways of writing custom distributions, both of which are explained below: diff --git a/docs/src/ref/extending.md b/docs/src/ref/extending.md index 7f9dfd48..b1d759b3 100644 --- a/docs/src/ref/extending.md +++ b/docs/src/ref/extending.md @@ -110,7 +110,7 @@ Gen's Distribution interface directly, as defined below. Probability distributions are singleton types whose supertype is `Distribution{T}`, where `T` indicates the data type of the random sample. ```julia -abstract type Distribution{T} end +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace} end ``` A new Distribution type must implement the following methods: @@ -146,6 +146,9 @@ has_output_grad logpdf_grad ``` +Any custom distribution will automatically be a `GenerativeFunction` since `Distribution <: GenerativeFunction`; +implementations of all GFI methods are automatically provided in terms of `random` and `logpdf`. + ## Custom generative functions We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the [`DynamicDSLFunction`](@ref) type as an example. From 298a333fc4a7646a8f5ec64c00174453a55fff56 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 17 Jun 2020 21:59:50 -0400 Subject: [PATCH 17/34] short map over distribution test --- test/assignment.jl | 4 ++-- test/modeling_library/map.jl | 15 +++++++++++++++ test/runtests.jl | 12 ++++++------ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/test/assignment.jl b/test/assignment.jl index 69890297..7485f92c 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -27,8 +27,8 @@ @test !has_value(vcm1, :addr) @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) - @test get_address_schema(typeof(vcm1)) == EmptyAddressSchema() - @test get_address_schema(ValueChoiceMap) == EmptyAddressSchema() + @test get_address_schema(typeof(vcm1)) == AllAddressSchema() + @test get_address_schema(ValueChoiceMap) == AllAddressSchema() @test nested_view(vcm1) == 2 end diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index bfb13eb4..ffe07d77 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -402,4 +402,19 @@ @test isapprox(get_param_grad(foo, :std), expected_std_grad) end + @testset "map over distribution" begin + flip_coins = Map(bernoulli) + coinflips_tr, weight = generate(flip_coins, (fill(0.4, 100),)) + @test weight == 0. + @test coinflips_tr[20] isa Bool + choices = get_choices(coinflips_tr) + @test get_submap(choices, 42) isa ValueChoiceMap{Bool} + val42 = get_value(choices, 42) + new_tr, weight, retdiff, discard = update(coinflips_tr, (fill(0.4, 100),), (NoChange(),), choicemap((42, !val42))) + @test new_tr[42] == !val42 + expected_score_change = logpdf(bernoulli, !val42, 0.4) - logpdf(bernoulli, val42, 0.4) + @test isapprox(get_score(new_tr) - get_score(coinflips_tr), expected_score_change) + @test isapprox(weight, expected_score_change) + end + end diff --git a/test/runtests.jl b/test/runtests.jl index 5ad3a2d1..74923603 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,12 +64,12 @@ end const dx = 1e-6 -# include("autodiff.jl") -# include("diff.jl") -# include("selection.jl") -# include("assignment.jl") -# include("dynamic_dsl.jl") -# include("optional_args.jl") +include("autodiff.jl") +include("diff.jl") +include("selection.jl") +include("assignment.jl") +include("dynamic_dsl.jl") +include("optional_args.jl") include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") From e34875a84a9b31f1f799b40ef582ee28308d4ffe Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:26:33 -0400 Subject: [PATCH 18/34] default static_get_submap = EmptyChoiceMap --- src/static_ir/trace.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 358a01d7..4ebe2800 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -20,6 +20,7 @@ function get_schema end abstract type StaticIRTrace <: Trace end @inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(trace::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false From 972d4555907813ec7fd77a2b202fdfdacf4d5f79 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:29:06 -0400 Subject: [PATCH 19/34] default static_get_submap = EmptyChoiceMap --- src/static_ir/trace.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 168ccf50..a79ed539 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -20,6 +20,7 @@ function get_schema end abstract type StaticIRTrace <: Trace end @inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false From ee64d12fad1b25d645d642ec2422b8dc1a62ae6f Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:54:13 -0400 Subject: [PATCH 20/34] dist performance improvements --- src/distribution.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/distribution.jl b/src/distribution.jl index a1880304..878fb7ad 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -5,13 +5,15 @@ struct DistributionTrace{T, Dist} <: Trace val::T args - dist::Dist + score::Float64 end +@inline dist(::DistributionTrace{T, Dist}) where {T, Dist} = Dist() abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end +@inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.dist) + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, dist(tr)) end """ @@ -62,10 +64,8 @@ get_return_type(::Distribution{T}) where {T} = T @inline Gen.get_args(trace::DistributionTrace) = trace.args @inline Gen.get_choices(trace::DistributionTrace) = ValueChoiceMap(trace.val) # should be able to get type of val @inline Gen.get_retval(trace::DistributionTrace) = trace.val -@inline Gen.get_gen_fn(trace::DistributionTrace) = trace.dist - -# TODO: for performance would it be better to store the score in the trace? -@inline Gen.get_score(trace::DistributionTrace) = logpdf(trace.dist, trace.val, trace.args...) +@inline Gen.get_gen_fn(trace::DistributionTrace) = dist(trace) +@inline Gen.get_score(trace::DistributionTrace) = trace.score @inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. @inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) @@ -80,28 +80,26 @@ end (tr, weight) end @inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::ValueChoiceMap) - new_tr = DistributionTrace(get_value(constraints), args, tr.dist) + new_tr = DistributionTrace(get_value(constraints), args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, UnknownChange(), get_choices(tr)) end @inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::EmptyChoiceMap) - new_tr = DistributionTrace(tr.val, args, tr.dist) + new_tr = DistributionTrace(tr.val, args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, NoChange(), EmptyChoiceMap()) end +@inline Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, constraints::EmptyChoiceMap) where {n} = (tr, 0., NoChange()) # TODO: do I need an update method to handle empty choicemaps which are not `EmptyChoiceMap`s? @inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, selection::EmptySelection) where {n} = (tr, 0., NoChange()) -# TODO: this next regenerate method is here because StaticSelections can have this sort of empty leaf node; choicemaps -# cannot right now and only have empty ones; we should try to fix this if possible. -#@inline Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::NTuple{n, NoChange}, ::StaticSelection{(), Tuple{}}) where {n} = (tr, 0., NoChange()) @inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::EmptySelection) - new_tr = DistributionTrace(tr.val, args, tr.dist) + new_tr = DistributionTrace(tr.val, args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, NoChange()) end @inline function Gen.regenerate(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, selection::AllSelection) - new_val = random(tr.dist, args...) - new_tr = DistributionTrace(new_val, args, tr.dist) + new_val = random(dist(tr), args...) + new_tr = DistributionTrace(new_val, args, dist(tr)) (new_tr, 0., UnknownChange()) end @inline function Gen.propose(dist::Distribution, args::Tuple) @@ -110,7 +108,7 @@ end (ValueChoiceMap(val), score, val) end @inline function Gen.assess(dist::Distribution, args::Tuple, choices::ValueChoiceMap) - weight = logpdf(dist, choices.val, args...) + weight = logpdf(dist, get_value(choices), args...) (weight, choices.val) end From fd1991ff3df029224ddc464642abb0ec15c5ead3 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 10:53:28 -0400 Subject: [PATCH 21/34] minor performance improvement --- src/choice_map/static_choice_map.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 58ef57d3..1aa40d4f 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -36,7 +36,7 @@ end function StaticChoiceMap(other::ChoiceMap) keys_and_nodes = collect(get_submaps_shallow(other)) if length(keys_and_nodes) > 0 - (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + (addrs::NTuple{n, Symbol} where {n}, submaps) = zip(keys_and_nodes...) else addrs = () submaps = () From c3d5db029e57d7bcb381a113dca1fa3659983296 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 12:35:10 -0400 Subject: [PATCH 22/34] performance improvement related to zip bug --- src/choice_map/static_choice_map.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 1aa40d4f..ff8c01a7 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -34,9 +34,10 @@ end # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = collect(get_submaps_shallow(other)) + keys_and_nodes = get_submaps_shallow(other) if length(keys_and_nodes) > 0 - (addrs::NTuple{n, Symbol} where {n}, submaps) = zip(keys_and_nodes...) + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) else addrs = () submaps = () From 8a43845bb1e1925125e28ad9131fd5621a9f9d5d Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sat, 20 Jun 2020 10:12:00 -0400 Subject: [PATCH 23/34] better static retdiff checking --- src/static_ir/update.jl | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 03f6d8c8..e3baf39a 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -52,29 +52,39 @@ function process_forward!(::Type{<:Union{ChoiceMap, Selection}}, state::ForwardP end end -function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, Selection}}, state::ForwardPassState, +function cannot_statically_guarantee_nochange_retdiff(constraint_type, node, state) + update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate + + trace_type = get_trace_type(node.generative_function) + argdiff_types = map(input_node -> input_node in state.value_changed ? UnknownChange : NoChange, node.inputs) + argdiff_type = Tuple{argdiff_types...} + # TODO: can we know the arg type statically? + update_rettype = Core.Compiler.return_type( + update_fn, + Tuple{trace_type, Tuple, argdiff_type, constraint_type} + ) + has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) > 3 + guaranteed_returns_nochange = has_static_retdiff && update_rettype.parameters[3] == NoChange + + return !guaranteed_returns_nochange +end + +function process_forward!(constraint_type::Type{<:Union{<:ChoiceMap, <:Selection}}, state::ForwardPassState, node::GenerativeFunctionCallNode) schema = get_address_schema(constraint_type) + will_run_update = false @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) push!(state.constrained_or_selected_calls, node) - push!(state.value_changed, node) - push!(state.discard_calls, node) + will_run_update = true end if any(input_node in state.value_changed for input_node in node.inputs) push!(state.input_changed, node) + will_run_update = true + end + if will_run_update push!(state.discard_calls, node) - - ## check if we can statically guarantee that this generative function has a `NoChange` diff ## - update_fn = constraint_type <: ChoiceMap ? Gen.update : Gen.regenerate - - trace_type = get_trace_type(node.generative_function) - update_rettype = Core.Compiler.return_type( - update_fn, - Tuple{trace_type, Tuple, Tuple, constraint_type} - ) - guaranteed_returns_nochange = update_rettype <: Tuple && update_rettype != Union{} && update_rettype.parameters[3] == NoChange - if !guaranteed_returns_nochange + if cannot_statically_guarantee_nochange_retdiff(constraint_type, node, state) push!(state.value_changed, node) end end From ffd9373c243593fbf46ef9b4331a93b9a787fdfd Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 10:39:43 -0400 Subject: [PATCH 24/34] add static info for dist trace type --- src/distribution.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distribution.jl b/src/distribution.jl index 878fb7ad..7dd5aa11 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -12,6 +12,9 @@ end abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end @inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) +# we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully +@inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} + function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, dist(tr)) end From 67d5e120c07e7a0b32a658e45698cc674b304334 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 15:09:06 -0400 Subject: [PATCH 25/34] don't use static get_submap for staticchoicemap --- src/choice_map/static_choice_map.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index ff8c01a7..587fc6ee 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -17,7 +17,15 @@ end @inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) @inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) + +# TODO: would it be faster to do static_get_submap? +function get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, addr::Symbol) where {Addrs, SubmapTypes} + if addr in Addrs + choices.submaps[addr] + else + EmptyChoiceMap() + end +end @generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} if A in Addrs From 4966ea9ee0f10f633d854f5c0bef3d17adaf712e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 15:27:59 -0400 Subject: [PATCH 26/34] some simple MH benchmarks --- test/benchmarks/dynamic_mh.jl | 77 +++++++++++++++++++++++++++ test/benchmarks/run_benchmarks.jl | 2 + test/benchmarks/static_mh.jl | 87 +++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+) create mode 100644 test/benchmarks/dynamic_mh.jl create mode 100644 test/benchmarks/run_benchmarks.jl create mode 100644 test/benchmarks/static_mh.jl diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl new file mode 100644 index 00000000..fb7661a7 --- /dev/null +++ b/test/benchmarks/dynamic_mh.jl @@ -0,0 +1,77 @@ +module DynamicMHBenchmark +using Gen +import Random + +include("../../examples/regression/dynamic_model.jl") +include("../../examples/regression/dataset.jl") + +@gen function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen function is_outlier_proposal(trace, i::Int) + prev = trace[:data => i => :z] + @trace(bernoulli(prev ? 0.0 : 1.0), :data => i => :z) +end + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +println("Simple dynamic DSL MH on regression model:") +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 50) +println() + +end \ No newline at end of file diff --git a/test/benchmarks/run_benchmarks.jl b/test/benchmarks/run_benchmarks.jl new file mode 100644 index 00000000..13a4754e --- /dev/null +++ b/test/benchmarks/run_benchmarks.jl @@ -0,0 +1,2 @@ +include("static_mh.jl") +include("dynamic_mh.jl") \ No newline at end of file diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl new file mode 100644 index 00000000..0e801631 --- /dev/null +++ b/test/benchmarks/static_mh.jl @@ -0,0 +1,87 @@ +module StaticMHBenchmark +using Gen +import Random + +include("../../examples/regression/static_model.jl") +include("../../examples/regression/dataset.jl") + +@gen (static) function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen (static) function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen (static) function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen (static) function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen (static) function flip_z(z::Bool) + @trace(bernoulli(z ? 0.0 : 1.0), :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +Gen.load_generated_functions() + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +println("Simple static DSL (including CallAt nodes) MH on regression model:") +@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 50) +println() +end \ No newline at end of file From 0909a5b3eb9e99829dd3bfe99cf3612051b7c543 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 15:41:21 -0400 Subject: [PATCH 27/34] bug fix --- src/distribution.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution.jl b/src/distribution.jl index 7dd5aa11..ce5455b7 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -16,7 +16,7 @@ abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end @inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, dist(tr)) + DistributionTrace(convert(U, tr.val), tr.args, dist(tr)) end """ From 47cca5980c8b7be38b6215b4b49a69101c2a8173 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 16:42:28 -0400 Subject: [PATCH 28/34] remove ChoiceAt; bug fixes --- src/distribution.jl | 5 +- src/dsl/static.jl | 10 +- src/modeling_library/call_at/call_at.jl | 50 +++++- src/modeling_library/choice_at/choice_at.jl | 177 -------------------- src/modeling_library/modeling_library.jl | 1 - src/static_ir/update.jl | 4 +- test/modeling_library/call_at.jl | 2 +- test/modeling_library/choice_at.jl | 6 +- test/static_dsl.jl | 12 +- 9 files changed, 60 insertions(+), 207 deletions(-) delete mode 100644 src/modeling_library/choice_at/choice_at.jl diff --git a/src/distribution.jl b/src/distribution.jl index ce5455b7..6354558d 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -10,13 +10,14 @@ end @inline dist(::DistributionTrace{T, Dist}) where {T, Dist} = Dist() abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end +DistributionTrace{T, Dist}(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) @inline DistributionTrace(val::T, args::Tuple, dist::Dist) where {T, Dist <: Distribution} = DistributionTrace{T, Dist}(val, args, logpdf(dist, val, args...)) # we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully @inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} -function Base.convert(::Type{DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} - DistributionTrace(convert(U, tr.val), tr.args, dist(tr)) +function Base.convert(::Type{<:DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} + DistributionTrace{U, Dist}(convert(U, tr.val), tr.args, tr.score) end """ diff --git a/src/dsl/static.jl b/src/dsl/static.jl index 8cafd9f9..c9ed6958 100644 --- a/src/dsl/static.jl +++ b/src/dsl/static.jl @@ -51,10 +51,6 @@ end split_addr!(keys, addr_expr::QuoteNode) = push!(keys, addr_expr) split_addr!(keys, addr_expr::Symbol) = push!(keys, addr_expr) -"Construct choice-at or call-at combinator depending on type." -choice_or_call_at(gen_fn::GenerativeFunction, addr_typ) = call_at(gen_fn, addr_typ) -choice_or_call_at(dist::Distribution, addr_typ) = choice_at(dist, addr_typ) - "Generate informative node name for a Julia expression." gen_node_name(arg::Any) = gensym(string(arg)) gen_node_name(arg::Expr) = gensym(arg.head) @@ -78,12 +74,12 @@ function parse_trace_expr!(stmts, bindings, fn, args, addr) end addr = keys[1].value # Get top level address if length(keys) > 1 - # For each nesting level, wrap gen_fn_or_dist within choice_at / call_at + # For each nesting level, wrap gen_fn_or_dist within call_at for key in keys[2:end] push!(stmts, :($(esc(gen_fn_or_dist)) = - choice_or_call_at($(esc(gen_fn_or_dist)), Any))) + call_at($(esc(gen_fn_or_dist)), Any))) end - # Append the nested addresses as arguments to choice_at / call_at + # Append the nested addresses as arguments to call_at args = [args; reverse(keys[2:end])] end # Handle arguments to the traced call diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index f17d061f..f997ba39 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -140,18 +140,50 @@ function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, end function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) - subselection = selection[trace.key] - (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( - trace.subtrace, subselection, retval_grad) - input_grads = (kernel_input_grads..., nothing) - value_choices = CallAtChoiceMap(trace.key, value_submap) - gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) - (input_grads, value_choices, gradient_choices) + if trace.subtrace isa DistributionTrace + if retval_grad !== nothing && !has_output_grad(get_gen_fn(trace.subtrace)) + error("return value gradient not accepted but one was provided") + end + kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) + if trace.key in selection + value_choices = CallAtChoiceMap(trace.key, get_choices(trace.subtrace)) + choice_grad = kernel_arg_grads[1] + if choice_grad === nothing + error("gradient not available for selected choice") + end + if retval_grad !== nothing + choice_grad += retval_grad + end + gradient_choices = CallAtChoiceMap(trace.key, ValueChoiceMap(choice_grad)) + else + value_choices = EmptyChoiceMap() + gradient_choices = EmptyChoiceMap() + end + input_grads = (kernel_arg_grads[2:end]..., nothing) + return (input_grads, value_choices, gradient_choices) + else + subselection = selection[trace.key] + (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( + trace.subtrace, subselection, retval_grad) + input_grads = (kernel_input_grads..., nothing) + value_choices = CallAtChoiceMap(trace.key, value_submap) + gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) + return (input_grads, value_choices, gradient_choices) + end end function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) - kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) - (kernel_input_grads..., nothing) + if trace.subtrace isa DistributionTrace + if retval_grad !== nothing && !has_output_grad(trace.gen_fn.dist) + error("return value gradient not accepted but one was provided") + end + kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) + return (kernel_arg_grads[2:end]..., nothing) + else + kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) + return (kernel_input_grads..., nothing) + end + end export call_at diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl deleted file mode 100644 index f3875895..00000000 --- a/src/modeling_library/choice_at/choice_at.jl +++ /dev/null @@ -1,177 +0,0 @@ -# TODO optimize ChoiceAtTrace using type parameters - -struct ChoiceAtTrace <: Trace - gen_fn::GenerativeFunction # the ChoiceAtCombinator (not the kernel) - value::Any - key::Any - kernel_args::Tuple - score::Float64 -end - -get_args(trace::ChoiceAtTrace) = (trace.kernel_args..., trace.key) -get_retval(trace::ChoiceAtTrace) = trace.value -get_score(trace::ChoiceAtTrace) = trace.score -get_gen_fn(trace::ChoiceAtTrace) = trace.gen_fn - -struct ChoiceAtChoiceMap{T,K} <: ChoiceMap - key::K - value::T -end - -get_choices(trace::ChoiceAtTrace) = ChoiceAtChoiceMap(trace.key, trace.value) -Base.isempty(::ChoiceAtChoiceMap) = false -function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} - SingleDynamicKeyAddressSchema() -end -get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) -get_submap(choices::ChoiceAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} - choices.key == addr ? choices.value : throw(KeyError(choices, addr)) -end -get_submap(choices::ChoiceAtChoiceMap, addr) = addr == choices.key ? ValueChoiceMap(choices.value) : EmptyChoiceMap() -get_submaps_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, ValueChoiceMap(choices.value)),) -get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) - -struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} - dist::Distribution{T} -end - -accepts_output_grad(gen_fn::ChoiceAtCombinator) = has_output_grad(gen_fn.dist) - -# TODO -# accepts_output_grad is true if the return value is dependent on the 'gradient source elements' -# if the random choice itself is not a 'gradient source element' then it is independent (false) -# if the random choice is a 'gradient source element', then the return value is dependent (true) -# we will consider the random choice as a gradient source element if the -# distribution has has_output_grad = true) - -function choice_at(dist::Distribution{T}, ::Type{K}) where {T,K} - ChoiceAtCombinator{T,K}(dist) -end - -unpack_choice_at_args(args) = (args[end], args[1:end-1]) - -function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = get_value(choices, key) - weight = logpdf(gen_fn.dist, value, kernel_args...) - (weight, value) -end - -function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - choices = ChoiceAtChoiceMap(key, value) - (choices, score, value) -end - -function simulate(gen_fn::ChoiceAtCombinator, args::Tuple) - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - ChoiceAtTrace(gen_fn, value, key, kernel_args, score) -end - -function generate(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - constrained = has_value(choices, key) - value = constrained ? get_value(choices, key) : random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - trace = ChoiceAtTrace(gen_fn, value, key, kernel_args, score) - weight = constrained ? score : 0. - (trace, weight) -end - -function project(trace::ChoiceAtTrace, selection::Selection) - (trace.key in selection) ? trace.score : 0. -end - -function update(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - constrained = has_value(choices, key) - if key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(trace.key, trace.value) - elseif !key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(key, trace.value) - elseif !key_changed && !constrained - new_value = trace.value - discard = EmptyChoiceMap() - else - error("New address $key not constrained in update") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - weight = new_score - trace.score - (new_trace, weight, UnknownChange(), discard) -end - -function regenerate(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - selected = key in selection - if !key_changed && selected - new_value = random(trace.gen_fn.dist, kernel_args...) - elseif !key_changed && !selected - new_value = trace.value - elseif key_changed && !selected - new_value = random(trace.gen_fn.dist, kernel_args...) - else - error("Cannot select new address $key in regenerate") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - if !key_changed && selected - weight = 0. - elseif !key_changed && !selected - weight = new_score - trace.score - elseif key_changed && !selected - weight = 0. - end - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - (new_trace, weight, UnknownChange()) -end - -function choice_gradients(trace::ChoiceAtTrace, selection::Selection, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - if trace.key in selection - value_choices = ChoiceAtChoiceMap(trace.key, trace.value) - choice_grad = kernel_arg_grads[1] - if choice_grad == nothing - error("gradient not available for selected choice") - end - if retval_grad != nothing - choice_grad += retval_grad - end - gradient_choices = ChoiceAtChoiceMap(trace.key, choice_grad) - else - value_choices = EmptyChoiceMap() - gradient_choices = EmptyChoiceMap() - end - input_grads = (kernel_arg_grads[2:end]..., nothing) - (input_grads, value_choices, gradient_choices) -end - -function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - (kernel_arg_grads[2:end]..., nothing) -end - -export choice_at diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index 6200e5b5..5d9e287b 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -19,7 +19,6 @@ include("dist_dsl/dist_dsl.jl") include("vector.jl") # built-in generative function combinators -include("choice_at/choice_at.jl") include("call_at/call_at.jl") include("map/map.jl") include("unfold/unfold.jl") diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index e3baf39a..ea4b17f9 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -63,9 +63,11 @@ function cannot_statically_guarantee_nochange_retdiff(constraint_type, node, sta update_fn, Tuple{trace_type, Tuple, argdiff_type, constraint_type} ) - has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) > 3 + has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) >= 3 guaranteed_returns_nochange = has_static_retdiff && update_rettype.parameters[3] == NoChange + # println("$trace_type, Tuple, $argdiff_type, $constraint_type >> $update_rettype : $has_static_retdiff") + return !guaranteed_returns_nochange end diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index 607eb61f..1985c610 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -1,4 +1,4 @@ -@testset "call_at combinator" begin +@testset "call_at combinator on non-distribution" begin @gen (grad) function foo((grad)(x::Float64)) return x + @trace(normal(x, 1), :y) diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 4f524138..69eb5249 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -1,6 +1,6 @@ -@testset "choice_at combinator" begin +@testset "call_at combinator on distribution" begin - at = choice_at(bernoulli, Int) + at = call_at(bernoulli, Int) @testset "assess" begin choices = choicemap() @@ -143,7 +143,7 @@ y = 1.2 constraints = choicemap() set_value!(constraints, 3, y) - (trace, _) = generate(choice_at(normal, Int), (0.0, 1.0, 3), constraints) + (trace, _) = generate(call_at(normal, Int), (0.0, 1.0, 3), constraints) # not selected (input_grads, choices, gradients) = choice_gradients( diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 5576aa7d..5d311df6 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -40,13 +40,13 @@ end ret = @trace(bernoulli(0.5), :x => i) end -# @trace(choice_at(bernoulli)(0.5, i), :x) +# @trace(call_at(bernoulli)(0.5, i), :x) @gen (static) function at_choice_example_2(i::Int) ret = @trace(bernoulli(0.5), :x => i => :y) end -# @trace(call_at(choice_at(bernoulli))(0.5, i, :y), :x) +# @trace(call_at(call_at(bernoulli))(0.5, i, :y), :x) @gen function foo(mu) @trace(normal(mu, 1), :y) @@ -255,8 +255,8 @@ ret = get_node_by_addr(ir, :x) @test isa(ret.inputs[1], Gen.JuliaNode) # () -> 0.5 @test ret.inputs[2] === i at = ret.generative_function -@test isa(at, Gen.ChoiceAtCombinator) -@test at.dist == bernoulli +@test isa(at, Gen.CallAtCombinator) +@test at.kernel == bernoulli # at_choice_example_2 ir = Gen.get_ir(typeof(at_choice_example_2)) @@ -271,8 +271,8 @@ ret = get_node_by_addr(ir, :x) at = ret.generative_function @test isa(at, Gen.CallAtCombinator) at2 = at.kernel -@test isa(at2, Gen.ChoiceAtCombinator) -@test at2.dist == bernoulli +@test isa(at2, Gen.CallAtCombinator) +@test at2.kernel == bernoulli end From 10df9520aa91cd7dd4846058628ecf82d4adb168 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 25 Jun 2020 17:06:22 -0400 Subject: [PATCH 29/34] decrease iters on benchmark --- test/benchmarks/dynamic_mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl index fb7661a7..88392cc0 100644 --- a/test/benchmarks/dynamic_mh.jl +++ b/test/benchmarks/dynamic_mh.jl @@ -70,8 +70,8 @@ end println("Simple dynamic DSL MH on regression model:") (xs, ys) = make_data_set(200) do_inference(xs, ys, 10) -@time do_inference(xs, ys, 50) -@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 20) +@time do_inference(xs, ys, 20) println() end \ No newline at end of file From a79390e1fdd255283637a81ee9826c813adcc538 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 3 Jul 2020 12:46:42 -0400 Subject: [PATCH 30/34] merge in updated master --- README.md | 2 +- src/dsl/dsl.jl | 83 ++++++++++++++++++++++++++++++----- src/dsl/static.jl | 12 +++-- src/static_ir/backprop.jl | 14 +++--- src/static_ir/generate.jl | 18 ++++---- src/static_ir/project.jl | 12 ++--- src/static_ir/simulate.jl | 10 ++--- src/static_ir/static_ir.jl | 41 ++++------------- src/static_ir/trace.jl | 28 ++++++------ src/static_ir/update.jl | 78 ++++++++++++++++---------------- test/benchmarks/dynamic_mh.jl | 2 +- test/benchmarks/static_mh.jl | 2 +- test/static_dsl.jl | 21 +++++++++ test/tilde_sugar.jl | 46 ++++++++++++------- 14 files changed, 222 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index 95a8cbfc..c2989f85 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Gen.jl -[![Build Status](https://travis-ci.org/probcomp/Gen.svg?branch=master)](https://travis-ci.org/probcomp/Gen.jl) +[![Build Status](https://travis-ci.org/probcomp/Gen.jl.svg?branch=master)](https://travis-ci.org/probcomp/Gen.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://probcomp.github.io/Gen.jl/stable) [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://probcomp.github.io/Gen.jl/dev) diff --git a/src/dsl/dsl.jl b/src/dsl/dsl.jl index f3f5bdab..a4f3c6b2 100644 --- a/src/dsl/dsl.jl +++ b/src/dsl/dsl.jl @@ -5,6 +5,7 @@ const DSL_ARG_GRAD_ANNOTATION = :grad const DSL_RET_GRAD_ANNOTATION = :grad const DSL_TRACK_DIFFS_ANNOTATION = :diffs const DSL_NO_JULIA_CACHE_ANNOTATION = :nojuliacache +const DSL_MACROS = Set([Symbol("@trace"), Symbol("@param")]) struct Argument name::Symbol @@ -71,21 +72,81 @@ function address_from_expression(lhs) end function desugar_tildes(expr) + trace_ref = GlobalRef(@__MODULE__, Symbol("@trace")) + line_num = LineNumberNode(1, :none) MacroTools.postwalk(expr) do e + # Replace with globally referenced macrocalls if MacroTools.@capture(e, {*} ~ rhs_) - :(@trace($rhs)) + Expr(:macrocall, trace_ref, line_num, rhs) elseif MacroTools.@capture(e, {addr_} ~ rhs_) - :(@trace($rhs, $(addr))) + Expr(:macrocall, trace_ref, line_num, rhs, addr) elseif MacroTools.@capture(e, lhs_ ~ rhs_) - addr_expr = address_from_expression(lhs) - :($lhs = @trace($rhs, $(addr_expr))) + addr = address_from_expression(lhs) + Expr(:(=), lhs, Expr(:macrocall, trace_ref, line_num, rhs, addr)) else e end end end -function parse_gen_function(ast, annotations) +function resolve_gen_macros(expr, __module__) + MacroTools.postwalk(expr) do e + # Resolve Gen macros to globally referenced macrocalls + if (MacroTools.@capture(e, @namespace_.m_(args__)) && + m in DSL_MACROS && __module__.eval(namespace) == @__MODULE__) + macro_ref = GlobalRef(@__MODULE__, m) + line_num = e.args[2] + Expr(:macrocall, macro_ref, line_num, args...) + elseif (MacroTools.@capture(e, @m_(args__)) && + m in DSL_MACROS && isdefined(__module__, m) && + getfield(__module__, m) == getfield(@__MODULE__, m)) + macro_ref = GlobalRef(@__MODULE__, m) + line_num = e.args[2] + Expr(:macrocall, macro_ref, line_num, args...) + else + e + end + end +end + +function extract_quoted_exprs(expr) + quoted_exprs = [] + expr = MacroTools.prewalk(expr) do e + if MacroTools.@capture(e, :(quoted_)) && !isa(quoted, Symbol) + push!(quoted_exprs, e) + Expr(:placeholder, length(quoted_exprs)) + else + e + end + end + return expr, quoted_exprs +end + +function insert_quoted_exprs(expr, quoted_exprs) + expr = MacroTools.prewalk(expr) do e + if MacroTools.@capture(e, p_placeholder) + idx = p.args[1] + quoted_exprs[idx] + else + e + end + end + return expr +end + +function preprocess_body(expr, __module__) + # Protect quoted expressions from pre-processing by extracting them + expr, quoted_exprs = extract_quoted_exprs(expr) + # Desugar tilde calls to globally referenced @trace calls + expr = desugar_tildes(expr) + # Also resolve Gen macros to GlobalRefs for consistent downstream parsing + expr = resolve_gen_macros(expr, __module__) + # Reinsert quoted expressions after pre-processing + expr = insert_quoted_exprs(expr, quoted_exprs) + return expr +end + +function parse_gen_function(ast, annotations, __module__) ast = MacroTools.longdef(ast) if ast.head != :function error("syntax error at $ast in $(ast.head)") @@ -94,7 +155,6 @@ function parse_gen_function(ast, annotations) error("syntax error at $ast in $(ast.args)") end signature = ast.args[1] - body = desugar_tildes(ast.args[2]) if signature.head == :(::) (call_signature, return_type) = signature.args elseif signature.head == :call @@ -102,6 +162,7 @@ function parse_gen_function(ast, annotations) else error("syntax error at $(signature)") end + body = preprocess_body(ast.args[2], __module__) name = call_signature.args[1] args = map(parse_arg, call_signature.args[2:end]) static = DSL_STATIC_ANNOTATION in annotations @@ -112,15 +173,13 @@ function parse_gen_function(ast, annotations) end end -macro gen(annotations_expr, ast) - +macro gen(annotations_expr, ast::Expr) # parse the annotations annotations = parse_annotations(annotations_expr) - # parse the function definition - parse_gen_function(ast, annotations) + parse_gen_function(ast, annotations, __module__) end -macro gen(ast) - parse_gen_function(ast, Set{Symbol}()) +macro gen(ast::Expr) + parse_gen_function(ast, Set{Symbol}(), __module__) end diff --git a/src/dsl/static.jl b/src/dsl/static.jl index c9ed6958..812aad78 100644 --- a/src/dsl/static.jl +++ b/src/dsl/static.jl @@ -202,10 +202,12 @@ end "Parse and rewrite expression if it matches an @trace call." function parse_and_rewrite_trace!(stmts, bindings, expr) - if MacroTools.@capture(expr, @m_(f_(xs__), addr_)) && m == STATIC_DSL_TRACE + if (MacroTools.@capture(expr, @m_(f_(xs__), addr_)) && isa(m, GlobalRef) && + m.name == STATIC_DSL_TRACE && m.mod == @__MODULE__) # Parse "@trace(f(xs...), addr)" and return fresh variable parse_trace_expr!(stmts, bindings, f, xs, addr) - elseif MacroTools.@capture(expr, @m_(f_(xs__))) && m == STATIC_DSL_TRACE + elseif (MacroTools.@capture(expr, @m_(f_(xs__))) && isa(m, GlobalRef) && + m.name == STATIC_DSL_TRACE && m.mod == @__MODULE__) # Throw error for @trace expression without address static_dsl_syntax_error(expr, "Address required.") else @@ -219,12 +221,14 @@ function parse_static_dsl_line!(stmts, bindings, line) rewritten = MacroTools.postwalk( e -> parse_and_rewrite_trace!(stmts, bindings, e), line) # If line is a top-level @trace call, we are done - if MacroTools.@capture(line, @m_(f_(x__), a_)) && m == STATIC_DSL_TRACE + if (MacroTools.@capture(line, @m_(f_(x__), a_)) && isa(m, GlobalRef) && + m.name == STATIC_DSL_TRACE && m.mod == @__MODULE__) return end # Match and parse any other top-level expressions line = rewritten - if MacroTools.@capture(line, @m_ expr_) && m == STATIC_DSL_PARAM + if (MacroTools.@capture(line, @m_(expr_)) && isa(m, GlobalRef) && + m.name == STATIC_DSL_PARAM && m.mod == @__MODULE__) # Parse "@param var::T" parse_param_line!(stmts, bindings, expr) elseif MacroTools.@capture(line, lhs_ = rhs_) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index eba97a82..38bf2cca 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -119,7 +119,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) else # regular forward execution. - + # we need the value for initializing gradient to zero (to get the type # and e.g. shape), and for reference by other nodes during # back_codegen! we could be more selective about which JuliaNodes need @@ -266,7 +266,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, subtrace_fieldname = get_subtrace_fieldname(node) call_selection = gensym("call_selection") if node in selected_calls - push!(stmts, :($call_selection = $qn_static_getindex(selection, $(QuoteNode(Val(node.addr)))))) + push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_getindex))(selection, $(QuoteNode(Val(node.addr)))))) else push!(stmts, :($call_selection = EmptySelection())) end @@ -425,7 +425,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # assemble value_trie and gradient_trie value_trie = gensym("value_trie") gradient_trie = gensym("gradient_trie") - push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, + push!(stmts, generate_value_gradient_trie(selected_choices, selected_calls, value_trie, gradient_trie)) # gradients with respect to inputs @@ -434,7 +434,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # return values push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie))) - + Expr(:block, stmts...) end @@ -486,20 +486,20 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # return values push!(stmts, :(return $input_grads)) - + Expr(:block, stmts...) end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:choice_gradients)))(trace::T, selection::$(QuoteNode(Selection)), +@generated function $(GlobalRef(Gen, :choice_gradients))(trace::T, selection::$(QuoteNode(Selection)), retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_choice_gradients))(trace, selection, retval_grad) end end) push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:accumulate_param_gradients!)))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} +@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) end end) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 2beecca5..a7e9d42a 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -32,16 +32,16 @@ function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode incr = gensym("weight") subconstraints = gensym("subconstraints") if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) - push!(state.stmts, :($subconstraints = $qn_static_get_submap(constraints, Val($addr)))) - push!(state.stmts, :(($subtrace, $incr) = $qn_generate($gen_fn, $args_tuple, $subconstraints))) + push!(state.stmts, :($subconstraints = $(GlobalRef(Gen, :static_get_submap))(constraints, Val($addr)))) + push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($gen_fn, $args_tuple, $subconstraints))) else - push!(state.stmts, :(($subtrace, $incr) = $qn_generate($gen_fn, $args_tuple, $qn_empty_choice_map))) + push!(state.stmts, :(($subtrace, $incr) = $(GlobalRef(Gen, :generate))($gen_fn, $args_tuple, $(GlobalRef(Gen, :EmptyChoiceMap))()))) end push!(state.stmts, :($weight += $incr)) - push!(state.stmts, :($num_nonempty_fieldname += !$qn_isempty($qn_get_choices($subtrace)) ? 1 : 0)) - push!(state.stmts, :($(node.name) = $qn_get_retval($subtrace))) - push!(state.stmts, :($total_score_fieldname += $qn_get_score($subtrace))) - push!(state.stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection))) + push!(state.stmts, :($num_nonempty_fieldname += !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) ? 1 : 0)) + push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) + push!(state.stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace))) + push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end function codegen_generate(gen_fn_type::Type{T}, args, @@ -51,7 +51,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $qn_generate(gen_fn, args, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $(GlobalRef(Gen, :generate))(gen_fn, args, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) @@ -88,7 +88,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:generate)))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), +@generated function $(GlobalRef(Gen, :generate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap))) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index ed14a2f2..62df493a 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -12,10 +12,10 @@ function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) subtrace = get_subtrace_fieldname(node) subselection = gensym("subselection") if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($subselection = $qn_static_getindex(selection, Val($addr)))) - push!(state.stmts, :($weight += $qn_project(trace.$subtrace, $subselection))) + push!(state.stmts, :($subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) + push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $subselection))) else - push!(state.stmts, :($weight += $qn_project(trace.$subtrace, $qn_empty_selection))) + push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end end @@ -25,7 +25,7 @@ function codegen_project(trace_type::Type, selection_type::Type) # convert the selection to a static selection if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) - return quote $qn_project(trace, $(QuoteNode(StaticSelection))(selection)) end + return quote $(GlobalRef(Gen, :project))(trace, $(QuoteNode(StaticSelection))(selection)) end end ir = get_ir(gen_fn_type) @@ -48,11 +48,11 @@ end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:project)))(trace::T, selection::$(QuoteNode(Selection))) where {T <: $(QuoteNode(StaticIRTrace))} +@generated function $(GlobalRef(Gen, :project))(trace::T, selection::$(QuoteNode(Selection))) where {T <: $(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_project))(trace, selection) end -function $(Expr(:(.), Gen, QuoteNode(:project)))(trace::T, selection::$(QuoteNode(EmptySelection))) where {T <: $(QuoteNode(StaticIRTrace))} +function $(GlobalRef(Gen, :project))(trace::T, selection::$(QuoteNode(EmptySelection))) where {T <: $(QuoteNode(StaticIRTrace))} trace.$total_noise_fieldname end diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 267183ac..669f4075 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -28,10 +28,10 @@ function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode subtrace = get_subtrace_fieldname(node) incr = gensym("weight") push!(state.stmts, :($subtrace = $(QuoteNode(simulate))($gen_fn, $args_tuple))) - push!(state.stmts, :($num_nonempty_fieldname += !$qn_isempty($qn_get_choices($subtrace)) ? 1 : 0)) - push!(state.stmts, :($(node.name) = $qn_get_retval($subtrace))) - push!(state.stmts, :($total_score_fieldname += $qn_get_score($subtrace))) - push!(state.stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection))) + push!(state.stmts, :($num_nonempty_fieldname += !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) ? 1 : 0)) + push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) + push!(state.stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace))) + push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenerativeFunction} @@ -70,7 +70,7 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera end push!(generated_functions, quote -@generated function $(Expr(:(.), Gen, QuoteNode(:simulate)))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple) +@generated function $(GlobalRef(Gen, :simulate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple) $(QuoteNode(codegen_simulate))(gen_fn, args) end end) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index f63fbcdd..5b156d0a 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -51,14 +51,14 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati params_grad::Dict{Symbol,Any} params::Dict{Symbol,Any} end - (gen_fn::$gen_fn_type_name)(args...) = propose(gen_fn, args)[3] - $(Expr(:(.), Gen, QuoteNode(:get_ir)))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir)) - $(Expr(:(.), Gen, QuoteNode(:get_trace_type)))(::Type{$gen_fn_type_name}) = $trace_struct_name - $(Expr(:(.), Gen, QuoteNode(:has_argument_grads)))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads)) - $(Expr(:(.), Gen, QuoteNode(:accepts_output_grad)))(::$gen_fn_type_name) = $(QuoteNode(accepts_output_grad)) - $(Expr(:(.), Gen, QuoteNode(:get_gen_fn)))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) - $(Expr(:(.), Gen, QuoteNode(:get_gen_fn_type)))(::Type{$trace_struct_name}) = $gen_fn_type_name - $(Expr(:(.), Gen, QuoteNode(:get_options)))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) + (gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] + $(GlobalRef(Gen, :get_ir))(::Type{$gen_fn_type_name}) = $(QuoteNode(ir)) + $(GlobalRef(Gen, :get_trace_type))(::Type{$gen_fn_type_name}) = $trace_struct_name + $(GlobalRef(Gen, :has_argument_grads))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads)) + $(GlobalRef(Gen, :accepts_output_grad))(::$gen_fn_type_name) = $(QuoteNode(accepts_output_grad)) + $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) + $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name + $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) end Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) end @@ -74,31 +74,6 @@ const trace = gensym("trace") const weight = gensym("weight") const subtrace = gensym("subtrace") -# quoted values and function called in generated code (since generated code is -# evaluted in the user's Main module, not Gen) -const qn_isempty = QuoteNode(isempty) -const qn_get_score = QuoteNode(get_score) -const qn_get_retval = QuoteNode(get_retval) -const qn_project = QuoteNode(project) -const qn_logpdf = QuoteNode(logpdf) -const qn_get_choices = QuoteNode(get_choices) -const qn_random = QuoteNode(random) -const qn_simulate = QuoteNode(simulate) -const qn_generate = QuoteNode(generate) -const qn_update = QuoteNode(update) -const qn_regenerate = QuoteNode(regenerate) -const qn_strip_diff = QuoteNode(strip_diff) -const qn_get_diff = QuoteNode(get_diff) -const qn_Diffed = QuoteNode(Diffed) -const qn_unknown_change = QuoteNode(UnknownChange()) -const qn_no_change = QuoteNode(NoChange()) -const qn_get_internal_node = QuoteNode(get_internal_node) -const qn_static_get_value = QuoteNode(static_get_value) -const qn_static_get_submap = QuoteNode(static_get_submap) -const qn_static_getindex = QuoteNode(static_getindex) # for getting a subselection -const qn_empty_choice_map = QuoteNode(EmptyChoiceMap()) -const qn_empty_selection = QuoteNode(EmptySelection()) - include("simulate.jl") include("generate.jl") include("project.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 56209e84..a38cfdbd 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -104,7 +104,7 @@ end function generate_get_score(trace_struct_name::Symbol) Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_score)), :(trace::$trace_struct_name)), + Expr(:call, GlobalRef(Gen, :get_score), :(trace::$trace_struct_name)), Expr(:block, :(trace.$total_score_fieldname))) end @@ -112,13 +112,13 @@ function generate_get_args(ir::StaticIR, trace_struct_name::Symbol) args = Expr(:tuple, [:(trace.$(get_value_fieldname(node))) for node in ir.arg_nodes]...) Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_args)), :(trace::$trace_struct_name)), + Expr(:call, GlobalRef(Gen, :get_args), :(trace::$trace_struct_name)), Expr(:block, args)) end function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_retval)), :(trace::$trace_struct_name)), + Expr(:call, GlobalRef(Gen, :get_retval), :(trace::$trace_struct_name)), Expr(:block, :(trace.$return_value_fieldname))) end @@ -127,32 +127,32 @@ function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), get_choices($subtrace)))) + push!(elements, :(($(QuoteNode(addr)), $(GlobalRef(Gen, :get_choices))($subtrace)))) end - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_submaps_shallow)), + Expr(:function, + Expr(:call, GlobalRef(Gen, :get_submaps_shallow), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), Expr(:block, Expr(:tuple, elements...))) end -function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) +function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) get_subtrace_exprs = Expr[] for node in ir.call_nodes push!(get_subtrace_exprs, quote - function Gen.static_get_subtrace(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) + function $(GlobalRef(Gen, :static_get_subtrace))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) return trace.$(get_subtrace_fieldname(node)) end end ) end - + call_getindex_exprs = Expr[] for node in ir.call_nodes push!(call_getindex_exprs, quote - function Gen.static_getindex(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) - return get_retval(trace.$(get_subtrace_fieldname(node))) + function $(GlobalRef(Gen, :static_getindex))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) + return $(GlobalRef(Gen, :get_retval))(trace.$(get_subtrace_fieldname(node))) end end ) @@ -165,11 +165,11 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), + Expr(:call, GlobalRef(Gen, :static_get_submap), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), Expr(:block, - :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) + :($(GlobalRef(Gen, :get_choices))(choices.trace.$(get_subtrace_fieldname(node))))))) end methods @@ -178,7 +178,7 @@ end function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol) addrs = [QuoteNode(node.addr) for node in ir.call_nodes] Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_schema)), :(::Type{$trace_struct_name})), + Expr(:call, GlobalRef(Gen, :get_schema), :(::Type{$trace_struct_name})), Expr(:block, :($(QuoteNode(StaticAddressSchema))( Set{Symbol}([$(addrs...)]))))) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index ea4b17f9..b927072e 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -99,7 +99,7 @@ end # this pass is used to determine which JuliaNodes need to be re-run (their # return value is not currently cached in the trace) -struct BackwardPassState +struct BackwardPassState marked::Set{StaticIRNode} end @@ -136,8 +136,8 @@ end ######################## function arg_values_and_diffs_from_tracked_diffs(input_nodes) - arg_values = map((node) -> Expr(:call, qn_strip_diff, node.name), input_nodes) - arg_diffs = map((node) -> Expr(:call, qn_get_diff, node.name), input_nodes) + arg_values = map((node) -> Expr(:call, (GlobalRef(Gen, :strip_diff)), node.name), input_nodes) + arg_diffs = map((node) -> Expr(:call, (GlobalRef(Gen, :get_diff)), node.name), input_nodes) (arg_values, arg_diffs) end @@ -151,7 +151,7 @@ end function process_codegen!(stmts, ::ForwardPassState, ::BackwardPassState, node::ArgumentNode, ::AbstractUpdateMode, options) if options.track_diffs - push!(stmts, :($(get_value_fieldname(node)) = $qn_strip_diff($(node.name)))) + push!(stmts, :($(get_value_fieldname(node)) = $(GlobalRef(Gen, :strip_diff))($(node.name)))) else push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) end @@ -166,13 +166,13 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, # track diffs if run_it arg_values, arg_diffs = arg_values_and_diffs_from_tracked_diffs(node.inputs) - args = map((v, d) -> Expr(:call, qn_Diffed, v, d), arg_values, arg_diffs) + args = map((v, d) -> Expr(:call, (GlobalRef(Gen, :Diffed)), v, d), arg_values, arg_diffs) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...)))) elseif options.cache_julia_nodes - push!(stmts, :($(node.name) = $qn_Diffed(trace.$(get_value_fieldname(node)), $qn_no_change))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), $(GlobalRef(Gen, :NoChange))()))) end if options.cache_julia_nodes - push!(stmts, :($(get_value_fieldname(node)) = $qn_strip_diff($(node.name)))) + push!(stmts, :($(get_value_fieldname(node)) = $(GlobalRef(Gen, :strip_diff))($(node.name)))) end else @@ -206,30 +206,30 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, call_constraints = gensym("call_constraints") if node in fwd.constrained_or_selected_calls || node in fwd.input_changed if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_constraints = $qn_static_get_submap(constraints, Val($addr)))) + push!(stmts, :($call_constraints = $(GlobalRef(Gen, :static_get_submap))(constraints, Val($addr)))) else - push!(stmts, :($call_constraints = $qn_empty_choice_map)) + push!(stmts, :($call_constraints = $(GlobalRef(Gen, :EmptyChoiceMap))())) end - push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = - $qn_update($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) + push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = + $(GlobalRef(Gen, :update))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) push!(stmts, :($weight += $call_weight)) - push!(stmts, :($total_score_fieldname += $qn_get_score($subtrace) - $qn_get_score($prev_subtrace))) - push!(stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection) - $qn_project($prev_subtrace, $qn_empty_selection))) - push!(stmts, :(if !$qn_isempty($qn_get_choices($subtrace)) && $qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) + push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) + push!(stmts, :(if !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname += 1 end)) - push!(stmts, :(if $qn_isempty($qn_get_choices($subtrace)) && !$qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :(if $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname -= 1 end)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $(calldiff_var(node))))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(calldiff_var(node))))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end else push!(stmts, :($subtrace = $prev_subtrace)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $(QuoteNode(NoChange()))))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(QuoteNode(NoChange()))))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end end end @@ -251,30 +251,30 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, call_subselection = gensym("call_subselection") if node in fwd.constrained_or_selected_calls || node in fwd.input_changed if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_subselection = $qn_static_getindex(selection, Val($addr)))) + push!(stmts, :($call_subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) else - push!(stmts, :($call_subselection = $qn_empty_selection)) + push!(stmts, :($call_subselection = $(GlobalRef(Gen, :EmptySelection))())) end - push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node))) = - $qn_regenerate($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_subselection))) + push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node))) = + $(GlobalRef(Gen, :regenerate))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_subselection))) push!(stmts, :($weight += $call_weight)) - push!(stmts, :($total_score_fieldname += $qn_get_score($subtrace) - $qn_get_score($prev_subtrace))) - push!(stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection) - $qn_project($prev_subtrace, $qn_empty_selection))) - push!(stmts, :(if !$qn_isempty($qn_get_choices($subtrace)) && !$qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) + push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) + push!(stmts, :(if !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname += 1 end)) - push!(stmts, :(if $qn_isempty($qn_get_choices($subtrace)) && !$qn_isempty($qn_get_choices($prev_subtrace)) + push!(stmts, :(if $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) $num_nonempty_fieldname -= 1 end)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $(calldiff_var(node))))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(calldiff_var(node))))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end else push!(stmts, :($subtrace = $prev_subtrace)) if options.track_diffs - push!(stmts, :($(node.name) = $qn_Diffed($qn_get_retval($subtrace), $qn_no_change))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(GlobalRef(Gen, :NoChange))()))) else - push!(stmts, :($(node.name) = $qn_get_retval($subtrace))) + push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end end end @@ -289,7 +289,7 @@ end function unpack_arguments!(stmts::Vector{Expr}, arg_nodes::Vector{ArgumentNode}, options) if options.track_diffs arg_names = Symbol[arg_node.name for arg_node in arg_nodes] - push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(map))($qn_Diffed, args, argdiffs))) + push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(map))($(GlobalRef(Gen, :Diffed)), args, argdiffs))) else arg_names = Symbol[arg_node.name for arg_node in arg_nodes] push!(stmts, :($(Expr(:tuple, arg_names...)) = args)) @@ -298,8 +298,8 @@ end function generate_return_value!(stmts::Vector{Expr}, fwd::ForwardPassState, return_node::StaticIRNode, options) if options.track_diffs - push!(stmts, :($return_value_fieldname = $qn_strip_diff($(return_node.name)))) - push!(stmts, :($retdiff = $qn_get_diff($(return_node.name)))) + push!(stmts, :($return_value_fieldname = $(GlobalRef(Gen, :strip_diff))($(return_node.name)))) + push!(stmts, :($retdiff = $(GlobalRef(Gen, :get_diff))($(return_node.name)))) else push!(stmts, :($return_value_fieldname = $(return_node.name))) push!(stmts, :($retdiff = $(QuoteNode(return_node in fwd.value_changed ? UnknownChange() : NoChange())))) @@ -309,7 +309,7 @@ end function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) if options.track_diffs # note that the generative function is the last field - constructor_args = map((name) -> Expr(:call, QuoteNode(strip_diff), name), + constructor_args = map((name) -> Expr(:call, QuoteNode(strip_diff), name), fieldnames(trace_type)[1:end-1]) push!(stmts, :($trace = $(QuoteNode(trace_type))($(constructor_args...), $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref)))))) @@ -347,7 +347,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $qn_update(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $(GlobalRef(Gen, :update))(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) @@ -395,7 +395,7 @@ function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type: # convert a hierarchical selection to a static selection if it is not alreay one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) - return quote $qn_regenerate(trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end + return quote $(GlobalRef(Gen, :regenerate))(trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end end ir = get_ir(gen_fn_type) @@ -433,14 +433,14 @@ end let T = gensym() push!(generated_functions, quote - @generated function $(Expr(:(.), Gen, QuoteNode(:update)))(trace::$T, args::Tuple, argdiffs::Tuple, + @generated function $(GlobalRef(Gen, :update))(trace::$T, args::Tuple, argdiffs::Tuple, constraints::$(QuoteNode(ChoiceMap))) where {$T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_update))(trace, args, argdiffs, constraints) end end) push!(generated_functions, quote - @generated function $(Expr(:(.), Gen, QuoteNode(:regenerate)))(trace::$T, args::Tuple, argdiffs::Tuple, + @generated function $(GlobalRef(Gen, :regenerate))(trace::$T, args::Tuple, argdiffs::Tuple, selection::$(QuoteNode(Selection))) where {$T<:$(QuoteNode(StaticIRTrace))} $(QuoteNode(codegen_regenerate))(trace, args, argdiffs, selection) end diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl index 88392cc0..a31223f9 100644 --- a/test/benchmarks/dynamic_mh.jl +++ b/test/benchmarks/dynamic_mh.jl @@ -74,4 +74,4 @@ do_inference(xs, ys, 10) @time do_inference(xs, ys, 20) println() -end \ No newline at end of file +end diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl index 0e801631..8c4c5aca 100644 --- a/test/benchmarks/static_mh.jl +++ b/test/benchmarks/static_mh.jl @@ -84,4 +84,4 @@ println("Simple static DSL (including CallAt nodes) MH on regression model:") @time do_inference(xs, ys, 50) @time do_inference(xs, ys, 50) println() -end \ No newline at end of file +end diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 5d311df6..24a8072e 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -567,4 +567,25 @@ tr, w = generate(MyModuleB.foo, (0,), choicemap(:y => 1)) end +@testset "static gen function choicemaps" begin +@gen (static) function bar2() + b ~ normal(0, 1) + return b +end +@gen (static) function bar1() + a ~ bar2() + x ~ normal(0, 1) + return x +end +Gen.load_generated_functions() +tr = simulate(bar1, ()) +ch = get_choices(tr) +@test has_value(ch, :x) +@test !has_value(ch, :y) +@test has_value(get_submap(ch, :a), :b) +@test get_submap(ch, :y) == EmptyChoiceMap() +@test length(collect(get_values_shallow(ch))) == 1 +@test length(collect(get_submaps_shallow(ch))) == 2 +end + end # @testset "static DSL" diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index fbd528b7..3f350196 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,9 +1,10 @@ using Gen import MacroTools -normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) - +@testset "tilde syntax" begin +normalize(ex) = + MacroTools.prewalk(MacroTools.rmlines, Gen.resolve_gen_macros(ex, Main)) # dynamic @testset "tilde syntax smoke test (dynamic)" begin @@ -73,17 +74,32 @@ end @testset "tilde syntax desugars as expected (static)" begin -expected = normalize(:( -@gen (static) function foo() - x = @trace(normal(0, 1), :x) - y = @trace(normal(0, 1), :y) -end)) - -actual = normalize(Gen.desugar_tildes(:( -@gen (static) function foo() - x ~ normal(0, 1) - y = ({:y} ~ normal(0, 1)) -end))) - -@test actual == expected + expected = normalize(:( + @gen (static) function foo() + x = @trace(normal(0, 1), :x) + y = @trace(normal(0, 1), :y) + end)) + + actual = normalize(Gen.desugar_tildes(:( + @gen (static) function foo() + x ~ normal(0, 1) + y = ({:y} ~ normal(0, 1)) + end))) + + @test actual == expected +end + +@testset "tilde syntax preserved in quoted expressions" begin + @gen function tilde_expr() + return :(x ~ normal(0, 1)) + end + @test tilde_expr() == :(x ~ normal(0, 1)) + + @gen (static) function tilde_expr() + return :(x ~ normal(0, 1)) + end + Gen.load_generated_functions() + @test tilde_expr() == :(x ~ normal(0, 1)) +end + end From b1c4fc212c7ba57a8a6d4e76b76899d6ca6e2d0d Mon Sep 17 00:00:00 2001 From: George Matheos Date: Fri, 3 Jul 2020 19:01:47 -0400 Subject: [PATCH 31/34] address_tree v0 --- .../Untitled-checkpoint.ipynb | 11 + src/address_tree/address_schema.jl | 21 ++ src/address_tree/address_tree.jl | 208 ++++++++++++++++++ src/address_tree/array_interface.jl | 110 +++++++++ src/address_tree/choicemap.jl | 130 +++++++++++ src/address_tree/dynamic_address_tree.jl | 95 ++++++++ src/address_tree/nested_view.jl | 82 +++++++ src/address_tree/selection.jl | 77 +++++++ src/address_tree/static_address_tree.jl | 178 +++++++++++++++ 9 files changed, 912 insertions(+) create mode 100644 src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb create mode 100644 src/address_tree/address_schema.jl create mode 100644 src/address_tree/address_tree.jl create mode 100644 src/address_tree/array_interface.jl create mode 100644 src/address_tree/choicemap.jl create mode 100644 src/address_tree/dynamic_address_tree.jl create mode 100644 src/address_tree/nested_view.jl create mode 100644 src/address_tree/selection.jl create mode 100644 src/address_tree/static_address_tree.jl diff --git a/src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 00000000..a11a62ac --- /dev/null +++ b/src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,11 @@ +{ + "cells": [], + "metadata": { + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/address_tree/address_schema.jl b/src/address_tree/address_schema.jl new file mode 100644 index 00000000..b0deb157 --- /dev/null +++ b/src/address_tree/address_schema.jl @@ -0,0 +1,21 @@ +abstract type AddressSchema end + +struct StaticAddressSchema <: AddressSchema + keys::Set{Symbol} +end + +Base.keys(schema::StaticAddressSchema) = schema.keys + +struct VectorAddressSchema <: AddressSchema end +struct SingleDynamicKeyAddressSchema <: AddressSchema end +struct DynamicAddressSchema <: AddressSchema end +struct EmptyAddressSchema <: AddressSchema end +struct AllAddressSchema <: AddressSchema end + +export AddressSchema +export StaticAddressSchema # hierarchical +export VectorAddressSchema # hierarchical +export SingleDynamicKeyAddressSchema # hierarchical +export DynamicAddressSchema # hierarchical +export EmptyAddressSchema +export AllAddressSchema \ No newline at end of file diff --git a/src/address_tree/address_tree.jl b/src/address_tree/address_tree.jl new file mode 100644 index 00000000..1f5a6138 --- /dev/null +++ b/src/address_tree/address_tree.jl @@ -0,0 +1,208 @@ +include("address_schema.jl") + +""" + AddressTree{LeafType} + +Abstract type for trees where each node's subtrees are labelled with +an address. All leaf nodes are of `LeafType` (or are an `EmptyAddressTree`). +""" +abstract type AddressTree{LeafType} end + +""" + AddressTreeLeaf + +Abstract type for address tree leaf nodes. + +## Note: +When declaring a subtype `T` of `AddressTreeLeaf`, +declare `T <: AddressTreeLeaf{T}` to ensure +`T <: AddressTree{T}`. +""" +abstract type AddressTreeLeaf{Type} <: AddressTree{Type} end + +""" + EmptyAddressTree + +An empty address tree with no subtrees. +""" +struct EmptyAddressTree <: AddressTreeLeaf{EmptyAddressTree} end + +""" + Value{T} + +An address tree leaf node storing a value of type `T`. +""" +struct Value{T} <: AddressTreeLeaf{Value{T}} + val::T +end +@inline get_value(v::Value) = v.val + +""" + AllSelection + +An address tree leaf node representing that all sub-addresses +from this point are selected. +""" +struct AllSelection <: AddressTreeLeaf{AllSelection} end + +""" + get_subtree(tree::AddressTree, addr) + +Get the subtree at address `addr` or return `EmptyAddressTree` +if there is no subtree at this address. +""" +function get_subtree end + +function _get_subtree(t::AddressTree, addr::Pair) + get_subtree(get_subtree(t, addr.first), addr.second) +end + +""" + get_subtrees_shallow(tree::AddressTree) + +Return an iterator over tuples `(address, subtree)` for each +top-level address associated with `tree`. + +The length of this iterator must nonzero if this is not a leaf node. +""" +function get_subtrees_shallow end + +get_leaf_type(T::Type{AddressTree{U}}) where {U} = U + +""" +schema = get_address_schema(::Type{T}) where {T <: AddressTree} + +Return the (top-level) address schema for the given address tree type. +""" +function get_address_schema end +@inline get_address_schema(::Type{EmptyAddressTree}) = EmptyAddressSchema() +@inline get_address_schema(::Type{AllSelection}) = AllAddressSchema() + +@inline get_address_schema(::Type{Value}) = error("I don't think this currently gets called, and it's not part of the user-facing interface. If we need this, set the appropriate value then.") + +Base.isempty(::Value) = false +Base.isempty(::AllSelection) = false +Base.isempty(::EmptyAddressTree) = true +Base.isempty(::AddressTreeLeaf) = error("Not implemented") +Base.isempty(t::AddressTree) = all((_, subtree) -> isempty(subtree), get_subtrees_shallow(t)) + +@inline get_subtree(::AddressTreeLeaf, _) = EmptyAddressTree() +@inline get_subtrees_shallow(::AddressTreeLeaf) = () + +@inline Base.:(==)(a::Value, b::Value) = a.val == b.val +function Base.:(==)(a::AddressTree, b::AddressTree) + for (addr, subtree) in get_subtrees_shallow(a) + if get_subtree(b, addr) != subtree + return false + end + end + for (addr, subtree) in get_subtrees_shallow(b) + if get_subtree(a, addr) != subtree + return false + end + end + return true +end + +Base.isapprox(a::Value, b::Value) = isapprox(a.val, b.val) +Base.isapprox(::EmptyAddressTree, ::EmptyAddressTree) = true +Base.isapprox(::AllSelection, ::AllSelection) = true +function Base.isapprox(::AddressTreeLeaf{T}, ::AddressTreeLeaf{U}) where {T, U} + if T != U + false + else + error("Not implemented") + end +end +function Base.isapprox(a::AddressTree, b::AddressTree) + for (addr, subtree) in get_subtrees_shallow(a) + if !isapprox(get_subtree(b, addr), subtree) + return false + end + end + for (addr, subtree) in get_subtrees_shallow(b) + if !isapprox(get_subtree(a, addr), subtree) + return false + end + end + return true +end + +""" + Base.merge(a::AddressTree, b::AddressTree) + +Merge two address trees. +""" +function Base.merge(a::AddressTree{T}, b::AddressTree{U}) where {T, U} + tree = DynamicAddressTree{<:Union{T, U}}() + for (key, subtree) in get_subtrees_shallow(a) + set_subtree!(tree, key, merge(subtree, get_subtree(b, key))) + end + for (key, subtree) in get_subtrees_shallow(b) + if isempty(get_subtree(a, key)) + set_submap!(tree, key, subtree) + end + end + tree +end +Base.merge(t::AddressTree, ::EmptyAddressTree) = t +Base.merge(::EmptyAddressTree, t::AddressTree) = t + +Base.merge(::AddressTreeLeaf, ::AddressTree) = error("Not implemented") +Base.merge(::AddressTree, ::AddressTreeLeaf) = error("Not implemented") + +""" +Variadic merge of address trees. +""" +function Base.merge(first::AddressTree, rest::AddressTree...) + reduce(Base.merge, rest; init=first) +end + +function _show_pretty(io::IO, tree::AddressTree, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_subtrees = collect(get_subtrees_shallow(tree)) + n = length(key_and_subtrees) + cur = 1 + for (key, subtree) in key_and_subtrees + print(io, indent_vert_str) + if subtree isa AddressTreeLeaf + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $subtree\n") + else + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, subtree, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + end + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", tree::AddressTree) + _show_pretty(io, tree, 0, ()) +end +Base.show(io::IO, ::MIME"text/plain", t::AddressTreeLeaf) = print(io, t) + +include("dynamic_address_tree.jl") +include("static_address_tree.jl") + +include("choicemap.jl") +include("selection.jl") + +# include("array_interface.jl") +# include("nested_view.jl") + +export get_subtree, get_subtrees_shallow, EmptyAddressTree, Value, AllSelection \ No newline at end of file diff --git a/src/address_tree/array_interface.jl b/src/address_tree/array_interface.jl new file mode 100644 index 00000000..3fac3865 --- /dev/null +++ b/src/address_tree/array_interface.jl @@ -0,0 +1,110 @@ +### interface for to_array and fill_array ### + +# NOTE: currently this only works for choicemaps, +# but if we found we needed some sort of "to_array" for other types of +# address trees, I don't think it would be too hard to generalize + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: ChoiceMap` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::Value{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) < start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end +function _fill_array!(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::Value, arr::Vector, start_idx::Int) + (1, Value(arr[start_idx])) +end +function _from_array(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, Value(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/address_tree/choicemap.jl b/src/address_tree/choicemap.jl new file mode 100644 index 00000000..86e9b1fe --- /dev/null +++ b/src/address_tree/choicemap.jl @@ -0,0 +1,130 @@ +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + ChoiceMap + +Abstract type for maps from hierarchical addresses to values. +""" +const ChoiceMap = AddressTree{<:Value} + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `ValueChoiceMap`s.) +""" +@inline get_submaps_shallow(c::ChoiceMap) = get_subtrees_shallow(c) + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +@inline get_submap(c::ChoiceMap, addr) = get_subtree(c, addr) + +""" + has_value(choices::ChoiceMap) + +Returns true if `choices` is a `Value`. + + has_value(choices::ChoiceMap, addr) + +Returns true if `choices` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(::ChoiceMap) = false +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; +throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `ChoiceMapGetValueError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `ValueChoiceMap`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) +end + +# support `DynamicChoiceMap` and `StaticChoiceMap` types, and the "legacy" DynamicChoiceMap interface +const DynamicChoiceMap = DynamicAddressTree{Value} +set_submap!(cm::DynamicChoiceMap, addr, submap::ChoiceMap) = set_subtree!(cm, addr, submap) +set_value!(cm::DynamicChoiceMap, addr, val) = set_subtree!(cm, addr, Value(val)) +Base.setindex!(cm::DynamicChoiceMap, addr, val) = set_value!(cm, addr, val) + +const StaticChoiceMap = StaticAddressTree{Value} + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given (address, value) tuples. +(Where `value` is the value to be stored, not a `Value` object.) +""" +function choicemap(tuples...) + cm = DynamicChoiceMap() + for (addr, val) in tuples + set_subtree!(cm, addr, Value(val)) + end + cm +end + +export choicemap +export get_value, has_value +export get_values_shallow, get_submaps_shallow, get_nonvalue_submaps_shallow + +include("array_interface.jl") \ No newline at end of file diff --git a/src/address_tree/dynamic_address_tree.jl b/src/address_tree/dynamic_address_tree.jl new file mode 100644 index 00000000..2ffb1ce9 --- /dev/null +++ b/src/address_tree/dynamic_address_tree.jl @@ -0,0 +1,95 @@ +""" + struct DynamicAddressTree <: AddressTree .. end + +A mutable AddressTree. + + tree = DynamicAddressTree() + +Construct an empty address tree. + +""" +struct DynamicAddressTree{LeafType} <: AddressTree{LeafType} + subtrees::Dict{Any, AddressTree{<:LeafType}} + function DynamicAddressTree{LeafType}() where {LeafType} + new{LeafType}(Dict{Any, AddressTree}()) + end +end + +""" + tree = address_tree() + +Construct an empty, mutable address tree. +""" +address_tree() = DynamicAddressTree{Any}() + +get_address_schema(::Type{<:DynamicAddressTree}) = DynamicAddressTree + +@inline get_subtrees_shallow(t::DynamicAddressTree) = t.subtrees +@inline get_subtree(t::DynamicAddressTree, addr) = get(t.subtrees, addr, EmptyAddressTree()) +@inline get_subtree(t::DynamicAddressTree, addr::Pair) = _get_subtree(t, addr) +@inline Base.isempty(t::DynamicAddressTree) = isempty(t.subtrees) + +function set_subtree!(t::DynamicAddressTree, addr, new_node::AddressTree) + delete!(t.subtrees, addr) + if !isempty(new_node) + t.subtrees[addr] = new_node + end +end +function set_subtree!(t::DynamicAddressTree{T}, addr::Pair, new_node::AddressTree) where {T} + (first, rest) = addr + if !haskey(t.subtrees, first) + t.subtrees[first] = DynamicAddressTree{T}() + end + set_subtree!(t.subtrees[first], rest, new_node) +end + +""" + tree = shallow_dynamic_copy(other::AddressTree) + +Make a shallow `DynamicAddressTree` copy of the given address tree. +""" +function shallow_dynamic_copy(other::AddressTree{LeafType}) where {LeafType} + tree = DynamicAddressTree{LeafType}() + for (addr, subtree) in get_subtrees_shallow(other) + set_subtree!(tree, addr, subtree) + end + tree +end + +""" + tree = deep_dynamic_copy(other::AddressTree) + +Make a deep copy of the given address tree, where every non-leaf-node +is a `DynamicAddressTree`. +""" +function deep_dynamic_copy(other::AddressTree{LeafType}) where {LeafType} + tree = DynamicAddressTree{LeafType}() + for (addr, subtree) in get_subtrees_shallow(other) + if subtree isa AddressTreeLeaf + set_subtree!(tree, addr, subtree) + else + set_subtree!(tree, addr, DynamicAddressTree(subtree)) + end + end + tree +end + +""" + tree = DynamicAddressTree(other::AddressTree) + +Shallowly convert an address tree to dynamic. +""" +DynamicAddressTree(t::AddressTree) = shallow_dynamic_copy(t) +DynamicAddressTree(t::DynamicAddressTree) = t + +function _from_array(proto_choices::DynamicAddressTree{LT}, arr::Vector{T}, start_idx::Int) where {T, LT} + choices = DynamicAddressTree{LT}() + keys_sorted = sort(collect(keys(proto_choices.subtrees))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.subtrees[key], arr, idx) + idx += n_read + choices.subtrees[key] = submap + end + (idx - start_idx, choices) +end diff --git a/src/address_tree/nested_view.jl b/src/address_tree/nested_view.jl new file mode 100644 index 00000000..40593c16 --- /dev/null +++ b/src/address_tree/nested_view.jl @@ -0,0 +1,82 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +# TODO: augment this to work in a form for any address tree? + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::Value) = get_value(cm) +ChoiceMapNestedView(::EmptyAddressTree) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/address_tree/selection.jl b/src/address_tree/selection.jl new file mode 100644 index 00000000..d8e6925a --- /dev/null +++ b/src/address_tree/selection.jl @@ -0,0 +1,77 @@ +const Selection = AddressTree{AllSelection} + +const StaticSelection = StaticAddressTree{AllSelection} +const DynamicSelection = DynamicAddressTree{AllSelection} + +""" + in(addr, selection::Selection) + +Whether the address is selected in the given selection. +""" +function Base.in(addr, selection::Selection) + get_subtree(selection, addr) === AllSelection() +end + +# indexing returns subtrees for selections +function Base.getindex(selection::Selection, addr) + get_subtree(selection, addr) +end + +function select(addrs...) + selection = DynamicSelection() + for addr in addrs + set_subtree!(selection, addr, AllSelection()) + end + selection +end + +Base.merge(::AllSelection, ::Selection) = AllSelection() +Base.merge(::Selection, ::AllSelection) = AllSelection() +Base.merge(::AllSelection, ::AllSelection) = AllSelection() + +""" + filtered = SelectionFilteredAddressTree(tree, selection) + +An address tree containing only the nodes in `tree` whose addresses are selected +in `selection.` +""" +struct SelectionFilteredAddressTree{T} <: AddressTree{T} + tree::AddressTree{T} + sel::Selection +end +SelectionFilteredAddressTree(t::AddressTree, ::AllSelection) = t +SelectionFilteredAddressTree(t::AddressTreeLeaf, ::AllSelection) = t +SelectionFilteredAddressTree(::AddressTree, ::EmptyAddressTree) = EmptyAddressTree() +SelectionFilteredAddressTree(::AddressTreeLeaf, ::EmptyAddressTree) = EmptyAddressTree() +SelectionFilteredAddressTree(::AddressTreeLeaf, ::Selection) = EmptyAddressTree() # if we hit a leaf node before a selected value, the node is not selected + +function get_subtree(t::SelectionFilteredAddressTree, addr) + subselection = get_subtree(t.sel, addr) + if subselection === EmptyAddressTree() + EmptyAddressTree() + else + SelectionFilteredAddressTree(get_subtree(t.tree, addr), subselection) + end +end + +function get_subtrees_shallow(t::SelectionFilteredAddressTree) + all_selected_including_empty = ( + (addr, SelectionFilteredAddressTree(subtree, get_subtree(t.sel, addr))) + for (addr, subtree) in get_subtrees_shallow(t.tree) + ) + + return ( + (addr, tree) for (addr, tree) in all_selected_including_empty + if tree !== EmptyAddressTree() + ) +end + +""" + selected = get_selected(tree::AddressTree, selection::Selection) + +Filter the address tree `tree` to only include leaf nodes at selected +addresses. +""" +get_selected(tree::AddressTree, selection::Selection) = SelectionFilteredAddressTree(tree, selection) + +export select, get_selected \ No newline at end of file diff --git a/src/address_tree/static_address_tree.jl b/src/address_tree/static_address_tree.jl new file mode 100644 index 00000000..282ccb61 --- /dev/null +++ b/src/address_tree/static_address_tree.jl @@ -0,0 +1,178 @@ +struct StaticAddressTree{LeafType, Addrs, SubtreeTypes} <: AddressTree{LeafType} + subtrees::NamedTuple{Addrs, SubtreeTypes} + function StaticAddressTree{LeafType}(nt::NamedTuple{Addrs, Subtrees}) where { + LeafType, Addrs, Subtrees <: Tuple{Vararg{<:AddressTree{<:LeafType}}} + } + new{LeafType, Addrs, Subtrees}(nt) + end +end + +# NOTE: It is probably better to avoid using this constructor when possible since I suspect it is less performant +# than if we specify `LeafType`. +# I could make this into a generated function...this would probably improve runtime performance but hurt compiletime performance. +function StaticAddressTree(subtrees::NamedTuple{Addrs, SubtreeTypes}) where {Addrs, SubtreeTypes <: Tuple{Vararg{AddressTree}}} + uniontype = Union{SubtreeTypes.parameters...} + StaticAddressTree{uniontype}(subtrees) +end +""" + StaticAddressTree{LeafType}(; a=val, b=tree, ...) + StaticAddressTree(; a=val, b=tree, ...) + +Construct a static address tree with the given address-subtree +or address-value pairs. (The addresses must be top-level symbols; +if the RHS is an AddressTree, this will be the subtree; if not, the +subtree will be a `Value` with the given value.) +""" +StaticAddressTree(;addrs_to_vals_and_trees...) = StaticAddressTree(addrs_subtrees_namedtuple(addrs_to_vals_and_trees)) +StaticAddressTree{LeafType}(; addrs_to_vals_and_trees...) where {LeafType} = StaticAddressTree{LeafType}(addrs_subtrees_namedtuple(addrs_to_vals_and_trees)) + +function addrs_subtrees_namedtuple(addrs_to_vals_and_trees) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_trees) + trees = Tuple(val_or_map isa AddressTree ? val_or_map : Value(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_trees) + NamedTuple{addrs}(trees) +end + +@inline get_subtrees_shallow(t::StaticAddressTree) = pairs(t.subtrees) +@inline get_submap(t::StaticAddressTree, addr::Pair) = _get_subtree(t, addr) + +function get_subtree(t::StaticAddressTree{LeafType, Addrs}, addr::Symbol) where {LeafType, Addrs} + if addr in Addrs + t.subtrees[addr] + else + EmptyAddressTree() + end +end + +@generated function static_get_subtree(t::StaticAddressTree{LeafType, Addrs}, ::Val{A}) where {A, Addrs, LeafType} + if A in Addrs + quote t.subtrees[A] end + else + quote EmptyAddressTree() end + end +end +@inline static_get_subtree(::EmptyAddressTree, ::Val) = EmptyAddressTree() + +@inline static_get_value(choices::StaticAddressTree, v::Val) = get_value(static_get_subtree(choices, v)) +@inline static_get_value(::EmptyAddressTree, ::Val) = throw(ChoiceMapGetValueError()) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +StaticAddressTree(t::StaticAddressTree) = t +function StaticAddressTree(other::AddressTree{LeafType}) where {LeafType} + keys_and_nodes = get_subtrees_shallow(other) + if length(keys_and_nodes) > 0 + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) + else + addrs = () + submaps = () + end + StaticAddressTree{LeafType}(NamedTuple{addrs}(submaps)) +end +StaticAddressTree(::AddressTreeLeaf) = error("Cannot convert a leaf node to a static address tree.") +StaticAddressTree{LeafType}(::NamedTuple{(),Tuple{}}) where {LeafType} = EmptyAddressTree() +StaticAddressTree(::NamedTuple{(),Tuple{}}) = EmptyAddressTree() +StaticAddressTree{LeafType}(other::AddressTree{<:LeafType}) where {LeafType} = StaticAddressTree(other) + +# TODO: deep conversion to static choicemap + +""" + tree = pair(tree1::AddressTree, tree2::AddressTree, key1::Symbol, key2::Symbol) + +Return an address tree that contains `tree1` as a subtree under `key1` +and `tree2` as a subtree under `key2`. +""" +function pair(tree1::AddressTree, tree2::AddressTree, key1::Symbol, key2::Symbol) + StaticAddressTree(NamedTuple{(key1, key2)}((tree1, tree2))) +end + +""" + (tree1, tree2) = unpair(tree::AddressTree, key1::Symbol, key2::Symbol) + +Return the two subtrees at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any subtrees at keys other than `key1` and `key2`. +""" +function unpair(tree::AddressTree, key1::Symbol, key2::Symbol) + if length(collect(get_subtrees_shallow(tree))) != 2 + error("Not a pair") + end + (get_subtree(tree, key1), get_subtree(tree, key2)) +end + +@generated function Base.merge(tree1::StaticAddressTree{T1, Addrs1, SubmapTypes1}, + tree2::StaticAddressTree{T2, Addrs2, SubmapTypes2}) where {T1, T2, Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, Type{<:AddressTree}}() + addr_to_type2 = Dict{Symbol, Type{<:AddressTree}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyAddressTree) + type2 = get(addr_to_type2, addr, EmptyAddressTree) + + if type1 <: EmptyAddressTree + push!(submap_exprs, + quote choices2.submaps.$addr end + ) + elseif type2 <: EmptyAddressTree + push!(submap_exprs, + quote choices1.submaps.$addr end + ) + else + push!(submap_exprs, + quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end + ) + end + end + + leaftype = Union{T1, T2} + + quote + StaticAddressTree{$leaftype}(NamedTuple{$merged_addrs}(($(submap_exprs...),))) + end +end + +@generated function _from_array(proto_choices::StaticAddressTree{LT, Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {LT, T, Addrs, SubmapTypes} + + perm = sortperm(collect(Addrs)) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, $submap_var_name) = _from_array(proto_choices.subtrees.$addr, arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticAddressTree{LT}(submaps) + (idx - start_idx, choices) + end +end + +function get_address_schema(::Type{StaticAddressTree{LT, Addrs}}) where {LT, Addrs} + StaticAddressSchema(Set(Addrs)) +end + +export StaticAddressTree +export pair, unpair +export static_get_subtree, static_get_value \ No newline at end of file From 3d1b519c56b32a322b133b9d452d0216ae460ac7 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 5 Jul 2020 15:44:40 -0400 Subject: [PATCH 32/34] bug fixing --- docs/src/ref/choice_maps.md | 6 +-- docs/src/ref/distributions.md | 2 +- src/Gen.jl | 8 ++-- src/address_tree/address_tree.jl | 19 ++++++-- src/address_tree/choicemap.jl | 45 ++++++++++------- src/address_tree/dynamic_address_tree.jl | 3 ++ src/address_tree/selection.jl | 8 +++- src/address_tree/static_address_tree.jl | 6 +-- src/{ => deprecated}/address.jl | 0 .../choice_map/array_interface.jl | 12 ++--- src/{ => deprecated}/choice_map/choice_map.jl | 38 +++++++-------- .../choice_map/dynamic_choice_map.jl | 6 +-- .../choice_map/nested_view.jl | 2 +- .../choice_map/static_choice_map.jl | 8 ++-- src/distribution.jl | 10 ++-- src/dynamic/dynamic.jl | 8 ++-- src/modeling_library/call_at/call_at.jl | 2 +- src/modeling_library/vector.jl | 2 +- src/static_ir/backprop.jl | 4 +- src/trie.jl | 2 + test/assignment.jl | 48 +++++++++---------- test/modeling_library/map.jl | 2 +- test/runtests.jl | 14 +++--- test/selection.jl | 34 +++---------- 24 files changed, 147 insertions(+), 142 deletions(-) rename src/{ => deprecated}/address.jl (100%) rename src/{ => deprecated}/choice_map/array_interface.jl (88%) rename src/{ => deprecated}/choice_map/choice_map.jl (88%) rename src/{ => deprecated}/choice_map/dynamic_choice_map.jl (95%) rename src/{ => deprecated}/choice_map/nested_view.jl (97%) rename src/{ => deprecated}/choice_map/static_choice_map.jl (93%) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 4a23b7cf..db30bed7 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -11,7 +11,7 @@ Choice maps are also returned by certain Gen inference methods, and are used int A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses for sub-choicemaps. Leaf nodes have type: ```@docs -ValueChoiceMap +Value ``` ### Example Usage Overview @@ -27,7 +27,7 @@ value = choicemap[:a] ``` A choicemap may also have a non-value choicemap stored at an address. For instance, if a choicemap has another choicemap stored at address `:a`, and this internal choicemap -has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: +has a Value stored at address `:b` and another at `:c`, we could perform the following lookups: ```julia value1 = choicemap[:a => :b] value2 = choicemap[:a => :c] @@ -47,7 +47,7 @@ value1 == submap[:b] # is true value_submap = get_submap(choicemap, :a => :b) value_submap[] == value1 # is true ``` -One can think of `ValueChoiceMap`s at storing being a choicemap which has a value at "nesting level zero", +One can think of `Value`s at storing being a choicemap which has a value at "nesting level zero", while other choicemaps have values at "nesting level" one or higher. ### Interface diff --git a/docs/src/ref/distributions.md b/docs/src/ref/distributions.md index f4b6d768..a928fabf 100644 --- a/docs/src/ref/distributions.md +++ b/docs/src/ref/distributions.md @@ -2,7 +2,7 @@ In Gen, a probability distribution is a generative function which makes a single random choice and returns the value of this choice. The choicemap for a probability distribution -is always a [`ValueChoiceMap`](@ref). In addition to supporting the regular `GFI` methods, +is always a [`Value`](@ref). In addition to supporting the regular `GFI` methods, every distribution supports the methods [`random`](@ref) and [`logpdf`](@ref), described in the [Distribution API](@ref custom_distributions). diff --git a/src/Gen.jl b/src/Gen.jl index fe3cdaa1..22623b04 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -33,11 +33,9 @@ export load_generated_functions, @load_generated_functions # built-in extensions to the reverse mode AD include("backprop.jl") -# addresses and address selections -include("address.jl") - -# abstract and built-in concrete choice map data types -include("choice_map/choice_map.jl") +# address and address trees +# (including choicemaps and selections) +include("address_tree/address_tree.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") diff --git a/src/address_tree/address_tree.jl b/src/address_tree/address_tree.jl index 1f5a6138..1249fe1c 100644 --- a/src/address_tree/address_tree.jl +++ b/src/address_tree/address_tree.jl @@ -32,10 +32,13 @@ struct EmptyAddressTree <: AddressTreeLeaf{EmptyAddressTree} end An address tree leaf node storing a value of type `T`. """ -struct Value{T} <: AddressTreeLeaf{Value{T}} +struct Value{T} <: AddressTreeLeaf{Value} val::T end @inline get_value(v::Value) = v.val +# Note that we don't set `Value{T} <: AddressTreeLeaf{Value{T}}`; +# this complicates the type heirarchicy unnecessarily and results in +# unintuitive phenomena, such as `Value <: AddressTreeLeaf{Value} == false`. """ AllSelection @@ -89,7 +92,8 @@ Base.isempty(t::AddressTree) = all((_, subtree) -> isempty(subtree), get_subtree @inline get_subtree(::AddressTreeLeaf, _) = EmptyAddressTree() @inline get_subtrees_shallow(::AddressTreeLeaf) = () -@inline Base.:(==)(a::Value, b::Value) = a.val == b.val +@inline get_subtree(::AllSelection, _) = AllSelection() + function Base.:(==)(a::AddressTree, b::AddressTree) for (addr, subtree) in get_subtrees_shallow(a) if get_subtree(b, addr) != subtree @@ -103,6 +107,9 @@ function Base.:(==)(a::AddressTree, b::AddressTree) end return true end +@inline Base.:(==)(a::Value, b::Value) = a.val == b.val +Base.:(==)(a::AddressTreeLeaf, b::AddressTreeLeaf) = false +Base.:(==)(::T, ::T) where {T <: AddressTreeLeaf} = true Base.isapprox(a::Value, b::Value) = isapprox(a.val, b.val) Base.isapprox(::EmptyAddressTree, ::EmptyAddressTree) = true @@ -134,7 +141,7 @@ end Merge two address trees. """ function Base.merge(a::AddressTree{T}, b::AddressTree{U}) where {T, U} - tree = DynamicAddressTree{<:Union{T, U}}() + tree = DynamicAddressTree{Union{T, U}}() for (key, subtree) in get_subtrees_shallow(a) set_subtree!(tree, key, merge(subtree, get_subtree(b, key))) end @@ -147,6 +154,8 @@ function Base.merge(a::AddressTree{T}, b::AddressTree{U}) where {T, U} end Base.merge(t::AddressTree, ::EmptyAddressTree) = t Base.merge(::EmptyAddressTree, t::AddressTree) = t +Base.merge(t::AddressTreeLeaf, ::EmptyAddressTree) = t +Base.merge(::EmptyAddressTree, t::AddressTreeLeaf) = t Base.merge(::AddressTreeLeaf, ::AddressTree) = error("Not implemented") Base.merge(::AddressTree, ::AddressTreeLeaf) = error("Not implemented") @@ -205,4 +214,6 @@ include("selection.jl") # include("array_interface.jl") # include("nested_view.jl") -export get_subtree, get_subtrees_shallow, EmptyAddressTree, Value, AllSelection \ No newline at end of file +export get_subtree, get_subtrees_shallow +export EmptyAddressTree, Value, AllSelection +export get_address_schema \ No newline at end of file diff --git a/src/address_tree/choicemap.jl b/src/address_tree/choicemap.jl index 86e9b1fe..cbfe5a53 100644 --- a/src/address_tree/choicemap.jl +++ b/src/address_tree/choicemap.jl @@ -12,16 +12,18 @@ showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueErr Abstract type for maps from hierarchical addresses to values. """ -const ChoiceMap = AddressTree{<:Value} +const ChoiceMap = AddressTree{Value} + +const ChoiceMapOrEmpty = Union{ChoiceMap, EmptyAddressTree} """ get_submaps_shallow(choices::ChoiceMap) Returns an iterable collection of tuples `(address, submap)` for each top-level address associated with `choices`. -(This includes `ValueChoiceMap`s.) +(This includes `Value`s.) """ -@inline get_submaps_shallow(c::ChoiceMap) = get_subtrees_shallow(c) +@inline get_submaps_shallow(c::ChoiceMapOrEmpty) = get_subtrees_shallow(c) """ get_submap(choices::ChoiceMap, addr) @@ -29,7 +31,9 @@ for each top-level address associated with `choices`. Return the submap at the given address, or `EmptyChoiceMap` if there is no submap at the given address. """ -@inline get_submap(c::ChoiceMap, addr) = get_subtree(c, addr) +@inline get_submap(c::ChoiceMapOrEmpty, addr) = get_subtree(c, addr) + +@inline static_get_submap(c::ChoiceMapOrEmpty, a) = static_get_subtree(c, a) """ has_value(choices::ChoiceMap) @@ -41,14 +45,15 @@ Returns true if `choices` is a `Value`. Returns true if `choices` has a value stored at address `addr`. """ function has_value end -@inline has_value(::ChoiceMap) = false -@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) +@inline has_value(c::ChoiceMapOrEmpty, addr) = has_value(get_submap(c, addr)) +has_value(::Value) = true +has_value(::ChoiceMapOrEmpty) = false """ get_value(choices::ChoiceMap) -Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; -throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. +Returns the value stored on `choices` is `choices` is a `Value`; +throws a `ChoiceMapGetValueError` if `choices` is not a `Value`. get_value(choices::ChoiceMap, addr) Returns the value stored in the submap with address `addr` or throws @@ -59,8 +64,8 @@ A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) -@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline get_value(::ChoiceMapOrEmpty) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMapOrEmpty, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) """ @@ -71,7 +76,7 @@ for each value stored at a top-level address in `choices`. (Works by applying a filter to `get_submaps_shallow`, so this internally requires iterating over every submap.) """ -function get_values_shallow(choices::ChoiceMap) +function get_values_shallow(choices::ChoiceMapOrEmpty) ( (addr, get_value(submap)) for (addr, submap) in get_submaps_shallow(choices) @@ -84,11 +89,11 @@ end Returns an iterable collection of tuples `(address, submap)` for every top-level submap stored in `choices` which is -not a `ValueChoiceMap`. +not a `Value`. (Works by applying a filter to `get_submaps_shallow`, so this internally requires iterating over every submap.) """ -function get_nonvalue_submaps_shallow(choices::ChoiceMap) +function get_nonvalue_submaps_shallow(choices::ChoiceMapOrEmpty) (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) end @@ -96,9 +101,10 @@ end const DynamicChoiceMap = DynamicAddressTree{Value} set_submap!(cm::DynamicChoiceMap, addr, submap::ChoiceMap) = set_subtree!(cm, addr, submap) set_value!(cm::DynamicChoiceMap, addr, val) = set_subtree!(cm, addr, Value(val)) -Base.setindex!(cm::DynamicChoiceMap, addr, val) = set_value!(cm, addr, val) +Base.setindex!(cm::DynamicChoiceMap, val, addr) = set_value!(cm, addr, val) const StaticChoiceMap = StaticAddressTree{Value} +const EmptyChoiceMap = EmptyAddressTree """ choices = choicemap() @@ -123,8 +129,13 @@ function choicemap(tuples...) cm end -export choicemap -export get_value, has_value +export ChoiceMap, choicemap +export ChoiceMapGetValueError +export get_value, has_value, get_submap export get_values_shallow, get_submaps_shallow, get_nonvalue_submaps_shallow +export EmptyChoiceMap, StaticChoiceMap, DynamicChoiceMap +export set_value!, set_submap! +export static_get_submap -include("array_interface.jl") \ No newline at end of file +include("array_interface.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/address_tree/dynamic_address_tree.jl b/src/address_tree/dynamic_address_tree.jl index 2ffb1ce9..ac40bda9 100644 --- a/src/address_tree/dynamic_address_tree.jl +++ b/src/address_tree/dynamic_address_tree.jl @@ -81,6 +81,7 @@ Shallowly convert an address tree to dynamic. """ DynamicAddressTree(t::AddressTree) = shallow_dynamic_copy(t) DynamicAddressTree(t::DynamicAddressTree) = t +DynamicAddressTree{LeafType}(t::AddressTree{<:LeafType}) where {LeafType} = DynamicAddressTree(t) function _from_array(proto_choices::DynamicAddressTree{LT}, arr::Vector{T}, start_idx::Int) where {T, LT} choices = DynamicAddressTree{LT}() @@ -93,3 +94,5 @@ function _from_array(proto_choices::DynamicAddressTree{LT}, arr::Vector{T}, star end (idx - start_idx, choices) end + +export DynamicAddressTree, set_subtree! \ No newline at end of file diff --git a/src/address_tree/selection.jl b/src/address_tree/selection.jl index d8e6925a..bee36bc3 100644 --- a/src/address_tree/selection.jl +++ b/src/address_tree/selection.jl @@ -2,13 +2,14 @@ const Selection = AddressTree{AllSelection} const StaticSelection = StaticAddressTree{AllSelection} const DynamicSelection = DynamicAddressTree{AllSelection} +const EmptySelection = EmptyAddressTree """ in(addr, selection::Selection) Whether the address is selected in the given selection. """ -function Base.in(addr, selection::Selection) +@inline function Base.in(addr, selection::Selection) get_subtree(selection, addr) === AllSelection() end @@ -17,6 +18,8 @@ function Base.getindex(selection::Selection, addr) get_subtree(selection, addr) end +get_subselections(s::Selection) = get_subtrees_shallow(s) + function select(addrs...) selection = DynamicSelection() for addr in addrs @@ -74,4 +77,5 @@ addresses. """ get_selected(tree::AddressTree, selection::Selection) = SelectionFilteredAddressTree(tree, selection) -export select, get_selected \ No newline at end of file +export select, get_selected +export DynamicSelection, EmptySelection, StaticSelection \ No newline at end of file diff --git a/src/address_tree/static_address_tree.jl b/src/address_tree/static_address_tree.jl index 282ccb61..67b4817a 100644 --- a/src/address_tree/static_address_tree.jl +++ b/src/address_tree/static_address_tree.jl @@ -121,15 +121,15 @@ end if type1 <: EmptyAddressTree push!(submap_exprs, - quote choices2.submaps.$addr end + quote tree2.subtrees.$addr end ) elseif type2 <: EmptyAddressTree push!(submap_exprs, - quote choices1.submaps.$addr end + quote tree1.subtrees.$addr end ) else push!(submap_exprs, - quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end + quote merge(tree1.subtrees.$addr, tree2.subtrees.$addr) end ) end end diff --git a/src/address.jl b/src/deprecated/address.jl similarity index 100% rename from src/address.jl rename to src/deprecated/address.jl diff --git a/src/choice_map/array_interface.jl b/src/deprecated/choice_map/array_interface.jl similarity index 88% rename from src/choice_map/array_interface.jl rename to src/deprecated/choice_map/array_interface.jl index cf9d0bd0..e036d1da 100644 --- a/src/choice_map/array_interface.jl +++ b/src/deprecated/choice_map/array_interface.jl @@ -33,14 +33,14 @@ function to_array(choices::ChoiceMap, ::Type{T}) where {T} arr end -function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} +function _fill_array!(c::Value{<:T}, arr::Vector{T}, start_idx::Int) where {T} if length(arr) < start_idx resize!(arr, 2 * start_idx) end arr[start_idx] = get_value(c) 1 end -function _fill_array!(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} +function _fill_array!(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} value = get_value(c) if length(arr) < start_idx + length(value) resize!(arr, 2 * (start_idx + length(value))) @@ -95,12 +95,12 @@ function from_array(proto_choices::ChoiceMap, arr::Vector) choices end -function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) - (1, ValueChoiceMap(arr[start_idx])) +function _from_array(::Value, arr::Vector, start_idx::Int) + (1, Value(arr[start_idx])) end -function _from_array(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} +function _from_array(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} n_read = length(get_value(c)) - (n_read, ValueChoiceMap(arr[start_idx:start_idx+n_read-1])) + (n_read, Value(arr[start_idx:start_idx+n_read-1])) end export to_array, from_array \ No newline at end of file diff --git a/src/choice_map/choice_map.jl b/src/deprecated/choice_map/choice_map.jl similarity index 88% rename from src/choice_map/choice_map.jl rename to src/deprecated/choice_map/choice_map.jl index a1ca2eae..27c97ceb 100644 --- a/src/choice_map/choice_map.jl +++ b/src/deprecated/choice_map/choice_map.jl @@ -23,7 +23,7 @@ abstract type ChoiceMap end Returns an iterable collection of tuples `(address, submap)` for each top-level address associated with `choices`. -(This includes `ValueChoiceMap`s.) +(This includes `Value`s.) """ function get_submaps_shallow end @@ -46,7 +46,7 @@ end """ has_value(choices::ChoiceMap) -Returns true if `choices` is a `ValueChoiceMap`. +Returns true if `choices` is a `Value`. has_value(choices::ChoiceMap, addr) @@ -59,8 +59,8 @@ function has_value end """ get_value(choices::ChoiceMap) -Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; -throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. +Returns the value stored on `choices` is `choices` is a `Value`; +throws a `ChoiceMapGetValueError` if `choices` is not a `Value`. get_value(choices::ChoiceMap, addr) Returns the value stored in the submap with address `addr` or throws @@ -104,7 +104,7 @@ end Returns an iterable collection of tuples `(address, submap)` for every top-level submap stored in `choices` which is -not a `ValueChoiceMap`. +not a `Value`. (Works by applying a filter to `get_submaps_shallow`, so this internally requires iterating over every submap.) """ @@ -133,21 +133,21 @@ struct EmptyChoiceMap <: ChoiceMap end @inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false """ - ValueChoiceMap + Value A leaf-node choicemap. Stores a single value. """ -struct ValueChoiceMap{T} <: ChoiceMap +struct Value{T} <: ChoiceMap val::T end -@inline has_value(choices::ValueChoiceMap) = true -@inline get_value(choices::ValueChoiceMap) = choices.val -@inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() -@inline get_submaps_shallow(choices::ValueChoiceMap) = () -@inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val -@inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) -@inline get_address_schema(::Type{<:ValueChoiceMap}) = AllAddressSchema() +@inline has_value(choices::Value) = true +@inline get_value(choices::Value) = choices.val +@inline get_submap(choices::Value, addr) = EmptyChoiceMap() +@inline get_submaps_shallow(choices::Value) = () +@inline Base.:(==)(a::Value, b::Value) = a.val == b.val +@inline Base.isapprox(a::Value, b::Value) = isapprox(a.val, b.val) +@inline get_address_schema(::Type{<:Value}) = AllAddressSchema() """ choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) @@ -172,10 +172,10 @@ function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) end Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c -Base.merge(c::ValueChoiceMap, ::EmptyChoiceMap) = c -Base.merge(::EmptyChoiceMap, c::ValueChoiceMap) = c -Base.merge(::ValueChoiceMap, ::ChoiceMap) = error("ValueChoiceMaps cannot be merged") -Base.merge(::ChoiceMap, ::ValueChoiceMap) = error("ValueChoiceMaps cannot be merged") +Base.merge(c::Value, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::Value) = c +Base.merge(::Value, ::ChoiceMap) = error("Values cannot be merged") +Base.merge(::ChoiceMap, ::Value) = error("Values cannot be merged") """ Variadic merge of choice maps. @@ -270,7 +270,7 @@ function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) _show_pretty(io, choices, 0, ()) end -export ChoiceMap, ValueChoiceMap, EmptyChoiceMap +export ChoiceMap, Value, EmptyChoiceMap export _get_submap, get_submap, get_submaps_shallow export get_value, has_value export get_values_shallow, get_nonvalue_submaps_shallow diff --git a/src/choice_map/dynamic_choice_map.jl b/src/deprecated/choice_map/dynamic_choice_map.jl similarity index 95% rename from src/choice_map/dynamic_choice_map.jl rename to src/deprecated/choice_map/dynamic_choice_map.jl index 0f27c89d..e2953b84 100644 --- a/src/choice_map/dynamic_choice_map.jl +++ b/src/deprecated/choice_map/dynamic_choice_map.jl @@ -38,7 +38,7 @@ Copy a choice map, returning a mutable choice map. function DynamicChoiceMap(other::ChoiceMap) choices = DynamicChoiceMap() for (addr, submap) in get_submaps_shallow(other) - if submap isa ValueChoiceMap + if submap isa Value set_submap!(choices, addr, submap) else set_submap!(choices, addr, DynamicChoiceMap(submap)) @@ -47,7 +47,7 @@ function DynamicChoiceMap(other::ChoiceMap) choices end -DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") +DynamicChoiceMap(other::Value) = error("Cannot convert a Value to a DynamicChoiceMap") """ choices = choicemap() @@ -88,7 +88,7 @@ The following syntactic sugar is provided: """ function set_value!(choices::DynamicChoiceMap, addr, value) delete!(choices.submaps, addr) - choices.submaps[addr] = ValueChoiceMap(value) + choices.submaps[addr] = Value(value) end function set_value!(choices::DynamicChoiceMap, addr::Pair, value) diff --git a/src/choice_map/nested_view.jl b/src/deprecated/choice_map/nested_view.jl similarity index 97% rename from src/choice_map/nested_view.jl rename to src/deprecated/choice_map/nested_view.jl index 68add0a0..6ecf9707 100644 --- a/src/choice_map/nested_view.jl +++ b/src/deprecated/choice_map/nested_view.jl @@ -29,7 +29,7 @@ struct ChoiceMapNestedView choice_map::ChoiceMap end -ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) +ChoiceMapNestedView(cm::Value) = get_value(cm) ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") function Base.getindex(choices::ChoiceMapNestedView, addr) diff --git a/src/choice_map/static_choice_map.jl b/src/deprecated/choice_map/static_choice_map.jl similarity index 93% rename from src/choice_map/static_choice_map.jl rename to src/deprecated/choice_map/static_choice_map.jl index 587fc6ee..0f021850 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/deprecated/choice_map/static_choice_map.jl @@ -11,7 +11,7 @@ end function StaticChoiceMap(;addrs_to_vals_and_maps...) addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) - maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : ValueChoiceMap(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) + maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : Value(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) StaticChoiceMap(NamedTuple{addrs}(maps)) end @@ -52,7 +52,7 @@ function StaticChoiceMap(other::ChoiceMap) end StaticChoiceMap(NamedTuple{addrs}(submaps)) end -StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") +StaticChoiceMap(other::Value) = error("Cannot convert a Value to a StaticChoiceMap") StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() # TODO: deep conversion to static choicemap @@ -99,8 +99,8 @@ end for addr in merged_addrs type1 = get(addr_to_type1, addr, EmptyChoiceMap) type2 = get(addr_to_type2, addr, EmptyChoiceMap) - if ((type1 <: ValueChoiceMap && type2 != EmptyChoiceMap) - || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) + if ((type1 <: Value && type2 != EmptyChoiceMap) + || (type2 <: Value && type1 != EmptyChoiceMap)) error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") end if type1 <: EmptyChoiceMap diff --git a/src/distribution.jl b/src/distribution.jl index 6354558d..8816082c 100644 --- a/src/distribution.jl +++ b/src/distribution.jl @@ -66,7 +66,7 @@ get_return_type(::Distribution{T}) where {T} = T @inline Base.getindex(trace::DistributionTrace) = trace.val @inline Gen.get_args(trace::DistributionTrace) = trace.args -@inline Gen.get_choices(trace::DistributionTrace) = ValueChoiceMap(trace.val) # should be able to get type of val +@inline Gen.get_choices(trace::DistributionTrace) = Value(trace.val) # should be able to get type of val @inline Gen.get_retval(trace::DistributionTrace) = trace.val @inline Gen.get_gen_fn(trace::DistributionTrace) = dist(trace) @inline Gen.get_score(trace::DistributionTrace) = trace.score @@ -78,12 +78,12 @@ get_return_type(::Distribution{T}) where {T} = T DistributionTrace(val, args, dist) end @inline Gen.generate(dist::Distribution, args::Tuple, ::EmptyChoiceMap) = (simulate(dist, args), 0.) -@inline function Gen.generate(dist::Distribution, args::Tuple, constraints::ValueChoiceMap) +@inline function Gen.generate(dist::Distribution, args::Tuple, constraints::Value) tr = DistributionTrace(get_value(constraints), args, dist) weight = get_score(tr) (tr, weight) end -@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::ValueChoiceMap) +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, constraints::Value) new_tr = DistributionTrace(get_value(constraints), args, dist(tr)) weight = get_score(new_tr) - get_score(tr) (new_tr, weight, UnknownChange(), get_choices(tr)) @@ -109,9 +109,9 @@ end @inline function Gen.propose(dist::Distribution, args::Tuple) val = random(dist, args...) score = logpdf(dist, val, args...) - (ValueChoiceMap(val), score, val) + (Value(val), score, val) end -@inline function Gen.assess(dist::Distribution, args::Tuple, choices::ValueChoiceMap) +@inline function Gen.assess(dist::Distribution, args::Tuple, choices::Value) weight = logpdf(dist, get_value(choices), args...) (weight, choices.val) end diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 0d8e03b4..dd62e486 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -124,8 +124,8 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end -all_visited(::Selection, ::ValueChoiceMap) = false -all_visited(::AllSelection, ::ValueChoiceMap) = true +all_visited(::Selection, ::Value) = false +all_visited(::AllSelection, ::Value) = true function all_visited(visited::Selection, choices::ChoiceMap) for (key, submap) in get_submaps_shallow(choices) if !all_visited(visited[key], submap) @@ -135,8 +135,8 @@ function all_visited(visited::Selection, choices::ChoiceMap) return true end -get_unvisited(::Selection, v::ValueChoiceMap) = v -get_unvisited(::AllSelection, v::ValueChoiceMap) = EmptyChoiceMap() +get_unvisited(::Selection, v::Value) = v +get_unvisited(::AllSelection, v::Value) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() for (key, submap) in get_submaps_shallow(choices) diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index f997ba39..951c519d 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -154,7 +154,7 @@ function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) if retval_grad !== nothing choice_grad += retval_grad end - gradient_choices = CallAtChoiceMap(trace.key, ValueChoiceMap(choice_grad)) + gradient_choices = CallAtChoiceMap(trace.key, Value(choice_grad)) else value_choices = EmptyChoiceMap() gradient_choices = EmptyChoiceMap() diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 3af416ef..607ee747 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -113,7 +113,7 @@ function get_retained_and_selected(selection::EmptySelection, prev_length::Int, Set{Int}() end -function get_retained_and_selected(selection::HierarchicalSelection, prev_length::Int, new_length::Int) +function get_retained_and_selected(selection::Selection, prev_length::Int, new_length::Int) keys = Set{Int}() for (key::Int, _) in get_subselections(selection) if key > 0 && key <= new_length diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 38bf2cca..15479f12 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -322,8 +322,8 @@ function generate_value_gradient_trie(selected_choices::Set{GenerativeFunctionCa value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(get_retval(trace.$(get_subtrace_fieldname(node))))), selected_choices_vec) - leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(Value(get_retval(trace.$(get_subtrace_fieldname(node))))), selected_choices_vec) + leaf_gradient_choicemaps = map((node) -> :(Value($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) diff --git a/src/trie.jl b/src/trie.jl index 0d1c2a8a..d7a3f962 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -2,6 +2,8 @@ # Trie # ################## +# TODO: use the `DynamicAddressTree` implementation for this + struct Trie{K,V} <: ChoiceMap leaf_nodes::Dict{K,V} internal_nodes::Dict{K,Trie{K,V}} diff --git a/test/assignment.jl b/test/assignment.jl index 7485f92c..9a889b81 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,10 +1,10 @@ -@testset "ValueChoiceMap" begin - vcm1 = ValueChoiceMap(2) - vcm2 = ValueChoiceMap(2.) - vcm3 = ValueChoiceMap([1,2]) - @test vcm1 isa ValueChoiceMap{Int} - @test vcm2 isa ValueChoiceMap{Float64} - @test vcm3 isa ValueChoiceMap{Vector{Int}} +@testset "Value" begin + vcm1 = Value(2) + vcm2 = Value(2.) + vcm3 = Value([1,2]) + @test vcm1 isa Value{Int} + @test vcm2 isa Value{Float64} + @test vcm3 isa Value{Vector{Int}} @test vcm1[] == 2 @test vcm1[] == get_value(vcm1) @@ -16,8 +16,8 @@ @test isempty(get_values_shallow(vcm1)) @test isempty(get_nonvalue_submaps_shallow(vcm1)) @test to_array(vcm1, Int) == [2] - @test from_array(vcm1, [4]) == ValueChoiceMap(4) - @test from_array(vcm3, [4, 5]) == ValueChoiceMap([4, 5]) + @test from_array(vcm1, [4]) == Value(4) + @test from_array(vcm3, [4, 5]) == Value([4, 5]) @test_throws Exception merge(vcm1, vcm2) @test_throws Exception merge(vcm1, choicemap(:a, 5)) @test merge(vcm1, EmptyChoiceMap()) == vcm1 @@ -25,19 +25,17 @@ @test get_submap(vcm1, :addr) == EmptyChoiceMap() @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) @test !has_value(vcm1, :addr) - @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) - @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) - @test get_address_schema(typeof(vcm1)) == AllAddressSchema() - @test get_address_schema(ValueChoiceMap) == AllAddressSchema() + @test isapprox(vcm2, Value(prevfloat(2.))) + @test isapprox(vcm1, Value(prevfloat(2.))) @test nested_view(vcm1) == 2 end @testset "static choicemap constructor" begin - @test StaticChoiceMap((a=ValueChoiceMap(5), b=ValueChoiceMap(6))) == StaticChoiceMap(a=5, b=6) + @test StaticChoiceMap((a=Value(5), b=Value(6))) == StaticChoiceMap(a=5, b=6) submap = StaticChoiceMap(a=1., b=[2., 2.5]) - @test submap == StaticChoiceMap((a=ValueChoiceMap(1.), b=ValueChoiceMap([2., 2.5]))) + @test submap == StaticChoiceMap((a=Value(1.), b=Value([2., 2.5]))) outer = StaticChoiceMap(c=3, d=submap, e=submap) - @test outer == StaticChoiceMap((c=ValueChoiceMap(3), d=submap, e=submap)) + @test outer == StaticChoiceMap((c=Value(3), d=submap, e=submap)) end @testset "static assignment to/from array" begin @@ -83,8 +81,8 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test get_submap(choices, :c) == ValueChoiceMap(1.0) - @test get_submap(choices, :d => :b) == ValueChoiceMap([3.0, 4.0]) + @test get_submap(choices, :c) == Value(1.0) + @test get_submap(choices, :d => :b) == Value([3.0, 4.0]) @test length(collect(get_submaps_shallow(choices))) == 3 @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 @@ -182,13 +180,13 @@ end # they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; # should we relabel this testset? @testset "static assignment errors" begin - # get_choices on an address that returns a ValueChoiceMap + # get_choices on an address that returns a Value choices = StaticChoiceMap(x=1) - @test get_submap(choices, :x) == ValueChoiceMap(1) + @test get_submap(choices, :x) == Value(1) - # static_get_submap on an address that contains a value returns a ValueChoiceMap + # static_get_submap on an address that contains a value returns a Value choices = StaticChoiceMap(x=1) - @test static_get_submap(choices, Val(:x)) == ValueChoiceMap(1) + @test static_get_submap(choices, Val(:x)) == Value(1) # get_submap on an address whose prefix contains a value returns EmptyChoiceMap choices = StaticChoiceMap(x=1) @@ -226,10 +224,10 @@ end end @testset "dynamic assignment errors" begin - # get_choices on an address that contains a value returns a ValueChoiceMap + # get_choices on an address that contains a value returns a Value choices = choicemap() choices[:x] = 1 - @test get_submap(choices, :x) == ValueChoiceMap(1) + @test get_submap(choices, :x) == Value(1) # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() @@ -272,7 +270,7 @@ end choices = choicemap() choices[:x => :y] = 1 choices[:x] = 2 - @test get_submap(choices, :x) == ValueChoiceMap(2) + @test get_submap(choices, :x) == Value(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index ffe07d77..1bb6cb02 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -408,7 +408,7 @@ @test weight == 0. @test coinflips_tr[20] isa Bool choices = get_choices(coinflips_tr) - @test get_submap(choices, 42) isa ValueChoiceMap{Bool} + @test get_submap(choices, 42) isa Value{Bool} val42 = get_value(choices, 42) new_tr, weight, retdiff, discard = update(coinflips_tr, (fill(0.4, 100),), (NoChange(),), choicemap((42, !val42))) @test new_tr[42] == !val42 diff --git a/test/runtests.jl b/test/runtests.jl index 74923603..a6e408e0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,10 +68,10 @@ include("autodiff.jl") include("diff.jl") include("selection.jl") include("assignment.jl") -include("dynamic_dsl.jl") -include("optional_args.jl") -include("static_ir/static_ir.jl") -include("static_dsl.jl") -include("tilde_sugar.jl") -include("inference/inference.jl") -include("modeling_library/modeling_library.jl") \ No newline at end of file +# include("dynamic_dsl.jl") +# include("optional_args.jl") +# include("static_ir/static_ir.jl") +# include("static_dsl.jl") +# include("tilde_sugar.jl") +# include("inference/inference.jl") +# include("modeling_library/modeling_library.jl \ No newline at end of file diff --git a/test/selection.jl b/test/selection.jl index 6436eaec..b165bfd8 100644 --- a/test/selection.jl +++ b/test/selection.jl @@ -14,28 +14,28 @@ # test Base.getindex @test s[:x] == AllSelection() sub = s[:y] - @test isa(sub, DynamicSelection) + @test isa(sub, DynamicAddressTree) @test :z in sub @test :w in sub - @test s[:u] == EmptySelection() + @test s[:u] == EmptyAddressTree() @test s[:y => :z] == AllSelection() # test set_subselection! - set_subselection!(s, :y, select(:z)) + set_subtree!(s, :y, select(:z)) @test (:y => :z) in s @test !((:y => :w) in s) selection = select(:x) @test :x in selection subselection = select(:y) - set_subselection!(selection, :x, subselection) + set_subtree!(selection, :x, subselection) @test (:x => :y) in selection @test !(:x in selection) end @testset begin "all selection" - s = selectall() + s = AllSelection() # test Base.in @test :x in s @@ -44,26 +44,4 @@ end # test Base.getindex @test s[:x] == AllSelection() @test s[:x => :y] == AllSelection() -end - -@testset begin "complement selection" - - @test !(:x in complement(selectall())) - @test :x in complement(select()) - - @test !(:x in complement(select(:x))) - @test :y in complement(select(:x)) - - @test :x in complement(select(:x => :y => :z)) - @test (:x => :y) in complement(select(:x => :y => :z)) - @test !((:x => :y => :z) in complement(select(:x => :y => :z))) - - @test !(:x in complement(complement(select(:x => :y => :z)))) - @test !((:x => :y) in complement(complement(select(:x => :y => :z)))) - @test (:x => :y => :z) in complement(complement(select(:x => :y => :z))) - - s = complement(select(:x => :y => :z))[:x] - @test !((:y => :z) in s) - @test :w in s - @test :y in s -end +end \ No newline at end of file From bc9b7f6d40f836cf084c002e99d4bfb7f4359290 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 6 Jul 2020 12:08:22 -0400 Subject: [PATCH 33/34] bug fixes -> tests passing --- examples/pmmh/model.jl | 2 +- src/address_tree/address_tree.jl | 10 ++-- src/address_tree/array_interface.jl | 4 +- src/address_tree/choicemap.jl | 22 ++++---- src/address_tree/selection.jl | 56 ++++++++++++++----- src/address_tree/static_address_tree.jl | 21 ++++++- src/deprecated/choice_map/array_interface.jl | 4 +- src/deprecated/choice_map/choice_map.jl | 2 +- .../choice_map/dynamic_choice_map.jl | 4 +- .../choice_map/static_choice_map.jl | 2 +- src/dynamic/backprop.jl | 2 +- src/dynamic/dynamic.jl | 2 +- src/dynamic/project.jl | 4 +- src/dynamic/regenerate.jl | 2 +- src/dynamic/trace.jl | 8 +-- src/modeling_library/call_at/call_at.jl | 14 ++--- src/modeling_library/map/backprop.jl | 2 +- src/modeling_library/map/regenerate.jl | 4 +- src/modeling_library/recurse/recurse.jl | 10 ++-- src/modeling_library/unfold/regenerate.jl | 4 +- src/modeling_library/vector.jl | 10 ++-- src/static_ir/backprop.jl | 2 +- src/static_ir/project.jl | 2 +- src/static_ir/trace.jl | 16 +++--- src/static_ir/update.jl | 2 +- src/trie.jl | 2 +- test/runtests.jl | 14 ++--- test/selection.jl | 1 - 28 files changed, 135 insertions(+), 93 deletions(-) diff --git a/examples/pmmh/model.jl b/examples/pmmh/model.jl index ce589706..5421dd69 100644 --- a/examples/pmmh/model.jl +++ b/examples/pmmh/model.jl @@ -150,7 +150,7 @@ get_call_record(trace::CollapsedHMMTrace) = trace.vector.call has_choices(trace::CollapsedHMMTrace) = length(trace.vector.call.retval) > 0 get_choices(trace::CollapsedHMMTrace) = CollapsedHMMChoiceMap(get_choices(trace.vector)) -struct CollapsedHMMChoiceMap <: ChoiceMap +struct CollapsedHMMChoiceMap <: AddressTree{Value} y_assignment::VectorDistTraceChoiceMap end diff --git a/src/address_tree/address_tree.jl b/src/address_tree/address_tree.jl index 1249fe1c..8148c316 100644 --- a/src/address_tree/address_tree.jl +++ b/src/address_tree/address_tree.jl @@ -32,13 +32,10 @@ struct EmptyAddressTree <: AddressTreeLeaf{EmptyAddressTree} end An address tree leaf node storing a value of type `T`. """ -struct Value{T} <: AddressTreeLeaf{Value} +struct Value{T} <: AddressTreeLeaf{Value{T}} val::T end @inline get_value(v::Value) = v.val -# Note that we don't set `Value{T} <: AddressTreeLeaf{Value{T}}`; -# this complicates the type heirarchicy unnecessarily and results in -# unintuitive phenomena, such as `Value <: AddressTreeLeaf{Value} == false`. """ AllSelection @@ -53,6 +50,9 @@ struct AllSelection <: AddressTreeLeaf{AllSelection} end Get the subtree at address `addr` or return `EmptyAddressTree` if there is no subtree at this address. + +Invariant: `get_subtree(::AddressTree{LeafType}, addr)` either returns +an object of `LeafType` or an `EmptyAddressTree`. """ function get_subtree end @@ -87,7 +87,7 @@ Base.isempty(::Value) = false Base.isempty(::AllSelection) = false Base.isempty(::EmptyAddressTree) = true Base.isempty(::AddressTreeLeaf) = error("Not implemented") -Base.isempty(t::AddressTree) = all((_, subtree) -> isempty(subtree), get_subtrees_shallow(t)) +Base.isempty(t::AddressTree) = all(((_, subtree),) -> isempty(subtree), get_subtrees_shallow(t)) @inline get_subtree(::AddressTreeLeaf, _) = EmptyAddressTree() @inline get_subtrees_shallow(::AddressTreeLeaf) = () diff --git a/src/address_tree/array_interface.jl b/src/address_tree/array_interface.jl index 3fac3865..0b6eae2f 100644 --- a/src/address_tree/array_interface.jl +++ b/src/address_tree/array_interface.jl @@ -19,7 +19,7 @@ the addresses of the choicemap using the `sort` function, then iterating over each submap in this order and filling the array for that submap. To override the default implementation of `to_array`, -a concrete subtype `T <: ChoiceMap` should implement the following method: +a concrete subtype `T <: AddressTree{Value}` should implement the following method: n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} @@ -82,7 +82,7 @@ and then filling in the submaps depth-first in this order. # Implementation -To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +To support `from_array`, a concrete subtype `T AddressTree{Value}` must implement the following method: (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} diff --git a/src/address_tree/choicemap.jl b/src/address_tree/choicemap.jl index cbfe5a53..94f09e53 100644 --- a/src/address_tree/choicemap.jl +++ b/src/address_tree/choicemap.jl @@ -12,9 +12,7 @@ showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueErr Abstract type for maps from hierarchical addresses to values. """ -const ChoiceMap = AddressTree{Value} - -const ChoiceMapOrEmpty = Union{ChoiceMap, EmptyAddressTree} +const ChoiceMap = AddressTree{<:Union{Value, EmptyAddressTree}} """ get_submaps_shallow(choices::ChoiceMap) @@ -23,7 +21,7 @@ Returns an iterable collection of tuples `(address, submap)` for each top-level address associated with `choices`. (This includes `Value`s.) """ -@inline get_submaps_shallow(c::ChoiceMapOrEmpty) = get_subtrees_shallow(c) +@inline get_submaps_shallow(c::ChoiceMap) = get_subtrees_shallow(c) """ get_submap(choices::ChoiceMap, addr) @@ -31,9 +29,9 @@ for each top-level address associated with `choices`. Return the submap at the given address, or `EmptyChoiceMap` if there is no submap at the given address. """ -@inline get_submap(c::ChoiceMapOrEmpty, addr) = get_subtree(c, addr) +@inline get_submap(c::ChoiceMap, addr) = get_subtree(c, addr) -@inline static_get_submap(c::ChoiceMapOrEmpty, a) = static_get_subtree(c, a) +@inline static_get_submap(c::ChoiceMap, a) = static_get_subtree(c, a) """ has_value(choices::ChoiceMap) @@ -45,9 +43,9 @@ Returns true if `choices` is a `Value`. Returns true if `choices` has a value stored at address `addr`. """ function has_value end -@inline has_value(c::ChoiceMapOrEmpty, addr) = has_value(get_submap(c, addr)) +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) has_value(::Value) = true -has_value(::ChoiceMapOrEmpty) = false +has_value(::ChoiceMap) = false """ get_value(choices::ChoiceMap) @@ -64,8 +62,8 @@ A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -@inline get_value(::ChoiceMapOrEmpty) = throw(ChoiceMapGetValueError()) -@inline get_value(c::ChoiceMapOrEmpty, addr) = get_value(get_submap(c, addr)) +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) """ @@ -76,7 +74,7 @@ for each value stored at a top-level address in `choices`. (Works by applying a filter to `get_submaps_shallow`, so this internally requires iterating over every submap.) """ -function get_values_shallow(choices::ChoiceMapOrEmpty) +function get_values_shallow(choices::ChoiceMap) ( (addr, get_value(submap)) for (addr, submap) in get_submaps_shallow(choices) @@ -93,7 +91,7 @@ not a `Value`. (Works by applying a filter to `get_submaps_shallow`, so this internally requires iterating over every submap.) """ -function get_nonvalue_submaps_shallow(choices::ChoiceMapOrEmpty) +function get_nonvalue_submaps_shallow(choices::ChoiceMap) (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) end diff --git a/src/address_tree/selection.jl b/src/address_tree/selection.jl index bee36bc3..d2dbb3df 100644 --- a/src/address_tree/selection.jl +++ b/src/address_tree/selection.jl @@ -1,7 +1,6 @@ -const Selection = AddressTree{AllSelection} +const Selection = AddressTree{<:Union{AllSelection, EmptyAddressTree}} const StaticSelection = StaticAddressTree{AllSelection} -const DynamicSelection = DynamicAddressTree{AllSelection} const EmptySelection = EmptyAddressTree """ @@ -14,19 +13,12 @@ Whether the address is selected in the given selection. end # indexing returns subtrees for selections -function Base.getindex(selection::Selection, addr) - get_subtree(selection, addr) -end +Base.getindex(selection::AddressTree{AllSelection}, addr) = get_subtree(selection, addr) -get_subselections(s::Selection) = get_subtrees_shallow(s) +# TODO: deprecate indexing syntax and only use this +get_subselection(s::Selection, addr) = get_subtree(s, addr) -function select(addrs...) - selection = DynamicSelection() - for addr in addrs - set_subtree!(selection, addr, AllSelection()) - end - selection -end +get_subselections(s::Selection) = get_subtrees_shallow(s) Base.merge(::AllSelection, ::Selection) = AllSelection() Base.merge(::Selection, ::AllSelection) = AllSelection() @@ -77,5 +69,43 @@ addresses. """ get_selected(tree::AddressTree, selection::Selection) = SelectionFilteredAddressTree(tree, selection) +""" + struct DynamicSelection <: HierarchicalSelection .. end +A hierarchical, mutable, selection with arbitrary addresses. +Can be mutated with the following methods: + Base.push!(selection::DynamicSelection, addr) +Add the address and all of its sub-addresses to the selection. +Example: +```julia +selection = select() +@assert !(:x in selection) +push!(selection, :x) +@assert :x in selection +``` + set_subselection!(selection::DynamicSelection, addr, other::Selection) +Change the selection status of the given address and its sub-addresses that defined by `other`. +Example: +```julia +selection = select(:x) +@assert :x in selection +subselection = select(:y) +set_subselection!(selection, :x, subselection) +@assert (:x => :y) in selection +@assert !(:x in selection) +``` +Note that `set_subselection!` does not copy data in `other`, so `other` may be mutated by a later calls to `set_subselection!` for addresses under `addr`. +""" +const DynamicSelection = DynamicAddressTree{AllSelection} +Base.push!(s::DynamicSelection, addr) = set_subtree!(s, addr, AllSelection()) +set_subselection!(s::DynamicSelection, addr, sub::Selection) = set_subtree!(s, addr, sub) + +function select(addrs...) + selection = DynamicSelection() + for addr in addrs + set_subtree!(selection, addr, AllSelection()) + end + selection +end + export select, get_selected export DynamicSelection, EmptySelection, StaticSelection \ No newline at end of file diff --git a/src/address_tree/static_address_tree.jl b/src/address_tree/static_address_tree.jl index 67b4817a..1d760b99 100644 --- a/src/address_tree/static_address_tree.jl +++ b/src/address_tree/static_address_tree.jl @@ -7,12 +7,27 @@ struct StaticAddressTree{LeafType, Addrs, SubtreeTypes} <: AddressTree{LeafType} end end +# If some of the subtrees are `EmptyAddressTree`s, but the `LeafType` does not have `EmptyAddressTree` +# as a subtype, strip away the addresses which point to EmptyAddressTrees. +# TODO: is this the implementation we want? is the performance hit for removing the empty subtrees here worthwhile? +# also TODO: should I make this @generated? +function StaticAddressTree{LeafType}(nt::NamedTuple{Addrs, Subtrees}) where { + LeafType, Addrs, Subtrees <: Tuple{Vararg{<:AddressTree{<:Union{LeafType, EmptyAddressTree}}}} +} + nonempty_addrs = Tuple(findall(x -> x != EmptyAddressTree(), nt)) + nonempty_subtrees = Tuple(nt[addr] for addr in nonempty_addrs) + StaticAddressTree{LeafType}(NamedTuple{nonempty_addrs}(nonempty_subtrees)) +end + # NOTE: It is probably better to avoid using this constructor when possible since I suspect it is less performant # than if we specify `LeafType`. -# I could make this into a generated function...this would probably improve runtime performance but hurt compiletime performance. function StaticAddressTree(subtrees::NamedTuple{Addrs, SubtreeTypes}) where {Addrs, SubtreeTypes <: Tuple{Vararg{AddressTree}}} uniontype = Union{SubtreeTypes.parameters...} - StaticAddressTree{uniontype}(subtrees) + if @generated + quote StaticAddressTree{$uniontype}(subtrees) end + else + StaticAddressTree{uniontype}(subtrees) + end end """ StaticAddressTree{LeafType}(; a=val, b=tree, ...) @@ -169,7 +184,7 @@ end end end -function get_address_schema(::Type{StaticAddressTree{LT, Addrs}}) where {LT, Addrs} +function get_address_schema(::Type{StaticAddressTree{LT, Addrs, SubtreeTypes}}) where {LT, Addrs, SubtreeTypes} StaticAddressSchema(Set(Addrs)) end diff --git a/src/deprecated/choice_map/array_interface.jl b/src/deprecated/choice_map/array_interface.jl index e036d1da..314b9bc5 100644 --- a/src/deprecated/choice_map/array_interface.jl +++ b/src/deprecated/choice_map/array_interface.jl @@ -15,7 +15,7 @@ the addresses of the choicemap using the `sort` function, then iterating over each submap in this order and filling the array for that submap. To override the default implementation of `to_array`, -a concrete subtype `T <: ChoiceMap` should implement the following method: +a concrete subtype `T <: AddressTree{Value}` should implement the following method: n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} @@ -78,7 +78,7 @@ and then filling in the submaps depth-first in this order. # Implementation -To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +To support `from_array`, a concrete subtype `T AddressTree{Value}` must implement the following method: (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} diff --git a/src/deprecated/choice_map/choice_map.jl b/src/deprecated/choice_map/choice_map.jl index 27c97ceb..74ddac6f 100644 --- a/src/deprecated/choice_map/choice_map.jl +++ b/src/deprecated/choice_map/choice_map.jl @@ -222,7 +222,7 @@ function get_selected( if has_value(submap) && addr in selection output[addr] = get_value(submap) else - subselection = selection[addr] + subselection = get_subselection(selection, addr) set_submap!(output, addr, get_selected(submap, subselection)) end end diff --git a/src/deprecated/choice_map/dynamic_choice_map.jl b/src/deprecated/choice_map/dynamic_choice_map.jl index e2953b84..f54a221c 100644 --- a/src/deprecated/choice_map/dynamic_choice_map.jl +++ b/src/deprecated/choice_map/dynamic_choice_map.jl @@ -3,7 +3,7 @@ ####################### """ - struct DynamicChoiceMap <: ChoiceMap .. end + struct DynamicChoiceMap AddressTree{Value} .. end A mutable map from arbitrary hierarchical addresses to values. @@ -15,7 +15,7 @@ Construct an empty map. Construct a map containing each of the given (addr, value) tuples. """ -struct DynamicChoiceMap <: ChoiceMap +struct DynamicChoiceMap AddressTree{Value} submaps::Dict{Any, ChoiceMap} function DynamicChoiceMap() new(Dict()) diff --git a/src/deprecated/choice_map/static_choice_map.jl b/src/deprecated/choice_map/static_choice_map.jl index 0f021850..3adcb775 100644 --- a/src/deprecated/choice_map/static_choice_map.jl +++ b/src/deprecated/choice_map/static_choice_map.jl @@ -2,7 +2,7 @@ # static assignment # ###################### -struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap +struct StaticChoiceMap{Addrs, SubmapTypes} AddressTree{Value} submaps::NamedTuple{Addrs, SubmapTypes} function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} new{Addrs, SubmapTypes}(submaps) diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index e870d0c7..3ad34643 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -317,7 +317,7 @@ function traceat(state::GFBackpropTraceState, gen_fn::GenerativeFunction{T,U}, retval_maybe_tracked = retval @assert !istracked(retval_maybe_tracked) end - selection = state.selection[key] + selection = get_subselection(state.selection, key) record = BackpropTraceRecord(gen_fn, subtrace, selection, state.value_choices, state.gradient_choices, key) record!(state.tape, ReverseDiff.SpecialInstruction, record, (args_maybe_tracked...,), retval_maybe_tracked) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index dd62e486..9042882b 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -128,7 +128,7 @@ all_visited(::Selection, ::Value) = false all_visited(::AllSelection, ::Value) = true function all_visited(visited::Selection, choices::ChoiceMap) for (key, submap) in get_submaps_shallow(choices) - if !all_visited(visited[key], submap) + if !all_visited(get_subselection(visited, key), submap) return false end end diff --git a/src/dynamic/project.jl b/src/dynamic/project.jl index 81b45b3c..33452581 100644 --- a/src/dynamic/project.jl +++ b/src/dynamic/project.jl @@ -2,11 +2,11 @@ function project_recurse(trie::Trie{Any, CallRecord}, selection::Selection) weight = 0. for (key, call) in get_leaf_nodes(trie) - subselection = selection[key] + subselection = get_subselection(selection, key) weight += project(call.subtrace, subselection) end for (key, subtrie) in get_internal_nodes(trie) - subselection = selection[key] + subselection = get_subselection(selection, key) weight += project_recurse(subtrie, subselection) end weight diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index a4006a6c..5a48b360 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -24,7 +24,7 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # check whether the key was selected - subselection = state.selection[key] + subselection = get_subselection(state.selection, key) # get subtrace has_previous = has_call(state.prev_trace, key) diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 0a169d73..bf20420b 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -49,13 +49,13 @@ get_gen_fn(trace::DynamicDSLTrace) = trace.gen_fn get_choices(trace::DynamicDSLTrace) = DynamicDSLChoiceMap(trace.trie) -struct DynamicDSLChoiceMap <: ChoiceMap +struct DynamicDSLChoiceMap <: AddressTree{Value} trie::Trie{Any,CallRecord} end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) -function get_submap(choices::DynamicDSLChoiceMap, addr) +get_subtree(choices::DynamicDSLChoiceMap, addr::Pair) = _get_subtree(choices, addr) +function get_subtree(choices::DynamicDSLChoiceMap, addr) if haskey(choices.trie.leaf_nodes, addr) get_choices(choices.trie[addr].subtrace) elseif haskey(choices.trie.internal_nodes, addr) @@ -65,7 +65,7 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) end end -function get_submaps_shallow(choices::DynamicDSLChoiceMap) +function get_subtrees_shallow(choices::DynamicDSLChoiceMap) leafs = ((key, get_choices(record.subtrace)) for (key, record) in get_leaf_nodes(choices.trie)) internals = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) Iterators.flatten((leafs, internals)) diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 951c519d..805f589a 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -1,4 +1,4 @@ -struct CallAtChoiceMap{K,T} <: ChoiceMap +struct CallAtChoiceMap{K,T} <: AddressTree{Value} key::K submap::T end @@ -9,12 +9,12 @@ function get_address_schema(::Type{T}) where {T<:CallAtChoiceMap} SingleDynamicKeyAddressSchema() end -function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} +function get_subtree(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} choices.key == addr ? choices.submap : EmptyChoiceMap() end -get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) +get_subtree(choices::CallAtChoiceMap, addr::Pair) = _get_subtree(choices, addr) +get_subtrees_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) # TODO optimize CallAtTrace using type parameters @@ -96,7 +96,7 @@ function generate(gen_fn::CallAtCombinator{T,U,K}, args::Tuple, end function project(trace::CallAtTrace, selection::Selection) - subselection = selection[trace.key] + subselection = get_subselection(selection, trace.key) project(trace.subtrace, subselection) end @@ -123,7 +123,7 @@ function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, selection::Selection) (key, kernel_args) = unpack_call_at_args(args) key_changed = (key != trace.key) - subselection = selection[key] + subselection = get_subselection(selection, key) if key_changed if !isempty(subselection) error("Cannot select addresses under new key $key in regenerate") @@ -162,7 +162,7 @@ function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) input_grads = (kernel_arg_grads[2:end]..., nothing) return (input_grads, value_choices, gradient_choices) else - subselection = selection[trace.key] + subselection = get_subselection(selection, trace.key) (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( trace.subtrace, subselection, retval_grad) input_grads = (kernel_input_grads..., nothing) diff --git a/src/modeling_library/map/backprop.jl b/src/modeling_library/map/backprop.jl index e14dbd05..4af8a6fd 100644 --- a/src/modeling_library/map/backprop.jl +++ b/src/modeling_library/map/backprop.jl @@ -20,7 +20,7 @@ function choice_gradients(trace::VectorTrace{MapType,T,U}, selection::Selection, for key=1:len subtrace = trace.subtraces[key] - sub_selection = selection[key] + sub_selection = get_subselection(selection, key) kernel_retval_grad = (retval_grad == nothing) ? nothing : retval_grad[key] (kernel_arg_grad::Tuple, kernel_value_choices, kernel_gradient_choices) = choice_gradients( subtrace, sub_selection, kernel_retval_grad) diff --git a/src/modeling_library/map/regenerate.jl b/src/modeling_library/map/regenerate.jl index 8634bbc3..93476134 100644 --- a/src/modeling_library/map/regenerate.jl +++ b/src/modeling_library/map/regenerate.jl @@ -15,7 +15,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, local prev_subtrace::U local retval::T - subselection = selection[key] + subselection = get_subselection(selection, key) kernel_args = get_args_for_key(args, key) # get new subtrace with recursive call to regenerate() @@ -48,7 +48,7 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, selection::Selection, key:: state::MapRegenerateState{T,U}) where {T,U} local subtrace::U local retval::T - if !isempty(selection[key]) + if !isempty(get_subselection(selection, key)) error("Tried to select new address in regenerate at key $key") end kernel_args = get_args_for_key(args, key) diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 1f101725..c93903b4 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -51,7 +51,7 @@ project(trace::RecurseTrace, ::EmptySelection) = 0. # recurse assignment wrapper # ############################## -struct RecurseTraceChoiceMap <: ChoiceMap +struct RecurseTraceChoiceMap <: AddressTree{Value} trace::RecurseTrace end @@ -64,7 +64,7 @@ end get_address_schema(::Type{RecurseTraceChoiceMap}) = DynamicAddressSchema() -function get_submap(choices::RecurseTraceChoiceMap, +function get_subtree(choices::RecurseTraceChoiceMap, addr::Tuple{Int,Val{:production}}) idx = addr[1] if !haskey(choices.trace.aggregation_traces, idx) @@ -74,7 +74,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, +function get_subtree(choices::RecurseTraceChoiceMap, addr::Tuple{Int,Val{:aggregation}}) idx = addr[1] if !haskey(choices.trace.aggregation_traces, idx) @@ -84,11 +84,11 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -get_submap(choices::RecurseTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) +get_subtree(choices::RecurseTraceChoiceMap, addr::Pair) = _get_subtree(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () -function get_submaps_shallow(choices::RecurseTraceChoiceMap) +function get_subtrees_shallow(choices::RecurseTraceChoiceMap) production_iter = (((idx, Val(:production)), get_choices(subtrace)) for (idx, subtrace) in choices.trace.production_traces) aggregation_iter = (((idx, Val(:aggregation)), get_choices(subtrace)) diff --git a/src/modeling_library/unfold/regenerate.jl b/src/modeling_library/unfold/regenerate.jl index 7e8480ca..6b464fa8 100644 --- a/src/modeling_library/unfold/regenerate.jl +++ b/src/modeling_library/unfold/regenerate.jl @@ -17,7 +17,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, local prev_state::T local new_state::T - subselection = selection[key] + subselection = get_subselection(selection, key) prev_state = (key == 1) ? state.init_state : state.retval[key-1] kernel_args = (key, prev_state, params...) @@ -55,7 +55,7 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, local prev_state::T local new_state::T - if !isempty(selection[key]) + if !isempty(get_subselection(selection, key)) error("Cannot select new addresses in regenerate") end prev_state = (key == 1) ? state.init_state : state.retval[key-1] diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 607ee747..e98709d3 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -65,21 +65,21 @@ end function project(trace::VectorTrace, selection::Selection) weight = 0. for key=1:trace.len - subselection = selection[key] + subselection = get_subselection(selection, key) weight += project(trace.subtraces[key], subselection) end weight end project(trace::VectorTrace, ::EmptySelection) = trace.noise -struct VectorTraceChoiceMap{GenFnType, T, U} <: ChoiceMap +struct VectorTraceChoiceMap{GenFnType, T, U} <: AddressTree{Value} trace::VectorTrace{GenFnType, T, U} end @inline Base.isempty(assignment::VectorTraceChoiceMap) = assignment.trace.num_nonempty == 0 @inline get_address_schema(::Type{VectorTraceChoiceMap}) = VectorAddressSchema() -@inline function get_submap(choices::VectorTraceChoiceMap, addr::Int) +@inline function get_subtree(choices::VectorTraceChoiceMap, addr::Int) if addr <= choices.trace.len get_choices(choices.trace.subtraces[addr]) else @@ -87,11 +87,11 @@ end end end -@inline function get_submaps_shallow(choices::VectorTraceChoiceMap) +@inline function get_subtrees_shallow(choices::VectorTraceChoiceMap) ((i, get_choices(choices.trace.subtraces[i])) for i=1:choices.trace.len) end -@inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline get_subtree(choices::VectorTraceChoiceMap, addr::Pair) = _get_subtree(choices, addr) ############################################ # code shared by vector-shaped combinators # diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 15479f12..5d902d8b 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -266,7 +266,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, subtrace_fieldname = get_subtrace_fieldname(node) call_selection = gensym("call_selection") if node in selected_calls - push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_getindex))(selection, $(QuoteNode(Val(node.addr)))))) + push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_get_subtree))(selection, $(QuoteNode(Val(node.addr)))))) else push!(stmts, :($call_selection = EmptySelection())) end diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index 62df493a..042d8bf1 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -12,7 +12,7 @@ function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) subtrace = get_subtrace_fieldname(node) subselection = gensym("subselection") if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) + push!(state.stmts, :($subselection = $(GlobalRef(Gen, :static_get_subtree))(selection, Val($addr)))) push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $subselection))) else push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $(GlobalRef(Gen, :EmptySelection))()))) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index a38cfdbd..225eedfb 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -2,7 +2,7 @@ # assignment wrapper # ###################### -struct StaticIRTraceAssmt{T} <: ChoiceMap +struct StaticIRTraceAssmt{T} <: AddressTree{Value} trace::T end @@ -10,8 +10,8 @@ function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) -@inline get_submap(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) -@inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) +@inline get_subtree(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) +@inline get_subtree(choices::StaticIRTraceAssmt, addr::Pair) = _get_subtree(choices, addr) ######################### # trace type generation # @@ -26,7 +26,7 @@ abstract type StaticIRTrace <: Trace end @inline static_haskey(trace::StaticIRTrace, ::Val) = false @inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_get_subtree(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] @@ -122,7 +122,7 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) +function generate_get_subtrees_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.call_nodes addr = node.addr @@ -130,7 +130,7 @@ function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) push!(elements, :(($(QuoteNode(addr)), $(GlobalRef(Gen, :get_choices))($subtrace)))) end Expr(:function, - Expr(:call, GlobalRef(Gen, :get_submaps_shallow), + Expr(:call, GlobalRef(Gen, :get_subtrees_shallow), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), Expr(:block, Expr(:tuple, elements...))) end @@ -151,7 +151,7 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) for node in ir.call_nodes push!(call_getindex_exprs, quote - function $(GlobalRef(Gen, :static_getindex))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) + function $(GlobalRef(Gen, :static_get_subtree))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) return $(GlobalRef(Gen, :get_retval))(trace.$(get_subtrace_fieldname(node))) end end @@ -192,7 +192,7 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) + get_submaps_shallow_expr = generate_get_subtrees_shallow(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index b927072e..7f3bb5da 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -251,7 +251,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, call_subselection = gensym("call_subselection") if node in fwd.constrained_or_selected_calls || node in fwd.input_changed if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) + push!(stmts, :($call_subselection = $(GlobalRef(Gen, :static_get_subtree))(selection, Val($addr)))) else push!(stmts, :($call_subselection = $(GlobalRef(Gen, :EmptySelection))())) end diff --git a/src/trie.jl b/src/trie.jl index d7a3f962..2e5f6e1d 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -4,7 +4,7 @@ # TODO: use the `DynamicAddressTree` implementation for this -struct Trie{K,V} <: ChoiceMap +struct Trie{K,V} <: AddressTree{Value} leaf_nodes::Dict{K,V} internal_nodes::Dict{K,Trie{K,V}} end diff --git a/test/runtests.jl b/test/runtests.jl index a6e408e0..74923603 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,10 +68,10 @@ include("autodiff.jl") include("diff.jl") include("selection.jl") include("assignment.jl") -# include("dynamic_dsl.jl") -# include("optional_args.jl") -# include("static_ir/static_ir.jl") -# include("static_dsl.jl") -# include("tilde_sugar.jl") -# include("inference/inference.jl") -# include("modeling_library/modeling_library.jl \ No newline at end of file +include("dynamic_dsl.jl") +include("optional_args.jl") +include("static_ir/static_ir.jl") +include("static_dsl.jl") +include("tilde_sugar.jl") +include("inference/inference.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/selection.jl b/test/selection.jl index b165bfd8..af062903 100644 --- a/test/selection.jl +++ b/test/selection.jl @@ -1,5 +1,4 @@ @testset begin "dynamic selection" - s = select(:x, :y => :z, :y => :w) # test Base.in From 9c7436067a880dc22cb087e41d6169fb6bc8d19b Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 6 Jul 2020 12:10:08 -0400 Subject: [PATCH 34/34] remove old choicemap & selection code --- src/deprecated/address.jl | 371 ------------------ src/deprecated/choice_map/array_interface.jl | 106 ----- src/deprecated/choice_map/choice_map.jl | 283 ------------- .../choice_map/dynamic_choice_map.jl | 149 ------- src/deprecated/choice_map/nested_view.jl | 80 ---- .../choice_map/static_choice_map.jl | 160 -------- 6 files changed, 1149 deletions(-) delete mode 100644 src/deprecated/address.jl delete mode 100644 src/deprecated/choice_map/array_interface.jl delete mode 100644 src/deprecated/choice_map/choice_map.jl delete mode 100644 src/deprecated/choice_map/dynamic_choice_map.jl delete mode 100644 src/deprecated/choice_map/nested_view.jl delete mode 100644 src/deprecated/choice_map/static_choice_map.jl diff --git a/src/deprecated/address.jl b/src/deprecated/address.jl deleted file mode 100644 index ad33cfe7..00000000 --- a/src/deprecated/address.jl +++ /dev/null @@ -1,371 +0,0 @@ -################### -# address schemas # -################### - -abstract type AddressSchema end - -struct StaticAddressSchema <: AddressSchema - keys::Set{Symbol} -end - -Base.keys(schema::StaticAddressSchema) = schema.keys - -struct VectorAddressSchema <: AddressSchema end -struct SingleDynamicKeyAddressSchema <: AddressSchema end -struct DynamicAddressSchema <: AddressSchema end -struct EmptyAddressSchema <: AddressSchema end -struct AllAddressSchema <: AddressSchema end - -export AddressSchema -export StaticAddressSchema # hierarchical -export VectorAddressSchema # hierarchical -export SingleDynamicKeyAddressSchema # hierarchical -export DynamicAddressSchema # hierarchical -export EmptyAddressSchema -export AllAddressSchema - -###################### -# abstract selection # -###################### - -""" - abstract type Selection end - -Abstract type for selections of addresses. - -All selections implement the following methods: - - Base.in(addr, selection) - -Is the address selected? - - Base.getindex(selection, addr) - -Get the subselection at the given address. - - Base.isempty(selection) - -Is the selection guaranteed to be empty? - - get_address_schema(T) - -Return a shallow, compile-time address schema, where `T` is the concrete type of the selection. -""" -abstract type Selection end - -Base.in(addr, ::Selection) = false -Base.getindex(::Selection, addr) = EmptySelection() - -export Selection - -########################## -# hierarchical selection # -########################## - -""" - abstract type HierarchicalSelection <: Selection end - -Abstract type for selections that have a notion of sub-selections. - - get_subselections(selection::HierarchicalSelection) - -Return an iterator over pairs of addresses and subselections at associated addresses. -""" -abstract type HierarchicalSelection <: Selection end - -export HierarchicalSelection -export get_subselections - -################### -# empty selection # -################### - -""" - struct EmptySelection <: Selection end - -A singleton type for a selection that is always empty. -""" -struct EmptySelection <: Selection end -get_address_schema(::Type{EmptySelection}) = EmptyAddressSchema() -Base.isempty(::EmptySelection) = true - -export EmptySelection - -################# -# all selection # -################# - -""" - struct AllSelection <: Selection end - -A singleton type for a selection that contains all choices at or under an address. -""" -struct AllSelection <: Selection end -get_address_schema(::Type{AllSelection}) = AllAddressSchema() -Base.isempty(::AllSelection) = false # it is not guaranteed to be empty -Base.in(addr, ::AllSelection) = true -Base.getindex(::AllSelection, addr) = AllSelection() - -export AllSelection - -######################## -# complement selection # -######################## - -struct ComplementSelection <: Selection - complement::Selection -end -get_address_schema(::Type{ComplementSelection}) = DynamicAddressSchema() -Base.isempty(::ComplementSelection) = false # it is not guaranteed to be empty -Base.in(addr, selection::ComplementSelection) = !(addr in selection.complement) -function Base.getindex(selection::ComplementSelection, addr) - ComplementSelection(selection.complement[addr]) -end - -""" - comp_selection = complement(selection::Selection) - -Return a selection that is the complement of the given selection. - -An address is in the selection if it is not in the complement selection. -""" -function complement(selection::Selection) - ComplementSelection(selection) -end - -export ComplementSelection, complement - -#################### -# static selection # -#################### - -# R is a tuple of symbols.. -# T is a tuple of symbols -# U the tuple type of subselections - -""" - struct StaticSelection{T,U} <: HierarchicalSelection .. end - -A hierarchical selection whose keys are among its type parameters. -""" -struct StaticSelection{T,U} <: HierarchicalSelection - subselections::NamedTuple{T,U} -end -StaticSelection(::NamedTuple{(), Tuple{}}) = EmptySelection() - -function Base.isempty(selection::StaticSelection{T,U}) where {T,U} - length(R) == 0 && all(isempty(node) for node in selection.subselections) -end - -function get_address_schema(::Type{StaticSelection{T,U}}) where {T,U} - keys = Set{Symbol}() - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -get_subselections(selection::StaticSelection) = pairs(selection.subselections) - -function static_getindex(selection::StaticSelection, ::Val{A}) where {A} - selection.subselections[A] -end - -# TODO do we no longer need static_in? - -function Base.getindex(selection::StaticSelection, addr::Symbol) - if haskey(selection.subselections, addr) - selection.subselections[addr] - else - EmptySelection() - end -end - -function Base.getindex(selection::StaticSelection, addr::Pair) - (first, rest) = addr - subselection = selection.subselections[first] - subselection[rest] -end - -function Base.in(addr::Symbol, selection::StaticSelection{T,U}) where {T,U} - addr in T && selection.subselections[addr] == AllSelection() -end - -function Base.in(addr::Pair, selection::StaticSelection) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - in(subselection, rest) - else - false - end -end - -function StaticSelection(other::HierarchicalSelection) - keys_and_subselections = collect(get_subselections(other)) - if length(keys_and_subselections) > 0 - (keys, subselections) = collect(zip(keys_and_subselections...)) - else - (keys, subselections) = ((), ()) - end - types = map(typeof, subselections) - StaticSelection(NamedTuple{keys}(subselections)) -end - -export StaticSelection - - -##################### -# dynamic selection # -##################### - -""" - struct DynamicSelection <: HierarchicalSelection .. end - -A hierarchical, mutable, selection with arbitrary addresses. - -Can be mutated with the following methods: - - - Base.push!(selection::DynamicSelection, addr) - -Add the address and all of its sub-addresses to the selection. - -Example: -```julia -selection = select() -@assert !(:x in selection) -push!(selection, :x) -@assert :x in selection -``` - - set_subselection!(selection::DynamicSelection, addr, other::Selection) - -Change the selection status of the given address and its sub-addresses that defined by `other`. - -Example: -```julia -selection = select(:x) -@assert :x in selection -subselection = select(:y) -set_subselection!(selection, :x, subselection) -@assert (:x => :y) in selection -@assert !(:x in selection) -``` - -Note that `set_subselection!` does not copy data in `other`, so `other` may be mutated by a later calls to `set_subselection!` for addresses under `addr`. -""" -struct DynamicSelection <: HierarchicalSelection - # note: only store subselections for which isempty = false - subselections::Dict{Any,Selection} -end - -function Base.isempty(selection::DynamicSelection) - isempty(selection.subselections) -end - -DynamicSelection() = DynamicSelection(Dict{Any,Selection}()) - -get_address_schema(::Type{DynamicSelection}) = DynamicAddressSchema() - -function Base.in(addr, selection::DynamicSelection) - if haskey(selection.subselections, addr) - selection.subselections[addr] == AllSelection() - else - false - end -end - -function Base.in(addr::Pair, selection::DynamicSelection) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - @assert !isempty(subselection) - rest in subselection - else - false - end -end - -function Base.getindex(selection::DynamicSelection, addr) - if haskey(selection.subselections, addr) - selection.subselections[addr] - else - EmptySelection() - end -end - -function Base.getindex(selection::DynamicSelection, addr::Pair) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - @assert !isempty(subselection) - getindex(subselection, rest) - else - EmptySelection() - end -end - -function Base.push!(selection::DynamicSelection, addr) - selection.subselections[addr] = AllSelection() -end - -function Base.push!(selection::DynamicSelection, addr::Pair) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - else - subselection = DynamicSelection() - selection.subselections[first] = subselection - end - push!(subselection, rest) -end - -function set_subselection!(selection::DynamicSelection, addr, other::Selection) - selection.subselections[addr] = other -end - -function set_subselection!(selection::DynamicSelection, addr::Pair, other::Selection) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - else - subselection = DynamicSelection() - selection.subselections[first] = subselection - end - set_subselection!(subselection, rest, other) -end - -get_subselections(selection::DynamicSelection) = selection.subselections - -""" - selection = select(addrs...) - -Return a selection containing a given set of addresses. - -Examples: -```julia -selection = select(:x, "foo", :y => 1 => :z) -selection = select() -selection = select(:x => 1, :x => 2) -``` -""" -function select(addrs...) - selection = DynamicSelection() - for addr in addrs - push!(selection, addr) - end - selection -end - -""" - selection = selectall() - -Construct a selection that includes all random choices. -""" -function selectall() - AllSelection() -end - -export DynamicSelection -export select, selectall, set_subselection! diff --git a/src/deprecated/choice_map/array_interface.jl b/src/deprecated/choice_map/array_interface.jl deleted file mode 100644 index 314b9bc5..00000000 --- a/src/deprecated/choice_map/array_interface.jl +++ /dev/null @@ -1,106 +0,0 @@ -### interface for to_array and fill_array ### - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -Implementation - -The default implmentation of `fill_array` will populate the array by sorting -the addresses of the choicemap using the `sort` function, then iterating over -each submap in this order and filling the array for that submap. - -To override the default implementation of `to_array`, -a concrete subtype `T <: AddressTree{Value}` should implement the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. - -(This is for performance; it is more efficient to fill in values in a preallocated array -by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(c::Value{<:T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = get_value(c) - 1 -end -function _fill_array!(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} - value = get_value(c) - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - -# default _fill_array! implementation -function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - key_to_submap = collect(get_submaps_shallow(choices)) - sort!(key_to_submap, by = ((key, submap),) -> key) - idx = start_idx - for (key, submap) in key_to_submap - n_written = _fill_array!(submap, arr, idx) - idx += n_written - end - idx - start_idx -end - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -The order in which addresses are populated with values from the array -should match the order in which the array is populated with values -in a call to `to_array(proto_choices, T)`. By default, -this means sorting the top-level addresses for `proto_choices` -and then filling in the submaps depth-first in this order. - -# Implementation - -To support `from_array`, a concrete subtype `T AddressTree{Value}` must implement -the following method: - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`. Return the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::Value, arr::Vector, start_idx::Int) - (1, Value(arr[start_idx])) -end -function _from_array(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(get_value(c)) - (n_read, Value(arr[start_idx:start_idx+n_read-1])) -end - -export to_array, from_array \ No newline at end of file diff --git a/src/deprecated/choice_map/choice_map.jl b/src/deprecated/choice_map/choice_map.jl deleted file mode 100644 index 74ddac6f..00000000 --- a/src/deprecated/choice_map/choice_map.jl +++ /dev/null @@ -1,283 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - ChoiceMapGetValueError - -The error returned when a user attempts to call `get_value` -on an choicemap for an address which does not contain a value in that choicemap. -""" -struct ChoiceMapGetValueError <: Exception end -showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - get_submaps_shallow(choices::ChoiceMap) - -Returns an iterable collection of tuples `(address, submap)` -for each top-level address associated with `choices`. -(This includes `Value`s.) -""" -function get_submaps_shallow end - -""" - get_submap(choices::ChoiceMap, addr) - -Return the submap at the given address, or `EmptyChoiceMap` -if there is no submap at the given address. -""" -function get_submap end - -# provide _get_submap so when users overwrite get_submap(choices::CustomChoiceMap, addr::Pair) -# they can just call _get_submap for convenience if they want -@inline function _get_submap(choices::ChoiceMap, addr::Pair) - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -""" - has_value(choices::ChoiceMap) - -Returns true if `choices` is a `Value`. - - has_value(choices::ChoiceMap, addr) - -Returns true if `choices` has a value stored at address `addr`. -""" -function has_value end -@inline has_value(::ChoiceMap) = false -@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) - -""" - get_value(choices::ChoiceMap) - -Returns the value stored on `choices` is `choices` is a `Value`; -throws a `ChoiceMapGetValueError` if `choices` is not a `Value`. - - get_value(choices::ChoiceMap, addr) -Returns the value stored in the submap with address `addr` or throws -a `ChoiceMapGetValueError` if no value exists at this address. - -A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end -@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) -@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) -@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) - -""" -schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -# get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow -""" - get_values_shallow(choices::ChoiceMap) - -Returns an iterable collection of tuples `(address, value)` -for each value stored at a top-level address in `choices`. -(Works by applying a filter to `get_submaps_shallow`, -so this internally requires iterating over every submap.) -""" -function get_values_shallow(choices::ChoiceMap) - ( - (addr, get_value(submap)) - for (addr, submap) in get_submaps_shallow(choices) - if has_value(submap) - ) -end - -""" - get_nonvalue_submaps_shallow(choices::ChoiceMap) - -Returns an iterable collection of tuples `(address, submap)` -for every top-level submap stored in `choices` which is -not a `Value`. -(Works by applying a filter to `get_submaps_shallow`, -so this internally requires iterating over every submap.) -""" -function get_nonvalue_submaps_shallow(choices::ChoiceMap) - (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) -end - -# a choicemap is empty if it has no submaps and no value -Base.isempty(c::ChoiceMap) = all(((addr, submap),) -> isempty(submap), get_submaps_shallow(c)) && !has_value(c) - -""" - EmptyChoiceMap - -A choicemap with no submaps or values. -""" -struct EmptyChoiceMap <: ChoiceMap end - -@inline has_value(::EmptyChoiceMap, addr...) = false -@inline get_value(::EmptyChoiceMap) = throw(ChoiceMapGetValueError()) -@inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() -@inline Base.isempty(::EmptyChoiceMap) = true -@inline get_submaps_shallow(::EmptyChoiceMap) = () -@inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -@inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true -@inline Base.:(==)(::ChoiceMap, ::EmptyChoiceMap) = false -@inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false - -""" - Value - -A leaf-node choicemap. Stores a single value. -""" -struct Value{T} <: ChoiceMap - val::T -end - -@inline has_value(choices::Value) = true -@inline get_value(choices::Value) = choices.val -@inline get_submap(choices::Value, addr) = EmptyChoiceMap() -@inline get_submaps_shallow(choices::Value) = () -@inline Base.:(==)(a::Value, b::Value) = a.val == b.val -@inline Base.isapprox(a::Value, b::Value) = isapprox(a.val, b.val) -@inline get_address_schema(::Type{<:Value}) = AllAddressSchema() - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, submap) in get_submaps_shallow(choices1) - set_submap!(choices, key, merge(submap, get_submap(choices2, key))) - end - for (key, submap) in get_submaps_shallow(choices2) - if isempty(get_submap(choices1, key)) - set_submap!(choices, key, submap) - end - end - choices -end -Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c -Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c -Base.merge(c::Value, ::EmptyChoiceMap) = c -Base.merge(::EmptyChoiceMap, c::Value) = c -Base.merge(::Value, ::ChoiceMap) = error("Values cannot be merged") -Base.merge(::ChoiceMap, ::Value) = error("Values cannot be merged") - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, submap) in get_submaps_shallow(a) - if get_submap(b, addr) != submap - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if get_submap(a, addr) != submap - return false - end - end - return true -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(get_submap(b, addr), submap) - return false - end - end - return true -end - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - # TODO: return a `FilteringChoiceMap` which does this filtering lazily! - output = choicemap() - for (addr, submap) in get_submaps_shallow(choices) - if has_value(submap) && addr in selection - output[addr] = get_value(submap) - else - subselection = get_subselection(selection, addr) - set_submap!(output, addr, get_selected(submap, subselection)) - end - end - output -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_nonvalue_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -export ChoiceMap, Value, EmptyChoiceMap -export _get_submap, get_submap, get_submaps_shallow -export get_value, has_value -export get_values_shallow, get_nonvalue_submaps_shallow -export get_address_schema, get_selected -export ChoiceMapGetValueError - -include("array_interface.jl") -include("dynamic_choice_map.jl") -include("static_choice_map.jl") -include("nested_view.jl") \ No newline at end of file diff --git a/src/deprecated/choice_map/dynamic_choice_map.jl b/src/deprecated/choice_map/dynamic_choice_map.jl deleted file mode 100644 index f54a221c..00000000 --- a/src/deprecated/choice_map/dynamic_choice_map.jl +++ /dev/null @@ -1,149 +0,0 @@ -####################### -# dynamic assignment # -####################### - -""" - struct DynamicChoiceMap AddressTree{Value} .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -struct DynamicChoiceMap AddressTree{Value} - submaps::Dict{Any, ChoiceMap} - function DynamicChoiceMap() - new(Dict()) - end -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for (addr, value) in tuples - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, submap) in get_submaps_shallow(other) - if submap isa Value - set_submap!(choices, addr, submap) - else - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - end - choices -end - -DynamicChoiceMap(other::Value) = error("Cannot convert a Value to a DynamicChoiceMap") - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -@inline get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps -@inline get_submap(choices::DynamicChoiceMap, addr) = get(choices.submaps, addr, EmptyChoiceMap()) -@inline get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.submaps, addr) - choices.submaps[addr] = Value(value) -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if !haskey(choices.submaps, first) - choices.submaps[first] = DynamicChoiceMap() - elseif has_value(choices.submaps[first]) - error("Tried to create assignment at $first but there was already a value there.") - end - set_value!(choices.submaps[first], rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node::ChoiceMap) - delete!(choices.submaps, addr) - if !isempty(new_node) - choices.submaps[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node::ChoiceMap) - (first, rest) = addr - if !haskey(choices.submaps, first) - choices.submaps[first] = DynamicChoiceMap() - elseif has_value(choices.submaps[first]) - error("Tried to create assignment at $first but there was already a value there.") - end - set_submap!(choices.submaps[first], rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - choices = DynamicChoiceMap() - keys_sorted = sort(collect(keys(proto_choices.submaps))) - idx = start_idx - for key in keys_sorted - (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) - idx += n_read - choices.submaps[key] = submap - end - (idx - start_idx, choices) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! \ No newline at end of file diff --git a/src/deprecated/choice_map/nested_view.jl b/src/deprecated/choice_map/nested_view.jl deleted file mode 100644 index 6ecf9707..00000000 --- a/src/deprecated/choice_map/nested_view.jl +++ /dev/null @@ -1,80 +0,0 @@ -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -ChoiceMapNestedView(cm::Value) = get_value(cm) -ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") - -function Base.getindex(choices::ChoiceMapNestedView, addr) - ChoiceMapNestedView(get_submap(choices.choice_map, addr)) -end - -function Base.iterate(c::ChoiceMapNestedView) - itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) - r = Base.iterate(itr) - if r === nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (itr, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (itr, st) = state - r = Base.iterate(itr, st) - if r === nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (itr, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) - -Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map - -function Base.length(cv::ChoiceMapNestedView) - length(collect(get_submaps_shallow(cv.choice_map))) -end -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view \ No newline at end of file diff --git a/src/deprecated/choice_map/static_choice_map.jl b/src/deprecated/choice_map/static_choice_map.jl deleted file mode 100644 index 3adcb775..00000000 --- a/src/deprecated/choice_map/static_choice_map.jl +++ /dev/null @@ -1,160 +0,0 @@ -###################### -# static assignment # -###################### - -struct StaticChoiceMap{Addrs, SubmapTypes} AddressTree{Value} - submaps::NamedTuple{Addrs, SubmapTypes} - function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} - new{Addrs, SubmapTypes}(submaps) - end -end - -function StaticChoiceMap(;addrs_to_vals_and_maps...) - addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) - maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : Value(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) - StaticChoiceMap(NamedTuple{addrs}(maps)) -end - -@inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) -@inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# TODO: would it be faster to do static_get_submap? -function get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, addr::Symbol) where {Addrs, SubmapTypes} - if addr in Addrs - choices.submaps[addr] - else - EmptyChoiceMap() - end -end - -@generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} - if A in Addrs - quote choices.submaps[A] end - else - quote EmptyChoiceMap() end - end -end -@inline static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() - -@inline static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) -@inline static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) - -# convert a nonvalue choicemap all of whose top-level-addresses -# are symbols into a staticchoicemap at the top level -function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = get_submaps_shallow(other) - if length(keys_and_nodes) > 0 - addrs = Tuple(key for (key, _) in keys_and_nodes) - submaps = Tuple(submap for (_, submap) in keys_and_nodes) - else - addrs = () - submaps = () - end - StaticChoiceMap(NamedTuple{addrs}(submaps)) -end -StaticChoiceMap(other::Value) = error("Cannot convert a Value to a StaticChoiceMap") -StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() - -# TODO: deep conversion to static choicemap - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple{(key1, key2)}((choices1, choices2))) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any submaps at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if length(collect(get_submaps_shallow(choices))) != 2 - error("Not a pair") - end - (get_submap(choices, key1), get_submap(choices, key2)) -end - -@generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, - choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} - - addr_to_type1 = Dict{Symbol, Type{<:ChoiceMap}}() - addr_to_type2 = Dict{Symbol, Type{<:ChoiceMap}}() - for (i, addr) in enumerate(Addrs1) - addr_to_type1[addr] = SubmapTypes1.parameters[i] - end - for (i, addr) in enumerate(Addrs2) - addr_to_type2[addr] = SubmapTypes2.parameters[i] - end - - merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) - submap_exprs = [] - - for addr in merged_addrs - type1 = get(addr_to_type1, addr, EmptyChoiceMap) - type2 = get(addr_to_type2, addr, EmptyChoiceMap) - if ((type1 <: Value && type2 != EmptyChoiceMap) - || (type2 <: Value && type1 != EmptyChoiceMap)) - error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") - end - if type1 <: EmptyChoiceMap - push!(submap_exprs, - quote choices2.submaps.$addr end - ) - elseif type2 <: EmptyChoiceMap - push!(submap_exprs, - quote choices1.submaps.$addr end - ) - else - push!(submap_exprs, - quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end - ) - end - end - - quote - StaticChoiceMap(NamedTuple{$merged_addrs}(($(submap_exprs...),))) - end -end - -@generated function _from_array(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, - arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} - - perm = sortperm(collect(Addrs)) - sorted_addrs = Addrs[perm] - submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) - - exprs = [quote idx = start_idx end] - - for (idx, addr) in zip(perm, sorted_addrs) - submap_var_name = gensym(addr) - submap_var_names[idx] = submap_var_name - push!(exprs, - quote - (n_read, $submap_var_name) = _from_array(proto_choices.submaps.$addr, arr, idx) - idx += n_read - end - ) - end - - quote - $(exprs...) - submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) - choices = StaticChoiceMap(submaps) - (idx - start_idx, choices) - end -end - -function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} - StaticAddressSchema(Set(Addrs)) -end - -export StaticChoiceMap -export pair, unpair -export static_get_submap, static_get_value \ No newline at end of file