Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the scope of Looplet/Unfurl closures #608

Open
willow-ahrens opened this issue Jun 18, 2024 · 2 comments
Open

Limit the scope of Looplet/Unfurl closures #608

willow-ahrens opened this issue Jun 18, 2024 · 2 comments

Comments

@willow-ahrens
Copy link
Owner

In order to support parallelism on more exciting architectures, we must redesign the unfurl function to avoid carrying closed variables from one dimension to another, so that we can re-label those symbols. As an example of the problem, consider the following Gustavson's matmul:

function spgemm_finch_gustavson_kernel_parallel(A, B)
    # @assert Threads.nthreads() >= 2
    z = default(A) * default(B) + false
    C = Tensor(Dense(Separate(SparseList(Element(z)))))
    w = Tensor(Dense(Element(z)))
    @finch_code begin
        C .= 0
        for j=parallel(_)
            w .= 0
            for k=_, i=_; w[i] += A[i, k] * B[k, j] end
            for i=_; C[i, j] = w[i] end
        end
    end
    @finch begin
        C .= 0
        for j=parallel(_)
            w .= 0
            for k=_, i=_; w[i] += A[i, k] * B[k, j] end
            for i=_; C[i, j] = w[i] end
        end
    end
    return C
end

This produces:

quote
    C_lvl = ((ex.bodies[1]).bodies[1]).tns.bind.lvl
    C_lvl_2 = C_lvl.lvl
    C_lvl_3 = C_lvl_2.lvl
    C_lvl_2_val = C_lvl_2.lvl.val
    w_lvl = (((ex.bodies[1]).bodies[2]).body.bodies[1]).tns.bind.lvl
    w_lvl_val = w_lvl.lvl.val
    A_lvl = ((((ex.bodies[1]).bodies[2]).body.bodies[2]).body.body.rhs.args[1]).tns.bind.lvl
    A_lvl_2 = A_lvl.lvl
    A_lvl_2_val = A_lvl_2.lvl.val
    B_lvl = ((((ex.bodies[1]).bodies[2]).body.bodies[2]).body.body.rhs.args[2]).tns.bind.lvl
    B_lvl_2 = B_lvl.lvl
    B_lvl_2_val = B_lvl_2.lvl.val
    B_lvl_2.shape == A_lvl.shape || throw(DimensionMismatch("mismatched dimension limits ($(B_lvl_2.shape) != $(A_lvl.shape))"))
    result = nothing
    pos_stop = A_lvl_2.shape * B_lvl.shape
    Finch.resize_if_smaller!(C_lvl_2_val, pos_stop)
    Finch.fill_range!(C_lvl_2_val, 0.0, 1, pos_stop)
    B_lvl_2_val = (Finch).moveto(B_lvl_2_val, CPU(Threads.nthreads()))
    A_lvl_2_val = (Finch).moveto(A_lvl_2_val, CPU(Threads.nthreads()))
    val_3 = C_lvl_2_val
    C_lvl_2_val = (Finch).moveto(C_lvl_2_val, CPU(Threads.nthreads()))
    Threads.@threads for i_7 = 1:Threads.nthreads()
            val_4 = w_lvl_val
            w_lvl_val = (Finch).moveto(w_lvl_val, CPUThread(i_7, CPU(Threads.nthreads()), Serial()))
            phase_start_2 = max(1, 1 + fld(B_lvl.shape * (i_7 + -1), Threads.nthreads()))
            phase_stop_2 = min(B_lvl.shape, fld(B_lvl.shape * i_7, Threads.nthreads()))
            if phase_stop_2 >= phase_start_2
                for j_6 = phase_start_2:phase_stop_2
                    B_lvl_q = (1 - 1) * B_lvl.shape + j_6
                    C_lvl_q = (1 - 1) * B_lvl.shape + j_6
                    Finch.resize_if_smaller!(w_lvl_val, A_lvl_2.shape)
                    Finch.fill_range!(w_lvl_val, 0.0, 1, A_lvl_2.shape)
                    for k_4 = 1:B_lvl_2.shape
                        A_lvl_q = (1 - 1) * A_lvl.shape + k_4
                        B_lvl_2_q = (B_lvl_q - 1) * B_lvl_2.shape + k_4
                        B_lvl_3_val = B_lvl_2_val[B_lvl_2_q]
                        for i_8 = 1:A_lvl_2.shape
                            w_lvl_q = (1 - 1) * A_lvl_2.shape + i_8
                            A_lvl_2_q = (A_lvl_q - 1) * A_lvl_2.shape + i_8
                            A_lvl_3_val = A_lvl_2_val[A_lvl_2_q]
                            w_lvl_val[w_lvl_q] += B_lvl_3_val * A_lvl_3_val
                        end
                    end
                    resize!(w_lvl_val, A_lvl_2.shape)
                    for i_9 = 1:A_lvl_2.shape
                        C_lvl_2_q = (C_lvl_q - 1) * A_lvl_2.shape + i_9
                        w_lvl_q_2 = (1 - 1) * A_lvl_2.shape + i_9
                        w_lvl_2_val = w_lvl_val[w_lvl_q_2]
                        C_lvl_2_val[C_lvl_2_q] = w_lvl_2_val
                    end
                end
            end
            w_lvl_val = val_4
        end
    resize!(val_3, A_lvl_2.shape * B_lvl.shape)
    result = (C = Tensor((DenseLevel){Int64}((DenseLevel){Int64}(C_lvl_3, A_lvl_2.shape), B_lvl.shape)),)
    result
end

We try to moveto the workspace to local memory but we end up setting the
global variable w_lvl_val = copy(w_lvl_val). In general, some looplet
structures currently capture state from previous outer loops, state which is not
captured by moveto because it's not accounted for explicitly. The solution
here is that we need to represent the state of all subfibers (including COO and
Masks) as an explicit struct with field names we can call moveto on. This means that each call to unfurl
would occur in the scope of the loop being unrolled, rather than at the top
level call to instantiate. More generally, it would benefit the project if we
could distinguish between instantiate and unfurl at every loop nest, to
insert side effects that eagerly expand as soon as a tensor is unfurled, versus
side effects that wait until the last possible moment to expand with
instantiate. This might clean up some of the lowering for Scalars as well.

This may involve a fair amount of code change, but it would be for the better.
For example, the state carried by a COO level would be encapsulated in a more
explicit COOSubLevel struct that reflects the current COO index search variables.

In the end, this would enable moveto to "re-virtualize" tensors upon entering
parallel regions, which is critical for local variables.

@wraith1995
Copy link
Collaborator

wraith1995 commented Jun 18, 2024

@willow-ahrens I want to add that I think this issue is also potentially related to "closure" issue that we've had in generating parallel code. In particular, part of the issue there is that we don't know much about the local variables once something has been lowered so we loose track and need to recover what needs to be passed into a closure/function for running parallel code. If everyone had structs describing the local state with types, then at lowering time we could figure out the fields currently in use and just pass those in.

@willow-ahrens
Copy link
Owner Author

willow-ahrens commented Jun 18, 2024

Right, although the closure issue is also resolved with

Finch.jl/src/util/shims.jl

Lines 59 to 208 in a7f5b89

"""
@barrier args... ex
Wrap `ex` in a let block that captures all free variables in `ex` that are bound in the arguments. This is useful for
ensuring that the variables in `ex` are not mutated by the arguments.
"""
macro barrier(args_ex...)
(args, ex) = args_ex[1:end-1], args_ex[end]
f = gensym()
esc(quote
$f = @closure ($(args...),) -> $ex
$f()
end)
end
# wrap_closure is taken from https://github.com/c42f/FastClosures.jl
#
# The FastClosures.jl package is licensed under the MIT "Expat" License:
# Copyright (c) 2017: Claire Foster.
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# Wrap `closure_expression` in a `let` block to improve efficiency.
function wrap_closure(module_, ex)
bound_vars = Symbol[]
captured_vars = Symbol[]
if @capture ex :->(:tuple(~args...), ~body)
elseif @capture ex :function(:call(~f, ~args...), ~body)
push!(bound_vars, f)
else
throw(ArgumentError("Argument to @closure must be a closure! (Got $closure_expression)"))
end
append!(bound_vars, [v for v in args])
find_var_uses!(captured_vars, bound_vars, body)
quote
let $(map(var -> :($var = $var), captured_vars)...)
$ex
end
end
end
"""
@closure closure_expression
Wrap the closure definition `closure_expression` in a let block to encourage
the julia compiler to generate improved type information. For example:
```julia
callfunc(f) = f()
function foo(n)
for i=1:n
if i >= n
# Unlikely event - should be fast. However, capture of `i` inside
# the closure confuses the julia-0.6 compiler and causes it to box
# the variable `i`, leading to a 100x performance hit if you remove
# the `@closure`.
callfunc(@closure ()->println("Hello \$i"))
end
end
end
```
There's nothing nice about this - it's a heuristic workaround for some
inefficiencies in the type information inferred by the julia 0.6 compiler.
However, it can result in large speedups in many cases, without the need to
restructure the code to avoid the closure.
"""
macro closure(ex)
esc(wrap_closure(__module__, ex))
end
# Utility function - fill `varlist` with all accesses to variables inside `ex`
# which are not bound before being accessed. Variables which were bound
# before access are returned in `bound_vars` as a side effect.
#
# With works with the surface syntax so it unfortunately has to reproduce some
# of the lowering logic (and consequently likely has bugs!)
function find_var_uses!(capture_vars, bound_vars, ex)
if isa(ex, Symbol)
if !(ex in bound_vars)
#occursin("threadsfor", string(ex)) && error()
ex capture_vars || push!(capture_vars, ex)
end
return capture_vars
elseif isa(ex, Expr)
if ex.head == :(=)
find_var_uses_lhs!(capture_vars, bound_vars, ex.args[1])
find_var_uses!(capture_vars, bound_vars, ex.args[2])
elseif @capture ex :->(:tuple(~args...), ~body)
body_vars = copy(bound_vars)
for arg in args
find_var_uses_lhs!(capture_vars, body_vars, arg)
end
find_var_uses!(capture_vars, body_vars, body)
elseif ex.head == :kw
find_var_uses!(capture_vars, bound_vars, ex.args[2])
elseif ex.head == :for || ex.head == :while || ex.head == :let
# New scopes
inner_bindings = copy(bound_vars)
find_var_uses!(capture_vars, inner_bindings, ex.args)
elseif ex.head == :try
# New scope + ex.args[2] is a new binding
find_var_uses!(capture_vars, copy(bound_vars), ex.args[1])
catch_bindings = copy(bound_vars)
!isa(ex.args[2], Symbol) || push!(catch_bindings, ex.args[2])
find_var_uses!(capture_vars,catch_bindings,ex.args[3])
if length(ex.args) > 3
finally_bindings = copy(bound_vars)
find_var_uses!(capture_vars,finally_bindings,ex.args[4])
end
elseif ex.head == :call
find_var_uses!(capture_vars, bound_vars, ex.args[2:end])
elseif ex.head == :local
foreach(ex.args) do e
if !isa(e, Symbol)
find_var_uses!(capture_vars, bound_vars, e)
end
end
elseif ex.head == :(::)
find_var_uses_lhs!(capture_vars, bound_vars, ex)
else
find_var_uses!(capture_vars, bound_vars, ex.args)
end
end
capture_vars
end
find_var_uses!(capture_vars, bound_vars, exs::Vector) =
foreach(e->find_var_uses!(capture_vars, bound_vars, e), exs)
# Find variable uses on the left hand side of an assignment. Some of what may
# be variable uses turn into bindings in this context (cf. tuple unpacking).
function find_var_uses_lhs!(capture_vars, bound_vars, ex)
if isa(ex, Symbol)
ex bound_vars || push!(bound_vars, ex)
elseif isa(ex, Expr)
if ex.head == :tuple
find_var_uses_lhs!(capture_vars, bound_vars, ex.args)
elseif ex.head == :(::)
find_var_uses!(capture_vars, bound_vars, ex.args[2])
find_var_uses_lhs!(capture_vars, bound_vars, ex.args[1])
else
find_var_uses!(capture_vars, bound_vars, ex.args)
end
end
end
find_var_uses_lhs!(capture_vars, bound_vars, exs::Vector) = foreach(e->find_var_uses_lhs!(capture_vars, bound_vars, e), exs)
We can certainly copy the logic from there to inside the compiler, so that as long as we rename state variables correctly we can generate explicit closures at the moment of parallel lowering, based on the variables that we end up using in the code (i.e. lower it and see what variables it uses). You are correct though that this would enable the analysis you describe where it would be possible to write a function that lists all variables that are reachable from a block of finch code, without lowering it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants