Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor @unsafe_region macro #233

Merged
merged 14 commits into from
Nov 12, 2024
Merged

Conversation

jofrevalles
Copy link
Member

@jofrevalles jofrevalles commented Nov 8, 2024

Summary

This PR resolves #230 by refactoring the @unsafe_region macro. In this PR, we added a new field on the TensorNetwork struct check_index_sizes::Ref{Bool}. When true, there will be no inconsistent-dimension checks on the push! functions.

Now, @unsafe_region just modifies this value in the field of the specific TensorNetwork passed in the arguments, and keeps track of all the possible copies created inside the macro block.

Example

Let's see how we fixed the example we show in #230:

julia> using Tenet

julia> tn = TensorNetwork([
           Tensor(ones(2, 2), [:a, :b]),
           Tensor(ones(2, 2), [:b, :c])
       ])
TensorNetwork (#tensors=2, #inds=3)

julia> tn2 = copy(tn)
TensorNetwork (#tensors=2, #inds=3)

julia> Tenet.@unsafe_region tn begin # Now it errors properly!
           tensor = Tensor(ones(3, 2), [:c, :d])
           push!(tn, tensor)
           push!(tn2, tensor)  # tn2 is not specified in @unsafe_region
           @test length(tensors(tn)) == 3
           pop!(tn, tensor)
       end
ERROR: DimensionMismatch: size(tensor,c)=3 but should be equal to size(tn,c)=2
Stacktrace:
 [1] push!(tn::TensorNetwork, tensor::Tensor{Float64, 2, Matrix{Float64}})
   @ Tenet ~/git/Tenet.jl/src/TensorNetwork.jl:381
 [2] macro expansion
   @ ./REPL[14]:4 [inlined]
 [3] top-level scope

But if instead tn2 is created inside the macro block:

julia> tn = TensorNetwork([
                  Tensor(ones(2, 2), [:a, :b]),
                  Tensor(ones(2, 2), [:b, :c])
              ])
TensorNetwork (#tensors=2, #inds=3)

julia> Tenet.@unsafe_region tn begin # Here still errors since at the end `tn2` is inconsistent
                  tensor = Tensor(ones(3, 2), [:c, :d])
                  push!(tn, tensor)
                  tn2 = copy(tn)
                  push!(tn2, tensor)  # tn2 is not specified in @unsafe_region
                  @test length(tensors(tn)) == 3
                  pop!(tn, tensor)
       end
ERROR: DimensionMismatch: Inconsistent size of indices
Stacktrace:
 [1] top-level scope
   @ ~/git/Tenet.jl/src/TensorNetwork.jl:351

julia> Tenet.@unsafe_region tn begin ### No errors thrown!
                  tensor = Tensor(ones(3, 2), [:c, :d])
                  push!(tn, tensor)
                  tn2 = copy(tn)
                  push!(tn2, tensor)  # tn2 is not specified in @unsafe_region
                  @test length(tensors(tn)) == 3
                  pop!(tn, tensor)
                  pop!(tn2, tensor)
       end
3×2 Tensor{Float64, 2, Matrix{Float64}}:
 1.0  1.0
 1.0  1.0
 1.0  1.0

So we get the expected correct behavior.

src/TensorNetwork.jl Outdated Show resolved Hide resolved
Base.values(uc::UnsafeContext) = map(x -> x.value, uc.refs)

# Global stack to manage nested unsafe contexts
const _unsafe_context_stack = Ref{Vector{UnsafeContext}}(Vector{UnsafeContext}())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm you don't need this? the idea was that @unsafe_region first creates a UnsafeContext and passes that the TensorNetwork

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need this to check if there is an active UnsafeContext when you call copy and related functions.

src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
@jofrevalles jofrevalles requested a review from mofeing November 11, 2024 11:27
Copy link
Member

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I really like it. It's quite elegant the way it solves.

Just want to talk to you about the global variable where it stores the current unsafe context.

src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
@jofrevalles jofrevalles requested a review from mofeing November 12, 2024 10:28
@jofrevalles
Copy link
Member Author

@mofeing global const is gone and everything looks nice now :)

Copy link
Member

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Aside from the small requests, would you mind doing 2 more extra things?

  1. Rename it to UnsafeScope: The word context has too many meanings and @unsafe_region is actually a scope. It also took from inspiration from the ScopedValues.jl package.
  2. Implement a get_unsafe_context method to retrieve it from any AbstractTensorNetwork: This way we wouldn't need to call TensorNetwork(tn) on the beginning of @unsafe_region. You will need to change any reference to tn.unsafe to get_unsafe_context(tn).
get_unsafe_context(tn::AbstractTensorNetwork) = TensorNetwork(tn).unsafe

src/TensorNetwork.jl Outdated Show resolved Hide resolved
end
end

# Define the @unsafe_region macro
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is a lil bit auto-descriptive? maybe we can just remove it or add a docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe its better to remove it. I prefer not to have a docstring here since this is an internal operation, for now.

src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
src/TensorNetwork.jl Outdated Show resolved Hide resolved
@jofrevalles
Copy link
Member Author

@mofeing I added some of your suggestions and changed the name to UnsafeScope, I think this is ready for merge.

@jofrevalles jofrevalles requested a review from mofeing November 12, 2024 12:11
@mofeing mofeing merged commit 6acfe1c into master Nov 12, 2024
5 checks passed
@mofeing mofeing deleted the feature/refactor-unsafe-region branch November 12, 2024 12:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

@unsafe_region applies to all TensorNetworks within the block, not just the one passed in the argument
2 participants