-
Notifications
You must be signed in to change notification settings - Fork 0
/
Snakefile3
78 lines (59 loc) · 2.92 KB
/
Snakefile3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import pandas as pd
from pprint import pprint
EMAIL = "[email protected]"
FILE = 'runs/spectral_runs2.csv'
#FILE = 'runs/vacs_runs.csv'
def get_files(wildcards):
runs = pd.read_csv(FILE)
return list(runs.path)
def get_score_files(wildcards):
runs = pd.read_csv(FILE)
return [os.path.join(os.path.split(path)[0], 'scores.npy') for path in runs.path]
def get_black_scores(wildcards):
runs = pd.read_csv(FILE)
return [os.path.join(os.path.split(path)[0], 'black_scores.npy') for path in runs.path]
rule:
# input: "/data/atong/anomaly/mnist/all_scores.pkl"
# input: "/data/atong/anomaly/mnist2/all_scores.pkl"
#input: get_score_files
#input: "/data/atong/anomaly/mnist2/all_black_scores.pkl"
input: get_black_scores
rule train:
output: "{prefix}/{dataset}2/{model}/{class}/{seed}/{frac_corrupt}/model.json"
shell: "python train_all3.py {wildcards.prefix} {wildcards.dataset}2 {wildcards.model} {wildcards.class} {wildcards.seed} {wildcards.frac_corrupt} 256 20000"
rule predict_models:
input: rules.train.output
output: "{prefix}/{dataset}2/{model}/{class}/{seed}/{frac_corrupt}/scores.npy"
shell: "CUDA_VISIBLE_DEVICES='' python score_all3.py {wildcards.prefix} {wildcards.dataset}2 {wildcards.model} {wildcards.class} {wildcards.seed} {wildcards.frac_corrupt}"
rule predict_baseline:
output: "{prefix}/{dataset}2/shallow_{model}/{class}/{seed}/{frac_corrupt}/scores.npy"
shell: "python train_baseline2.py {wildcards.prefix} {wildcards.dataset}2 {wildcards.model} {wildcards.class} {wildcards.seed} {wildcards.frac_corrupt}"
ruleorder: predict_baseline > predict_models
rule predict_black_baseline:
input: 'black_image_baseline.py'
output: "{prefix}/{dataset}2/shallow_{model}/{class}/{seed}/{frac_corrupt}/black_scores.npy"
shell: "python {input} {wildcards.prefix} {wildcards.dataset}2 {wildcards.model} {wildcards.class} {wildcards.seed} {wildcards.frac_corrupt}"
rule accumulate_scores:
input:
scores = get_score_files,
program = 'accumulate_scores.py'
output: "{prefix}/{dataset}/all_scores.pkl"
shell: "python {input.program} {input.scores} {output}"
rule predict_black:
input:
models = rules.train.output,
program = 'black_image.py'
output: "{prefix}/{dataset}2/{model}/{class}/{seed}/{frac_corrupt}/black_scores.npy"
shell: "CUDA_VISIBLE_DEVICES='' python {input.program} {wildcards.prefix} {wildcards.dataset}2 {wildcards.model} {wildcards.class} {wildcards.seed} {wildcards.frac_corrupt}"
ruleorder: predict_black_baseline > predict_black
rule accumulate_black:
input:
scores = get_black_scores,
program = 'accumulate_scores.py'
output: "{prefix}/{dataset}/all_black_scores.pkl"
shell: "python {input.program} {input.scores} {output}"
onsuccess:
shell("mail -s 'Snakemake Completed!' %s < {log}" % EMAIL)
onerror:
shell("mail -s 'Snakemake Error' %s < {log}" % EMAIL)