Skip to content

Commit

Permalink
feat(gold_standard): arbitrary gold standards (#912)
Browse files Browse the repository at this point in the history
* feat(gold_standard): filter by protein coding genes

* feat: arbitrary gold standards

* feat: read model from gcs

* feat: read model from gcs

* feat: get untrusted  types from blob

* revert: changes to gene_index

* fix: correct list of missing and unexpected fields

* chore: addressing comments

* fix: selective check on the schema issues

---------

Co-authored-by: Szymon Szyszkowski <[email protected]>
  • Loading branch information
project-defiant and Szymon Szyszkowski authored Nov 14, 2024
1 parent e5b3c9e commit 253fe31
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 95 deletions.
227 changes: 139 additions & 88 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

import logging
from typing import Any

import pyspark.sql.functions as f
from sklearn.ensemble import GradientBoostingClassifier
from wandb import login as wandb_login

from gentropy.common.schemas import compare_struct_schemas
from gentropy.common.session import Session
from gentropy.common.spark_helpers import calculate_harmonic_sum
from gentropy.common.utils import access_gcp_secret
Expand Down Expand Up @@ -152,6 +154,9 @@ def __init__(
self.download_from_hub = download_from_hub
self.hf_model_commit_message = hf_model_commit_message
self.l2g_threshold = l2g_threshold or 0.0
self.gold_standard_curation_path = gold_standard_curation_path
self.gene_interactions_path = gene_interactions_path
self.variant_index_path = variant_index_path

# Load common inputs
self.credible_set = StudyLocus.from_parquet(
Expand All @@ -160,27 +165,105 @@ def __init__(
self.feature_matrix = L2GFeatureMatrix(
_df=session.load_data(feature_matrix_path), features_list=self.features_list
)
self.variant_index = (
VariantIndex.from_parquet(session, variant_index_path)
if variant_index_path
else None
)

if run_mode == "predict":
self.run_predict()
elif run_mode == "train":
self.gs_curation = (
self.session.spark.read.json(gold_standard_curation_path)
if gold_standard_curation_path
else None
)
self.interactions = (
self.session.spark.read.parquet(gene_interactions_path)
if gene_interactions_path
else None
)
self.gold_standard = self.prepare_gold_standard()
self.run_train()

def prepare_gold_standard(self) -> L2GGoldStandard:
"""Prepare the gold standard for training.
Returns:
L2GGoldStandard: training dataset.
Raises:
ValueError: When gold standard path, is not provided, or when
parsing OTG gold standard but missing interactions and variant index paths.
TypeError: When gold standard is not OTG gold standard nor L2GGoldStandard.
"""
if self.gold_standard_curation_path is None:
raise ValueError("Gold Standard is required for model training.")
# Read the gold standard either from json or parquet, default to parquet if can not infer the format from extension.
ext = self.gold_standard_curation_path.split(".")[-1]
ext = "parquet" if ext not in ["parquet", "json"] else ext
gold_standard = self.session.load_data(self.gold_standard_curation_path, ext)
schema_issues = compare_struct_schemas(
gold_standard.schema, L2GGoldStandard.get_schema()
)
# Parse the gold standard depending on the input schema
match schema_issues:
case {**extra} if not extra:
# Schema is the same as L2GGoldStandard - load the GS
# NOTE: match to empty dict will be non-selective
# see https://stackoverflow.com/questions/75389166/how-to-match-an-empty-dictionary
logging.info("Successfully parsed gold standard.")
return L2GGoldStandard(
_df=gold_standard,
_schema=L2GGoldStandard.get_schema(),
)
case {
"missing_mandatory_columns": [
"studyLocusId",
"variantId",
"studyId",
"geneId",
"goldStandardSet",
],
"unexpected_columns": [
"association_info",
"gold_standard_info",
"metadata",
"sentinel_variant",
"trait_info",
],
}:
# There are schema mismatches, this would mean that we have
logging.info("Detected OTG Gold Standard. Attempting to parse it.")
otg_curation = gold_standard
if self.gene_interactions_path is None:
raise ValueError("Interactions are required for parsing curation.")
if self.variant_index_path is None:
raise ValueError("Variant Index are required for parsing curation.")

interactions = self.session.load_data(
self.gene_interactions_path, "parquet"
)
variant_index = VariantIndex.from_parquet(
self.session, self.variant_index_path
)
study_locus_overlap = StudyLocus(
_df=self.credible_set.df.join(
otg_curation.select(
f.concat_ws(
"_",
f.col("sentinel_variant.locus_GRCh38.chromosome"),
f.col("sentinel_variant.locus_GRCh38.position"),
f.col("sentinel_variant.alleles.reference"),
f.col("sentinel_variant.alleles.alternative"),
).alias("variantId"),
f.col("association_info.otg_id").alias("studyId"),
),
on=[
"studyId",
"variantId",
],
how="inner",
),
_schema=StudyLocus.get_schema(),
).find_overlaps()

return L2GGoldStandard.from_otg_curation(
gold_standard_curation=otg_curation,
variant_index=variant_index,
study_locus_overlap=study_locus_overlap,
interactions=interactions,
)
case _:
raise TypeError("Incorrect gold standard dataset provided.")

def run_predict(self) -> None:
"""Run the prediction step.
Expand All @@ -207,87 +290,55 @@ def run_predict(self) -> None:

def run_train(self) -> None:
"""Run the training step."""
if (
self.gs_curation
and self.interactions
and self.wandb_run_name
and self.model_path
):
wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev")

# Instantiate classifier and train model
l2g_model = LocusToGeneModel(
model=GradientBoostingClassifier(random_state=42),
hyperparameters=self.hyperparameters,
)
wandb_login(key=wandb_key)
trained_model = LocusToGeneTrainer(
model=l2g_model,
feature_matrix=self._annotate_gold_standards_w_feature_matrix(),
).train(self.wandb_run_name)
if trained_model.training_data and trained_model.model and self.model_path:
trained_model.save(self.model_path)
if self.hf_hub_repo_id and self.hf_model_commit_message:
hf_hub_token = access_gcp_secret(
"hfhub-key", "open-targets-genetics-dev"
)
trained_model.export_to_hugging_face_hub(
# we upload the model in the filesystem
self.model_path.split("/")[-1],
hf_hub_token,
data=trained_model.training_data._df.drop(
"goldStandardSet", "geneId"
).toPandas(),
repo_id=self.hf_hub_repo_id,
commit_message=self.hf_model_commit_message,
)
# Initialize access to weights and biases
wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev")
wandb_login(key=wandb_key)

# Instantiate classifier and train model
l2g_model = LocusToGeneModel(
model=GradientBoostingClassifier(random_state=42),
hyperparameters=self.hyperparameters,
)

# Calculate the gold standard features
feature_matrix = self._annotate_gold_standards_w_feature_matrix()

# Run the training
trained_model = LocusToGeneTrainer(
model=l2g_model, feature_matrix=feature_matrix
).train(self.wandb_run_name)

# Export the model
if trained_model.training_data and trained_model.model and self.model_path:
trained_model.save(self.model_path)
if self.hf_hub_repo_id and self.hf_model_commit_message:
hf_hub_token = access_gcp_secret(
"hfhub-key", "open-targets-genetics-dev"
)
trained_model.export_to_hugging_face_hub(
# we upload the model in the filesystem
self.model_path.split("/")[-1],
hf_hub_token,
data=trained_model.training_data._df.drop(
"goldStandardSet", "geneId"
).toPandas(),
repo_id=self.hf_hub_repo_id,
commit_message=self.hf_model_commit_message,
)

def _annotate_gold_standards_w_feature_matrix(self) -> L2GFeatureMatrix:
"""Generate the feature matrix of annotated gold standards.
Returns:
L2GFeatureMatrix: Feature matrix with gold standards annotated with features.
Raises:
ValueError: Not all training dependencies are defined
"""
if self.gs_curation and self.interactions and self.variant_index:
study_locus_overlap = StudyLocus(
_df=self.credible_set.df.join(
self.gs_curation.select(
f.concat_ws(
"_",
f.col("sentinel_variant.locus_GRCh38.chromosome"),
f.col("sentinel_variant.locus_GRCh38.position"),
f.col("sentinel_variant.alleles.reference"),
f.col("sentinel_variant.alleles.alternative"),
).alias("variantId"),
f.col("association_info.otg_id").alias("studyId"),
),
on=[
"studyId",
"variantId",
],
how="inner",
),
_schema=StudyLocus.get_schema(),
).find_overlaps()

gold_standards = L2GGoldStandard.from_otg_curation(
gold_standard_curation=self.gs_curation,
variant_index=self.variant_index,
study_locus_overlap=study_locus_overlap,
interactions=self.interactions,
return (
self.gold_standard.build_feature_matrix(
self.feature_matrix, self.credible_set
)

return (
gold_standards.build_feature_matrix(
self.feature_matrix, self.credible_set
)
.select_features(self.features_list)
.persist()
)
raise ValueError("Dependencies for train mode not set.")
.select_features(self.features_list)
.persist()
)


class LocusToGeneEvidenceStep:
Expand Down
23 changes: 17 additions & 6 deletions src/gentropy/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,32 @@ def __post_init__(self: LocusToGeneModel) -> None:
self.model.set_params(**self.hyperparameters_dict)

@classmethod
def load_from_disk(
cls: Type[LocusToGeneModel], path: str | Path
) -> LocusToGeneModel:
def load_from_disk(cls: Type[LocusToGeneModel], path: str) -> LocusToGeneModel:
"""Load a fitted model from disk.
Args:
path (str | Path): Path to the model
path (str): Path to the model
Returns:
LocusToGeneModel: L2G model loaded from disk
Raises:
ValueError: If the model has not been fitted yet
"""
loaded_model = sio.load(path, trusted=sio.get_untrusted_types(file=path))
if path.startswith("gs://"):
path = path.removeprefix("gs://")
bucket_name = path.split("/")[0]
blob_name = "/".join(path.split("/")[1:])
from google.cloud import storage

client = storage.Client()
bucket = storage.Bucket(client=client, name=bucket_name)
blob = storage.Blob(name=blob_name, bucket=bucket)
data = blob.download_as_string(client=client)
loaded_model = sio.loads(data, trusted=sio.get_untrusted_types(data=data))
else:
loaded_model = sio.load(path, trusted=sio.get_untrusted_types(file=path))

if not loaded_model._is_fitted():
raise ValueError("Model has not been fitted yet.")
return cls(model=loaded_model)
Expand All @@ -80,7 +91,7 @@ def load_from_hub(
"""
local_path = Path(model_id)
hub_utils.download(repo_id=model_id, dst=local_path, token=hf_token)
return cls.load_from_disk(Path(local_path) / model_name)
return cls.load_from_disk(str(Path(local_path) / model_name))

@property
def hyperparameters_dict(self) -> dict[str, Any]:
Expand Down
10 changes: 9 additions & 1 deletion src/gentropy/method/l2g/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _get_shap_explanation(
Raises:
ValueError: Train data not set, cannot get SHAP values.
Exception: (ExplanationError) When the additivity check fails.
"""
if self.x_train is not None and self.x_test is not None:
training_data = pd.concat([self.x_train, self.x_test], ignore_index=True)
Expand All @@ -105,7 +106,14 @@ def _get_shap_explanation(
data=training_data,
feature_perturbation="interventional",
)
return explainer(training_data)
try:
return explainer(training_data)
except Exception as e:
if "Additivity check failed in TreeExplainer" in repr(e):
return explainer(training_data, check_additivity=False)
else:
raise

raise ValueError("Train data not set.")

def log_plot_image_to_wandb(
Expand Down

0 comments on commit 253fe31

Please sign in to comment.