Skip to content

Commit

Permalink
Fix autodiff on contraction of Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 12, 2023
1 parent f335f50 commit 00aa2a3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
5 changes: 5 additions & 0 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ function ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...)
return T(data, inds; meta...), Tensor_pullback
end

@non_differentiable copy(tn::TensorNetwork)

# NOTE fix problem with vector generator in `contract`
@non_differentiable Tenet.__omeinsum_sym2str(x)

# WARN type-piracy
@non_differentiable setdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
@non_differentiable union(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
Expand Down
16 changes: 12 additions & 4 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,27 @@ for op in [
@eval Base.$op(a::Tensor{A,0}, b::Tensor{B,0}) where {A,B} = broadcast($op, a, b)
end

# NOTE used for marking non-differentiability
# NOTE use `String[...]` code instead of `map` or broadcasting to set eltype in empty cases
__omeinsum_sym2str(x) = String[string(i) for i in x]

"""
contract(a::Tensor[, b::Tensor, dims=nonunique([inds(a)..., inds(b)...])])
Perform tensor contraction operation.
"""
function contract(a::Tensor, b::Tensor; dims = ((inds(a), inds(b))))
ia = inds(a)
ib = inds(b)
ia = inds(a) |> collect
ib = inds(b) |> collect
i = (dims, ia, ib)

ic = tuple(setdiff(ia ib, i isa Base.AbstractVecOrTuple ? i : (i,))...)
ic = setdiff(ia ib, i isa Base.AbstractVecOrTuple ? i : (i,))::Vector{Symbol}

_ia = __omeinsum_sym2str(ia)
_ib = __omeinsum_sym2str(ib)
_ic = __omeinsum_sym2str(ic)

data = EinCode((String.(ia), String.(ib)), String.(ic))(parent(a), parent(b))
data = EinCode((_ia, _ib), _ic)(parent(a), parent(b))

# TODO merge metadata?
return Tensor(data, ic)
Expand Down

0 comments on commit 00aa2a3

Please sign in to comment.