Skip to content

Commit

Permalink
make plot names contain config info
Browse files Browse the repository at this point in the history
  • Loading branch information
syrkis committed Jun 15, 2024
1 parent 9322154 commit a86644d
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ COPY requirements.txt .
# Install the required packages
RUN python3.11 -m pip install -r requirements.txt

# Install JAX with CUDA support. HPC is on CUDA 11, and JAX 0.2.25 is the latest version for that
# Install JAX with CUDA support.
RUN python3.11 -m pip install --upgrade \
pip install -U "jax[cuda12]" \
optax
Expand Down
12 changes: 6 additions & 6 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
base: 12
n: 1024
emb: 32
n: 16384
emb: 64
lr: 0.001
depth: 3
heads: 4
epochs: 400
depth: 2
heads: 8
epochs: 10000
l2: 0.00001 # lambda
block: vaswani
dropout: 0.15
dropout: 0.2
gamma: 2
186 changes: 184 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ numpy-hilbert-curve = "^1.0.1"
wandb = "^0.17.0"
seaborn = "^0.13.2"
scikit-learn = "^1.5.0"
jax = "^0.4.29"
optax = "^0.2.2"


[build-system]
Expand Down
15 changes: 12 additions & 3 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
plt.rcParams["font.family"] = "Monospace"


def fname_fn(conf, fname):
return (
"_".join([f"{k}_{v}" for k, v in conf.items() if k not in ["in_d", "out_d"]])
+ f"_{fname}"
)


# functions
def polar_plot(gold, pred, conf, fname, offset=0): # maps v to a polar plot
conf = conf.__dict__
Expand Down Expand Up @@ -74,8 +81,8 @@ def polar_plot(gold, pred, conf, fname, offset=0): # maps v to a polar plot
]
)
ax.set_xlabel(xlabel, color=ink)
if darkdetect.isLight():
plt.savefig(f"figs/{fname}", dpi=100)
fname = fname_fn(conf, fname)
plt.savefig(f"figs/{fname}", dpi=100)


def curve_plot(
Expand Down Expand Up @@ -110,8 +117,10 @@ def curve_plot(
color=ink,
)
ax.legend(info["legend"], frameon=False, labelcolor=ink)
# make fname contain conf
fname = fname_fn(conf, "curves")
if darkdetect.isLight():
plt.savefig(f"figs/curves.pdf")
plt.savefig(f"figs/{fname}.pdf", dpi=100)


############################################
Expand Down

0 comments on commit a86644d

Please sign in to comment.