Skip to content

Commit

Permalink
Refactor code and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Nov 6, 2023
1 parent 1e4b31f commit 96355ca
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 198 deletions.
24 changes: 7 additions & 17 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,14 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pull
@non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)

function ChainRulesCore.ProjectTo(tn::T) where {T<:AbstractTensorNetwork}
# TODO create function to extract extra fields
fields = map(fieldnames(T)) do fieldname
if fieldname === :tensors
:tensors => ProjectTo(tn.tensors)
else
fieldname => getfield(tn, fieldname)
end
end
ProjectTo{T}(; fields...)
ProjectTo{T}(; tensors = ProjectTo(tensors(tn)))

Check warning on line 31 in ext/TenetChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetChainRulesCoreExt.jl#L31

Added line #L31 was not covered by tests
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:AbstractTensorNetwork}
function (projector::ProjectTo{T})(dx::T) where {T<:AbstractTensorNetwork}
Tangent{TensorNetwork}(tensors = projector.tensors(tensors(tn)))

Check warning on line 35 in ext/TenetChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetChainRulesCoreExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end

function (projector::ProjectTo{T})(dx::Tangent{T}) where {T<:AbstractTensorNetwork}

Check warning on line 38 in ext/TenetChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetChainRulesCoreExt.jl#L38

Added line #L38 was not covered by tests
dx.tensors isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end
Expand All @@ -49,13 +45,7 @@ function Base.:+(x::T, Δ::Tangent{TensorNetwork}) where {T<:AbstractTensorNetwo
tensors = map(+, tensors(x), Δ.tensors)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
T(tensors)

Check warning on line 48 in ext/TenetChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetChainRulesCoreExt.jl#L48

Added line #L48 was not covered by tests
end

function ChainRulesCore.frule((_, Δ), T::Type{<:AbstractTensorNetwork}, tensors)
Expand Down
15 changes: 2 additions & 13 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,8 @@ using Tenet: AbstractTensorNetwork
using FiniteDifferences

function FiniteDifferences.to_vec(x::T) where {T<:AbstractTensorNetwork}
x_vec, back = to_vec(x.tensors)
function TensorNetwork_from_vec(v)
tensors = back(v)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
end
x_vec, back = to_vec(tensors(x))
TensorNetwork_from_vec(v) = T(back(v))

Check warning on line 9 in ext/TenetFiniteDifferencesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetFiniteDifferencesExt.jl#L8-L9

Added lines #L8 - L9 were not covered by tests

return x_vec, TensorNetwork_from_vec
end
Expand Down
6 changes: 3 additions & 3 deletions ext/TenetMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::AbstractTensorNetw
tn = transform(tn, Tenet.HyperindConverter)

# TODO how to mark multiedges? (i.e. parallel edges)
graph = SimpleGraph([Edge(tensors...) for (_, tensors) in tn.indices if length(tensors) > 1])
graph = SimpleGraph([Edge(tensors...) for (_, tensors) in tn.indexmap if length(tensors) > 1])

Check warning on line 54 in ext/TenetMakieExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetMakieExt.jl#L54

Added line #L54 was not covered by tests

# TODO recognise `copytensors` by using `DeltaArray` or `Diagonal` representations
copytensors = findall(tensor -> any(flatinds -> issetequal(inds(tensor), flatinds), keys(hypermap)), tensors(tn))
ghostnodes = map(inds(tn, :open)) do ind
ghostnodes = map(inds(tn, :open)) do index

Check warning on line 58 in ext/TenetMakieExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetMakieExt.jl#L58

Added line #L58 was not covered by tests
# create new ghost node
add_vertex!(graph)
node = nv(graph)

# connect ghost node
tensor = only(tn.indices[ind])
tensor = only(tn.indexmap[index])

Check warning on line 64 in ext/TenetMakieExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TenetMakieExt.jl#L64

Added line #L64 was not covered by tests
add_edge!(graph, node, tensor)

return node
Expand Down
27 changes: 21 additions & 6 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TensorNetwork() = TensorNetwork(Tensor[])
Return a shallow copy of a [`TensorNetwork`](@ref).
"""
Base.copy(tn::T) where {T<:AbstractTensorNetwork} = TensorNetwork(copy(tn.indexmap), copy(tn.tensormap))
Base.copy(tn::T) where {T<:AbstractTensorNetwork} = TensorNetwork(tensors(tn))

Check warning on line 42 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L42

Added line #L42 was not covered by tests

Base.summary(io::IO, tn::AbstractTensorNetwork) = print(io, "$(length(tn.tensormap))-tensors $(typeof(tn))")

Check warning on line 44 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L44

Added line #L44 was not covered by tests
Base.show(io::IO, tn::AbstractTensorNetwork) =
Expand Down Expand Up @@ -115,7 +115,7 @@ function Base.push!(tn::AbstractTensorNetwork, tensor::Tensor)
end

tn.tensormap[tensor] = collect(inds(tensor))
for index in inds(tensor)
for index in unique(inds(tensor))
push!(get!(tn.indexmap, index, Tensor[]), tensor)

Check warning on line 119 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L117-L119

Added lines #L117 - L119 were not covered by tests
end

Expand Down Expand Up @@ -174,7 +174,7 @@ Base.delete!(tn::AbstractTensorNetwork, x) = (_ = pop!(tn, x); tn)
tryprune!(tn::AbstractTensorNetwork, i::Symbol) = (x = isempty(tn.indexmap[i]) && delete!(tn.indexmap, i); x)

Check warning on line 174 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L174

Added line #L174 was not covered by tests

function Base.delete!(tn::AbstractTensorNetwork, tensor::Tensor)
for index in inds(tensor)
for index in unique(inds(tensor))
filter!(Base.Fix1(!==, tensor), tn.indexmap[index])
tryprune!(tn, index)
end
Expand Down Expand Up @@ -212,16 +212,31 @@ function Base.replace!(tn::AbstractTensorNetwork, pair::Pair{<:Tensor,<:Tensor})
return tn

Check warning on line 212 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L212

Added line #L212 was not covered by tests
end

function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{Symbol,Symbol}...)
first.(old_new) keys(tn.indexmap) ||

Check warning on line 216 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L215-L216

Added lines #L215 - L216 were not covered by tests
throw(ArgumentError("set of old indices must be a subset of current indices"))
isdisjoint(last.(old_new), keys(tn.indexmap)) ||

Check warning on line 218 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L218

Added line #L218 was not covered by tests
throw(ArgumentError("set of new indices must be disjoint to current indices"))
for pair in old_new
replace!(tn, pair)
end

Check warning on line 222 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L220-L222

Added lines #L220 - L222 were not covered by tests
return tn
end

function Base.replace!(tn::AbstractTensorNetwork, old_new::Pair{Symbol,Symbol})
old, new = old_new
old keys(tn.indexmap) || throw(ArgumentError("index $old does not exist"))
new keys(tn.indexmap) || throw(ArgumentError("index $new is already present"))

Check warning on line 229 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L228-L229

Added lines #L228 - L229 were not covered by tests

for tensor in tn.indexmap[old]
delete!(tn, tensor)
# NOTE `copy` because collection underneath is mutated
for tensor in copy(tn.indexmap[old])

Check warning on line 232 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L232

Added line #L232 was not covered by tests
# NOTE do not `delete!` before `push!` as indices can be lost due to `tryprune!`
push!(tn, replace(tensor, old_new))
delete!(tn, tensor)

Check warning on line 235 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L234-L235

Added lines #L234 - L235 were not covered by tests
end

delete!(tn.indexmap, old)

Check warning on line 238 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L238

Added line #L238 was not covered by tests

return tn
end

Expand All @@ -246,7 +261,7 @@ Return tensors whose indices match with the list of indices `i`.
select(tn::AbstractTensorNetwork, i::Symbol) = copy(tn.indexmap[i])
select(tn::AbstractTensorNetwork, is::AbstractVecOrTuple{Symbol}) =

Check warning on line 262 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L261-L262

Added lines #L261 - L262 were not covered by tests
filter(tn.indexmap[first(is)]) do tensor
issetequal(inds(tensor), is)
is inds(tensor)

Check warning on line 264 in src/TensorNetwork.jl

View check run for this annotation

Codecov / codecov/patch

src/TensorNetwork.jl#L264

Added line #L264 was not covered by tests
end

"""
Expand Down
81 changes: 7 additions & 74 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ end
function transform!(tn::AbstractTensorNetwork, config::AntiDiagonalGauging)
skip_inds = isempty(config.skip) ? inds(tn, set = :open) : config.skip

for idx in keys(tn.tensors)
tensor = tn.tensors[idx]

for tensor in keys(tn.tensormap)

Check warning on line 179 in src/Transformations.jl

View check run for this annotation

Codecov / codecov/patch

src/Transformations.jl#L179

Added line #L179 was not covered by tests
anti_diag_axes = find_anti_diag_axes(parent(tensor), atol = config.atol)

for (i, j) in anti_diag_axes # loop over all anti-diagonal axes
Expand Down Expand Up @@ -215,56 +213,14 @@ end
function transform!(tn::AbstractTensorNetwork, config::ColumnReduction)
skip_inds = isempty(config.skip) ? inds(tn, set = :open) : config.skip

for tensor in tn.tensors
zero_columns = find_zero_columns(parent(tensor), atol = config.atol)
zero_columns_by_axis = [filter(x -> x[1] == d, zero_columns) for d in 1:length(size(tensor))]

# find non-zero column for each axis
non_zero_columns =
[(d, setdiff(1:size(tensor, d), [x[2] for x in zero_columns_by_axis[d]])) for d in 1:length(size(tensor))]

# remove axes that have more than one non-zero column
axes_to_reduce = [(d, c[1]) for (d, c) in filter(x -> length(x[2]) == 1, non_zero_columns)]

# First try to reduce the whole index if only one column is non-zeros
for (d, c) in axes_to_reduce # loop over all column axes
ix_i = inds(tensor)[d]

# do not reduce output indices
if ix_i skip_inds
continue
end
for tensor in tensors(tn)
for (dim, index) in enumerate(inds(tensor))
index skip_inds && continue

Check warning on line 218 in src/Transformations.jl

View check run for this annotation

Codecov / codecov/patch

src/Transformations.jl#L216-L218

Added lines #L216 - L218 were not covered by tests

# reduce all tensors where ix_i appears
for (ind, t) in enumerate(tensors(tn))
if ix_i inds(t)
# Replace the tensor with the reduced one
new_tensor = selectdim(parent(t), findfirst(l -> l == ix_i, inds(t)), c)
new_inds = filter(l -> l != ix_i, inds(t))
zeroslices = iszero.(eachslice(tensor, dims = dim))
any(zeroslices) || continue

Check warning on line 221 in src/Transformations.jl

View check run for this annotation

Codecov / codecov/patch

src/Transformations.jl#L220-L221

Added lines #L220 - L221 were not covered by tests

tn.tensors[ind] = Tensor(new_tensor, new_inds)
end
end
delete!(tn.indices, ix_i)
end

# Then try to reduce the dimensionality of the index in the other tensors
zero_columns = find_zero_columns(parent(tensor), atol = config.atol)
for (d, c) in zero_columns # loop over all column axes
ix_i = inds(tensor)[d]

# do not reduce output indices
if ix_i skip_inds
continue
end

# reduce all tensors where ix_i appears
for (ind, t) in enumerate(tensors(tn))
if ix_i inds(t)
reduced_dims = [i == ix_i ? filter(j -> j != c, 1:size(t, i)) : (1:size(t, i)) for i in inds(t)]
tn.tensors[ind] = Tensor(view(parent(t), reduced_dims...), inds(t))
end
end
slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices))

Check warning on line 223 in src/Transformations.jl

View check run for this annotation

Codecov / codecov/patch

src/Transformations.jl#L223

Added line #L223 was not covered by tests
end
end

Expand Down Expand Up @@ -321,29 +277,6 @@ function transform!(tn::AbstractTensorNetwork, config::SplitSimplification)
return tn
end

function find_zero_columns(x; atol = 1e-12)
dims = size(x)

# Create an initial set of all possible column pairs
zero_columns = Set((d, c) for d in 1:length(dims) for c in 1:dims[d])

# Iterate over each element in tensor
for index in CartesianIndices(x)
val = x[index]

# For each non-zero element, eliminate the corresponding column from the zero_columns set
if abs(val) > atol
for d in 1:length(dims)
c = index[d]
delete!(zero_columns, (d, c))
end
end
end

# Now the zero_columns set only contains column pairs where all elements are zero
return collect(zero_columns)
end

function find_diag_axes(x; atol = 1e-12)
# skip 1D tensors
ndims(parent(x)) == 1 && return []
Expand Down
Loading

0 comments on commit 96355ca

Please sign in to comment.