-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor @unsafe_region
macro
#233
Conversation
src/TensorNetwork.jl
Outdated
Base.values(uc::UnsafeContext) = map(x -> x.value, uc.refs) | ||
|
||
# Global stack to manage nested unsafe contexts | ||
const _unsafe_context_stack = Ref{Vector{UnsafeContext}}(Vector{UnsafeContext}()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmm you don't need this? the idea was that @unsafe_region
first creates a UnsafeContext
and passes that the TensorNetwork
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need this to check if there is an active UnsafeContext
when you call copy and related functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I really like it. It's quite elegant the way it solves.
Just want to talk to you about the global variable where it stores the current unsafe context.
@mofeing global |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Aside from the small requests, would you mind doing 2 more extra things?
- Rename it to
UnsafeScope
: The word context has too many meanings and@unsafe_region
is actually a scope. It also took from inspiration from the ScopedValues.jl package. - Implement a
get_unsafe_context
method to retrieve it from anyAbstractTensorNetwork
: This way we wouldn't need to callTensorNetwork(tn)
on the beginning of@unsafe_region
. You will need to change any reference totn.unsafe
toget_unsafe_context(tn)
.
get_unsafe_context(tn::AbstractTensorNetwork) = TensorNetwork(tn).unsafe
src/TensorNetwork.jl
Outdated
end | ||
end | ||
|
||
# Define the @unsafe_region macro |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment is a lil bit auto-descriptive? maybe we can just remove it or add a docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe its better to remove it. I prefer not to have a docstring
here since this is an internal operation, for now.
@mofeing I added some of your suggestions and changed the name to |
Summary
This PR resolves #230 by refactoring the
@unsafe_region
macro. In this PR, we added a new field on theTensorNetwork
structcheck_index_sizes::Ref{Bool}
. Whentrue
, there will be no inconsistent-dimension checks on thepush!
functions.Now,
@unsafe_region
just modifies this value in thefield
of the specificTensorNetwork
passed in the arguments, and keeps track of all the possible copies created inside the macro block.Example
Let's see how we fixed the example we show in #230:
But if instead
tn2
is created inside the macro block:So we get the expected correct behavior.