diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index d5e59b3c6..ad8e601b5 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -45,7 +45,6 @@ struct TensorNetwork <: AbstractTensorNetwork tensormap = IdDict{Tensor,Vector{Symbol}}(tensor => inds(tensor) for tensor in tensors) indexmap = reduce(tensors; init=Dict{Symbol,Vector{Tensor}}()) do dict, tensor - # TODO check for inconsistent dimensions? for index in inds(tensor) # TODO use lambda? `Tensor[]` might be reused push!(get!(dict, index, Tensor[]), tensor) @@ -53,6 +52,12 @@ struct TensorNetwork <: AbstractTensorNetwork dict end + # Check for inconsistent index dimensions + for ind in keys(indexmap) + dims = map(tensor -> size(tensor, ind), indexmap[ind]) + length(unique(dims)) == 1 || throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)")) + end + return new(indexmap, tensormap, CachedField{Vector{Tensor}}()) end end diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index 6b5448fbb..c30b97b31 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -23,7 +23,7 @@ @testset "TensorNetwork with tensors of different dimensions" begin tensor1 = Tensor(zeros(2, 2), (:i, :j)) tensor2 = Tensor(zeros(3, 3), (:j, :k)) - @test_skip @test_throws DimensionMismatch tn = TensorNetwork([tensor1, tensor2]) + @test_throws DimensionMismatch tn = TensorNetwork([tensor1, tensor2]) end end