Skip to content

Commit

Permalink
Implement in-place contract! methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed May 17, 2024
1 parent d9e38c2 commit 8254549
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,9 @@ function contract(a::Tensor, b::Tensor; dims=(∩(inds(a), inds(b))), out=nothin
out
end

_ia = __omeinsum_sym2str(ia)
_ib = __omeinsum_sym2str(ib)
_ic = __omeinsum_sym2str(ic)

data = EinCode((_ia, _ib), _ic)(parent(a), parent(b))

return Tensor(data, ic)
data = OMEinsum.get_output_array((parent(a), parent(b)), [size(i in ia ? a : b, i) for i in ic]; fillzero=false)
c = Tensor(data, ic)
return contract!(c, a, b)
end

function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing)
Expand All @@ -61,9 +57,9 @@ function contract(a::Tensor; dims=nonunique(inds(a)), out=nothing)
out
end

data = EinCode((String.(ia),), String.(ic))(parent(a))

return Tensor(data, ic)
data = OMEinsum.get_output_array((parent(a),), [size(a, i) for i in ic]; fillzero=false)
c = Tensor(data, ic)
return contract!(c, a)
end

contract(a::Union{T,AbstractArray{T,0}}, b::Tensor{T}) where {T} = contract(Tensor(a), b)
Expand All @@ -72,6 +68,26 @@ contract(a::AbstractArray{<:Any,0}, b::AbstractArray{<:Any,0}) = only(contract(T
contract(a::Number, b::Number) = contract(fill(a), fill(b))
contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs...), tensors)

function contract!(c::Tensor, a::Tensor, b::Tensor)
ixs = (inds(a), inds(b))
iy = inds(c)
xs = (parent(a), parent(b))
y = parent(c)
size_dict = merge!(Dict{Symbol,Int}.([inds(a) .=> size(a), inds(b) .=> size(b)])...)

einsum!(ixs, iy, xs, y, true, false, size_dict)
return c
end

function contract!(y::Tensor, x::Tensor)
ixs = (inds(x),)
iy = inds(y)
size_dict = Dict{Symbol,Int}(inds(x) .=> size(x))

einsum!(ixs, iy, (parent(x),), parent(y), true, false, size_dict)
return y
end

"""
*(::Tensor, ::Tensor)
Expand Down

0 comments on commit 8254549

Please sign in to comment.