Skip to content

Commit

Permalink
Enhance Tranformations docs (#86)
Browse files Browse the repository at this point in the history
* Add plot example for DiagonalReduction and RankSimplification transformations

* Update transformation docs

* Add dual metadata in copy tensor from DiagonalReduction

* Update docs/src/transformations.md

* Skip 1D tensors from checking anti-diagonal/diagonal condition

* Apply suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* Enhance HyperindConverter docstring

* Add more transformations in the docs example

* Change labels

* Apply suggestions from code review

* Put docstring before plot in DiagonalReduction

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
jofrevalles and mofeing authored Sep 13, 2023
1 parent f51f45b commit 5edc49d
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 15 deletions.
287 changes: 273 additions & 14 deletions docs/src/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,301 @@ transform
transform!
```

# Example
Here we show how can we reduce the complexity of the tensor network by applying a tranformation to it:
# Example: RQC simplification
Here we show how can we reduce the complexity of the tensor network by applying a tranformation to it. We take as an example the Sycamore circuit from the [Google's quantum supremacy paper](https://www.nature.com/articles/s41586-019-1666-5)

```@setup plot
using Makie
Makie.inline!(true)
using CairoMakie
CairoMakie.activate!(type = "svg")
using Tenet
using NetworkLayout
using Pkg
Pkg.add("QuacIO")
function smooth_annotation!(f; color=Makie.RGBAf(110 // 256, 170 // 256, 250 // 256, 60 // 256), xlims=[-2, 2], ylims=[-2, 2], offset_x=0.0, offset_y=0.0, radius_x=1.0, radius_y=1.0, num_waves=5, fluctuation_amplitude=0.1, phase_shift=0.0)
ax = Axis(f)
hidedecorations!(ax)
hidespines!(ax)
# Define limits of the plot
xlims!(ax, xlims...)
ylims!(ax, ylims...)
# Create a perturbed filled shape
theta = LinRange(0, 2π, 100)
fluctuations = fluctuation_amplitude .* sin.(num_waves .* theta .+ phase_shift)
# Apply the fluctuations and radius scaling
perturbed_radius_x = radius_x .+ fluctuations
perturbed_radius_y = radius_y .+ fluctuations
circle_points = [Point2f((perturbed_radius_x[i]) * cos(theta[i]) + offset_x,
(perturbed_radius_y[i]) * sin(theta[i]) + offset_y) for i in 1:length(theta)]
poly!(ax, circle_points, color=color, closed=true)
end
bg_blue = Makie.RGBAf(110 // 256, 170 // 256, 250 // 256, 50 // 256) #hide
orange = Makie.RGBf(240 // 256, 180 // 256, 100 // 256) #hide
red = Makie.RGBf(240 // 256, 90 // 256, 70 // 256) #hide
```
```@example transformation
```@example plot
using QuacIO
using CairoMakie
using Tenet
set_theme!(resolution=(800,400)) # hide
sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 72, 73, 74, 75, 76, 83, 84, 85, 94]
circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format=QuacIO.Qflex(), sites=sites)
tn = TensorNetwork(circuit)
transformed_tn = transform(tn, Tenet.RankSimplification)
# Apply transformations to the tensor network
transformed_tn = transform(tn, [Tenet.AntiDiagonalGauging, Tenet.DiagonalReduction, Tenet.ColumnReduction, Tenet.RankSimplification])
fig = Figure() # hide
ax1 = Axis(fig[1, 1]; title="Original TensorNetwork") # hide
p1 = plot!(ax1, tn; node_size=5.) # hide
ax2 = Axis(fig[1, 2], title="Transformed TensorNetwork") # hide
p2 = plot!(ax2, transformed_tn; node_size=5.) # hide
ax1.titlesize=20 # hide
ax2.titlesize=20 # hide
ax1 = Axis(fig[1, 1]) # hide
p1 = plot!(ax1, tn; edge_width=0.75, node_size=8., node_attr=(strokecolor=:black, strokewidth=0.5)) # hide
ax2 = Axis(fig[1, 2]) # hide
p2 = plot!(ax2, transformed_tn; edge_width=0.75, node_size=8., node_attr=(strokecolor=:black, strokewidth=0.5)) # hide
ax1.titlesize, ax2.titlesize = 20, 20 # hide
hidedecorations!(ax1) # hide
hidespines!(ax1) # hide
hidedecorations!(ax2) # hide
hidespines!(ax2) # hide
Label(fig[1, 1, Bottom()], "Original") # hide
Label(fig[1, 2, Bottom()], "Transformed") # hide
fig # hide
```

# Transformations

## Hyperindex converter
```@docs
Tenet.HyperindConverter
```

## Diagonal reduction
```@docs
Tenet.DiagonalReduction
Tenet.RankSimplification
```

```@example plot
set_theme!(resolution=(800,200)) # hide
fig = Figure() #hide
data = zeros(Float64, 2, 2, 2, 2) #hide
for i in 1:2 #hide
for j in 1:2 #hide
for k in 1:2 #hide
data[i, i, j, k] = k #hide
end #hide
end #hide
end #hide
A = Tensor(data, (:i, :j, :k, :l)) #hide
B = Tensor(rand(2, 2), (:i, :m)) #hide
C = Tensor(rand(2, 2), (:j, :n)) #hide
tn = TensorNetwork([A, B, C]) #hide
reduced = transform(tn, Tenet.DiagonalReduction) #hide
smooth_annotation!( #hide
fig[1, 1]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = -0.21, #hide
offset_y = -0.42, #hide
radius_x = 0.38, #hide
radius_y = 0.8, #hide
num_waves = 6, #hide
fluctuation_amplitude = 0.02, #hide
phase_shift = 0.0) #hide
plot!(fig[1, 1], tn, layout=Spring(iterations=1000, C=0.5, seed=100); node_color=[red, orange, orange, :black, :black,:black, :black]) #hide
smooth_annotation!( #hide
fig[1, 2]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = 0.1, #hide
offset_y = -0.35, #hide
radius_x = 0.38, #hide
radius_y = 1.1, #hide
num_waves = 5, #hide
fluctuation_amplitude = 0.02, #hide
phase_shift = 1.9) #hide
plot!(fig[1, 2], reduced, layout=Spring(iterations=1000, C=0.5, seed=100), node_color=[orange, orange, red, :black, :black, :black, :black, :black]) #hide
Label(fig[1, 1, Bottom()], "Original") #hide
Label(fig[1, 2, Bottom()], "Transformed") #hide
fig #hide
```

## Anti-diagonal reduction
```@docs
Tenet.AntiDiagonalGauging
```

## Rank simplification
```@docs
Tenet.RankSimplification
```

```@example plot
set_theme!(resolution=(800,200)) # hide
fig = Figure() #hide
A = Tensor(rand(2, 2, 2, 2), (:i, :j, :k, :l)) #hide
B = Tensor(rand(2, 2), (:i, :m)) #hide
C = Tensor(rand(2, 2, 2), (:m, :n, :o)) #hide
E = Tensor(rand(2, 2, 2, 2), (:o, :p, :q, :j)) #hide
tn = TensorNetwork([A, B, C, E]) #hide
reduced = transform(tn, Tenet.RankSimplification) #hide
smooth_annotation!( #hide
fig[1, 1]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = -0.32, #hide
offset_y = -0.5, #hide
radius_x = 0.25, #hide
radius_y = 0.94, #hide
num_waves = 6, #hide
fluctuation_amplitude = 0.01, #hide
phase_shift = 0.0) #hide
plot!(fig[1, 1], tn, layout=Spring(iterations=1000, C=0.5, seed=20); node_color=[orange, red, orange, orange, :black, :black, :black, :black, :black]) #hide
smooth_annotation!( #hide
fig[1, 2]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = 0.12, #hide
offset_y = -0.62, #hide
radius_x = 0.18, #hide
radius_y = 0.46, #hide
num_waves = 5, #hide
fluctuation_amplitude = 0.01, #hide
phase_shift = 0) #hide
plot!(fig[1, 2], reduced, layout=Spring(iterations=1000, C=0.5, seed=1); node_color=[red, orange, orange, :black, :black, :black, :black, :black]) #hide
Label(fig[1, 1, Bottom()], "Original") #hide
Label(fig[1, 2, Bottom()], "Transformed") #hide
fig #hide
```

## Column reduction

```@docs
Tenet.ColumnReduction
```

```@example plot
set_theme!(resolution=(800,200)) # hide
fig = Figure() #hide
data = rand(3, 3, 3) #hide
data[:, 1:2, :] .= 0 #hide
A = Tensor(data, (:i, :j, :k)) #hide
B = Tensor(rand(3, 3), (:j, :l)) #hide
C = Tensor(rand(3, 3), (:l, :m)) #hide
tn = TensorNetwork([A, B, C]) #hide
reduced = transform(tn, Tenet.ColumnReduction) #hide
smooth_annotation!( #hide
fig[1, 1]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = -1.12, #hide
offset_y = -0.22, #hide
radius_x = 0.35, #hide
radius_y = 0.84, #hide
num_waves = 4, #hide
fluctuation_amplitude = 0.02, #hide
phase_shift = 0.0) #hide
plot!(fig[1, 1], tn, layout=Spring(iterations=1000, C=0.5, seed=6); node_color=[red, orange, orange, :black, :black, :black]) #hide
smooth_annotation!( #hide
fig[1, 2]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = -0.64, #hide
offset_y = 1.2, #hide
radius_x = 0.32, #hide
radius_y = 0.78, #hide
num_waves = 5, #hide
fluctuation_amplitude = 0.02, #hide
phase_shift = 0) #hide
Label(fig[1, 1, Bottom()], "Original") #hide
Label(fig[1, 2, Bottom()], "Transformed") #hide
plot!(fig[1, 2], reduced, layout=Spring(iterations=2000, C=40, seed=8); node_color=[red, orange, orange, :black, :black, :black]) #hide
fig #hide
```

## Split simplification
```@docs
Tenet.SplitSimplification
```

```@example plot
set_theme!(resolution=(800,200)) # hide
fig = Figure() #hide
v1 = Tensor([1, 2, 3], (:i,)) #hide
v2 = Tensor([4, 5, 6], (:j,)) #hide
m1 = Tensor(rand(3, 3), (:k, :l)) #hide
t1 = contract(v1, v2) #hide
tensor = contract(t1, m1) #hide
tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide
reduced = transform(tn, Tenet.SplitSimplification) #hide
smooth_annotation!( #hide
fig[1, 1]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = 0.24, #hide
offset_y = 0.6, #hide
radius_x = 0.32, #hide
radius_y = 0.78, #hide
num_waves = 5, #hide
fluctuation_amplitude = 0.015, #hide
phase_shift = 0.0) #hide
plot!(fig[1, 1], tn, layout=Spring(iterations=10000, C=0.5, seed=12); node_color=[red, orange, orange, :black, :black, :black, :black]) #hide
smooth_annotation!( #hide
fig[1, 2]; #hide
color = bg_blue, #hide
xlims = [-2, 2], #hide
ylims = [-2, 2], #hide
offset_x = -0.2, #hide
offset_y = -0.4, #hide
radius_x = 1.1, #hide
radius_y = 0.75, #hide
num_waves = 3, #hide
fluctuation_amplitude = 0.18, #hide
phase_shift = 0.8) #hide
Label(fig[1, 1, Bottom()], "Original") #hide
Label(fig[1, 2, Bottom()], "Transformed") #hide
plot!(fig[1, 2], reduced, layout=Spring(iterations=10000, C=13, seed=151); node_color=[orange, orange, red, red, red, :black, :black, :black, :black]) #hide
fig #hide
```
9 changes: 8 additions & 1 deletion src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ end
HyperindConverter <: Transformation
Convert hyperindices to COPY-tensors, represented by `DeltaArray`s.
This transformation is always used by default when visualizing a `TensorNetwork` with `plot`.
"""
struct HyperindConverter <: Transformation end

Expand Down Expand Up @@ -92,7 +93,7 @@ function transform!(tn::TensorNetwork, config::DiagonalReduction)
# insert COPY tensor
new_index = Symbol(uuid4())
data = DeltaArray{N + 1}(ones(size(target, first(inds))))
push!(copies, Tensor(data, (new_index, inds...)))
push!(copies, Tensor(data, (new_index, inds...), dual = new_index))

# extract diagonal of target tensor
# TODO rewrite using `einsum!` when implemented in Tensors
Expand Down Expand Up @@ -345,6 +346,9 @@ function find_zero_columns(x; atol = 1e-12)
end

function find_diag_axes(x; atol = 1e-12)
# skip 1D tensors
ndims(parent(x)) == 1 && return []

# find all the potential diagonals
potential_diag_axes = [(i, j) for i in 1:ndims(x) for j in i+1:ndims(x) if size(x, i) == size(x, j)]

Expand All @@ -367,6 +371,9 @@ function find_diag_axes(x; atol = 1e-12)
end

function find_anti_diag_axes(x; atol = 1e-12)
# skip 1D tensors
ndims(parent(x)) == 1 && return []

# Find all the potential anti-diagonals
potential_anti_diag_axes = [(i, j) for i in 1:ndims(x) for j in i+1:ndims(x) if size(x, i) == size(x, j)]

Expand Down

0 comments on commit 5edc49d

Please sign in to comment.