Skip to content

Commit

Permalink
design matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
vladdez committed Nov 15, 2023
1 parent 96f8d0b commit bf6324e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 51 deletions.
24 changes: 11 additions & 13 deletions docs/src/tutorials/designmatrix.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# [Designmatrix Visualization](@id dm_vis)

Here we discuss designmatrix visualization.
Here we discuss design matrix visualization.
Make sure you have looked into the [installation instructions](@ref install_instruct) section.

## Include used Modules
Expand Down Expand Up @@ -31,20 +31,18 @@ plot_designmatrix(designmatrix(uf))
# kwargs `plot_designmatrix(...; ...)`.


- sortData (boolean,false) - Indicating whether the data is sorted; using sortslices() of Base Julia.
- `sort_data` (bool, `true`): indicates whether the data is sorted; using sortslices() of Base Julia.


In order to make the designmatrix easier to read, you may want to sort it.
To make the design matrix easier to read, you may want to sort it.
```
plot_designmatrix(designmatrix(uf); sortData=true)
```

- standardizeData (boolean, false) - Indicating whether the data is standardized, mapping the values between 0 and 1.
- xTicks (number, nothing)

Indicating the number of labels on the x-axis. Behavior if specified in configuration:
- xTicks = 0: no labels are placed.
- xTicks = 1: first possible label is placed.
- xTicks = 2: first and last possible labels are placed.
- 2 < xTicks < number of labels: xTicks-2 labels are placed between the first and last.
- xTicks ≥ number of labels: all labels are placed.
- `standardize_data` (bool,`true`): indicates whether the data is standardized by pointwise division of the data with its sampled standard deviation.
- `sort_data` (bool, `true`): indicates whether the data is sorted; using sortslices() of Base Julia.
- `xticks` (`nothing`): returns the number of labels on the x-axis. Behavior is set in the configuration:
- xticks = 0: no labels are placed.
- xticks = 1: first possible label is placed.
- xticks = 2: first and last possible labels are placed.
- 2 < xticks < `number of labels`: equally distribute the labels.
- xticks ≥ `number of labels`: all labels are placed.
69 changes: 33 additions & 36 deletions src/plot_designmatrix.jl
Original file line number Diff line number Diff line change
@@ -1,73 +1,70 @@
"""
plot_designmatrix!(f::Union{GridPosition, GridLayout, Figure}, plotData::Unfold.DesignMatrix; kwargs...)
plot_designmatrix(plotData::Unfold.DesignMatrix; kwargs...)
plot_designmatrix!(f::Union{GridPosition, GridLayout, Figure}, data::Unfold.DesignMatrix; kwargs...)
plot_designmatrix(data::Unfold.DesignMatrix; kwargs...)
Plot a designmatrix.
## Arguments:
- `f::Union{GridPosition, GridLayout, Figure}`: Figure or GridPosition (e.g. f[2, 3]) that the plot should be drawn into. New axis is created.
- `plotData::Unfold.DesignMatrix`: Data for the plot visualization.
- `f::Union{GridPosition, GridLayout, Figure}`: Figure or GridPosition (e.g. f[2, 3]) in which the plot will be placed into. A new axis is created.
- `data::Unfold.DesignMatrix`: data for the plot visualization.
## kwargs
- `standardizeData`: (bool,`true`) - Indicating whether the data is standardized by pointwise division of the data with its sampled standard deviation.
- `sortData`: (bool, `true`) - Indicating whether the data is sorted; using sortslices() of Base Julia.
- `xTicks`: (`nothing`)
Indicating the number of labels on the x-axis.
Behavior if specified in configuration:
- xTicks = 0: no labels are placed.
- xTicks = 1: first possible label is placed.
- xTicks = 2: first and last possible labels are placed.
- 2 < xTicks < `number of labels`: Equally distribute the labels.
- xTicks ≥ `number of labels`: all labels are placed.
- `standardize_data` (bool,`true`): indicates whether the data is standardized by pointwise division of the data with its sampled standard deviation.
- `sort_data` (bool, `true`): indicates whether the data is sorted; using sortslices() of Base Julia.
- `xticks` (`nothing`): returns the number of labels on the x-axis. Behavior is set in the configuration:
- xticks = 0: no labels are placed.
- xticks = 1: first possible label is placed.
- xticks = 2: first and last possible labels are placed.
- 2 < xticks < `number of labels`: equally distribute the labels.
- xticks ≥ `number of labels`: all labels are placed.
$(_docstring(:designmat))
## Return Value:
A figure displaying the designmatrix.
"""
plot_designmatrix(plotData::Unfold.DesignMatrix; kwargs...) =
plot_designmatrix!(Figure(), plotData; kwargs...)
plot_designmatrix(data::Unfold.DesignMatrix; kwargs...) =
plot_designmatrix!(Figure(), data; kwargs...)
function plot_designmatrix!(
f::Union{GridPosition,GridLayout,Figure},
plotData::Unfold.DesignMatrix;
xTicks = nothing,
sortData = false,
standardizeData = false,
data::Unfold.DesignMatrix;
xticks = nothing,
sort_data = false,
standardize_data = false,
kwargs...,
)
config = PlotConfig(:designmat)
config_kwargs!(config; kwargs...)
designmat = Unfold.get_Xs(plotData)
if standardizeData
designmat = Unfold.get_Xs(data)
if standardize_data
designmat = designmat ./ std(designmat, dims = 1)
designmat[isinf.(designmat)] .= 1.0
end

if isa(designmat, SparseMatrixCSC)
if sortData
@warn "Sorting does not make sense for timeexpanded designmatrices. sortData has been set to `false`"
if sort_data
@warn "Sorting does not make sense for time-expanded designmatrices. sort_data has been set to `false`"

sortData = false
sort_data = false
end
designmat = Matrix(designmat[end÷2-2000:end÷2+2000, :])
end

if sortData
if sort_data
designmat = Base.sortslices(designmat, dims = 1)
end
labels = Unfold.get_coefnames(plotData)
labels = Unfold.get_coefnames(data)

lLength = length(labels)
# only change xTicks if we want less then all
if (xTicks !== nothing && xTicks < lLength)
@assert(xTicks >= 0, "xTicks shouldn't be negative")
# sections between xTicks
sectionSize = (lLength - 2) / (xTicks - 1)
# only change xticks if we want less then all
if (xticks !== nothing && xticks < lLength)
@assert(xticks >= 0, "xticks shouldn't be negative")
# sections between xticks
sectionSize = (lLength - 2) / (xticks - 1)
newLabels = []

# first tick. Empty if 0 ticks
if xTicks >= 1
if xticks >= 1
push!(newLabels, labels[1])
else
push!(newLabels, "")
Expand All @@ -76,15 +73,15 @@ function plot_designmatrix!(
# fill in ticks in the middle
for i = 1:(lLength-2)
# checks if we're at the end of a section, but NO tick on the very last section
if i % sectionSize < 1 && i < ((xTicks - 1) * sectionSize)
if i % sectionSize < 1 && i < ((xticks - 1) * sectionSize)
push!(newLabels, labels[i+1])
else
push!(newLabels, "")
end
end

# last tick at the end
if xTicks >= 2
if xticks >= 2
push!(newLabels, labels[lLength-1])
else
push!(newLabels, "")
Expand Down
4 changes: 2 additions & 2 deletions test/test_dm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ uf = example_data("UnfoldLinearModel")
end

@testset "sort data" begin
plot_designmatrix(designmatrix(uf); sortData = true)
plot_designmatrix(designmatrix(uf); sort_data = true)
end


@testset "designmatrix plot in GridLayout" begin
f = Figure(resolution=(1200, 1400))
ga = f[1, 1] = GridLayout()
plot_designmatrix!(ga, designmatrix(uf); sortData = true)
plot_designmatrix!(ga, designmatrix(uf); sort_data = true)
f
end

0 comments on commit bf6324e

Please sign in to comment.