Skip to content

Commit

Permalink
ohio data and attention optim
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed May 23, 2024
1 parent 9fbb3e8 commit 0e12677
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions configs/ohio_data_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# config for the full dataset builder
data_dir: "/home/fvaselli/Documents/PHD/TSA/TSA/data/data_ohio"
data_dir: "/home/fvaselli/Documents/PHD/TSA/TSA/data/test"
# patients ids
ids: ['540', '552', '544', '567', '584', '596']
test_ids: []
Expand All @@ -12,15 +12,15 @@ scale: 1
# outtype
outtype: "History"
# smooth
smooth: True
smooth: False
# target_weight
target_weight: 1
# standardize
standardize: False
standardize_by_ref: True
standardize_params:
mean: 127.836 # 144.982
std: 60.410 #57.941
mean: 144.96
std: 58.062 #57.941
# Computed Mean: 144.98199462890625, Computed Std: 58.11943817138672
# dataset smooth Computed Mean: 144.98204040527344, Computed Std: 57.940860748291016
# cutpoint (negative= take all the data)
Expand Down
8 changes: 4 additions & 4 deletions src/data_processing/build_ohio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def build_dataset(
files = []
files_ids = []
for pid in ids:
files += [f"/home/fvaselli/Documents/TSA/data/data_ohio/{pid}-ws-testing.xml"]
files += [f"/home/fvaselli/Documents/PHD/TSA/TSA/data/test/{pid}-ws-testing.xml"]
reader = DataReader(
"ohio", f"/home/fvaselli/Documents/TSA/data/data_ohio/{pid}-ws-testing.xml", 5
"ohio", f"/home/fvaselli/Documents/PHD/TSA/TSA/data/test/{pid}-ws-testing.xml", 5
)
train_data[pid] = reader.read()

Expand Down Expand Up @@ -121,10 +121,10 @@ def main(data_config):

# save data and targets as numpy arrays, in same file
dataset = np.concatenate((data, targets), axis=1)
np.save("/home/fvaselli/Documents/TSA/data/data_ohio/dataset_ohio_smooth_stdbyupsampled.npy", dataset)
np.save("/home/fvaselli/Documents/PHD/TSA/TSA/data/test/dataset_ohio_stdby.npy", dataset)
# dataset = tf.data.Dataset.from_tensor_slices((data, targets))
# save
# dataset.save("data/dataset")

if __name__ == "__main__":
main('/home/fvaselli/Documents/TSA/configs/ohio_data_config.yaml')
main('/home/fvaselli/Documents/PHD/TSA/TSA/configs/ohio_data_config.yaml')
2 changes: 1 addition & 1 deletion src/models/param_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def main():
config = yaml.load(f, Loader=yaml.FullLoader)

# models = ["cnn", "rnn", "transformer"]
targets = ["regression", "classification", "multi_classification"]
targets = ["multi_classification"]

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
16 changes: 8 additions & 8 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def check_classification(
pred *= 100
true *= 100

# NOTE now it is 0 for hypo and 1 for hyper
# diffrent from exam project!!
pred_label = (pred[:, ind] > threshold).astype(int) # Assuming feature_dim = 1
true_label = (true[:, ind] > threshold).astype(int) # Adjust index if different
# NOTE now it is 1 for hypo and 0 for hyper
# we make it so for the other models
pred_label = (pred[:, ind] < threshold).astype(int) # Assuming feature_dim = 1
true_label = (true[:, ind] < threshold).astype(int) # Adjust index if different

fpr, tpr, _ = roc_curve(true_label, pred_label)
roc_auc = auc(fpr, tpr)
Expand Down Expand Up @@ -240,7 +240,7 @@ def on_epoch_end(self, epoch, logs=None):
specificity = tn / (tn + fp)
precision = tp / (tp + fp)
npv = tn / (tn + fn)
f1 = 2 * (precision * sensitivity) / (precision + sensitivity)
f1 = 2 * (precision * sensitivity) / (precision + sensitivity) # in this way the f1 is relative to the hyper class
tf.summary.scalar("Accuracy", accuracy, step=epoch)
tf.summary.scalar("Sensitivity", sensitivity, step=epoch)
tf.summary.scalar("Specificity", specificity, step=epoch)
Expand Down Expand Up @@ -394,10 +394,10 @@ def plot_to_image(self, figure):

def check_classification1(true, pred, threshold=0.5):
# Assuming true and pred have shape [batch_size, 1]
# 0 for hypo and 1 for hyper
pred_label = (pred >= threshold).astype(int)
# 1 for hypo and 0 for hyper
pred_label = (pred < threshold).astype(int)
#
true_label = (true >= threshold).astype(int)
true_label = (true < threshold).astype(int)
print("true_label shape:", true_label.shape)
print("pred_label shape:", pred_label.shape)
print("example pred vs true:", pred_label[0], true_label[0])
Expand Down

0 comments on commit 0e12677

Please sign in to comment.