-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
35 lines (29 loc) · 982 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder
# -- settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# TabNetPretrainer
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=1e-3),
mask_type='entmax', # "sparsemax",
device_name='cuda'
)
clf = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=1e-3),
scheduler_params={"step_size":7, # how to use learning rate scheduler
"gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='entmax', # This will be overwritten if using pretrain model,
device_name='cuda'
)