-
Notifications
You must be signed in to change notification settings - Fork 0
/
probe_eval.py
68 lines (57 loc) · 2.04 KB
/
probe_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from vitx import config_parser, get_method, sync_checkpoints
from vitx.data import CIFAR100Dataset, get_image_transforms
from vitx.methods.evaluators import ProbeEvaluator
def main():
config = config_parser(
config_path="./configs/", config_name="default", job_name="test"
)
ckpt_checkpoint_path = sync_checkpoints(config=config)
image_transform = get_image_transforms(transform_config=config.data.transform)
train_dataset = CIFAR100Dataset(
images_path="../datasets/cl-datasets/data/",
image_transform=image_transform,
train=True,
)
test_dataset = CIFAR100Dataset(
images_path="../datasets/cl-datasets/data/",
image_transform=image_transform,
train=False,
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config.train.batch_size,
num_workers=config.data.n_workers,
shuffle=True,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=config.train.batch_size,
num_workers=config.data.n_workers,
shuffle=False,
)
checkpoint = torch.load(config.ckpt_checkpoint_path)
method = get_method(config=config)
method.load_state_dict(checkpoint["state_dict"])
probe_evaluator = ProbeEvaluator(
model=method.model, embed_dim=config.model.vision_model.embed_dim, n_classes=100
)
trainer = pl.Trainer(
accelerator=config.train.accelerator_type,
devices=config.train.n_devices,
strategy=DDPStrategy(),
precision=16 if config.train.mixed_precision else 32,
max_epochs=config.train.n_epochs,
check_val_every_n_epoch=config.train.check_val_every_n_epoch,
)
trainer.fit(
model=probe_evaluator,
train_dataloaders=train_loader,
val_dataloaders=test_loader,
)
if __name__ == "__main__":
main()