Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement GPD from wikipedia definition with 3 params #3218

Open
wants to merge 30 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e2c2962
implement GPD from wikipedia definition with 3 params
kashif Sep 13, 2024
a638d17
renamed to Exponentiated Generalised Pareto
kashif Sep 13, 2024
1c96087
fix docs
kashif Sep 13, 2024
cdbd4f3
add tests
kashif Sep 13, 2024
065757e
fix ruff
kashif Sep 13, 2024
f1e524a
clamp concentration
kashif Sep 13, 2024
467e16e
fix variance
kashif Sep 13, 2024
8a18417
do not clamp concentration
kashif Sep 13, 2024
2c3ac40
fix flakey test
kashif Sep 13, 2024
c405183
update docstring
kashif Sep 13, 2024
c61ba39
docstrings
kashif Sep 13, 2024
e8d4a96
initial anomaly detection script
kashif Sep 19, 2024
d5a093f
fix comments
kashif Sep 19, 2024
b8122a8
remove print
kashif Sep 19, 2024
cf3af3d
doc string
kashif Sep 19, 2024
bc9ad40
concat all the anomalies
kashif Sep 20, 2024
277056e
save as csv
kashif Sep 20, 2024
e34a600
make threshold and percentage an argument
kashif Sep 20, 2024
55257e3
fit returns GPD without validation
kashif Sep 20, 2024
e127f2b
use _gdk_domain_map
kashif Sep 20, 2024
87d851d
use values directly
kashif Sep 20, 2024
745ca70
use args for fitting gpd
kashif Sep 20, 2024
13d777d
isort
kashif Sep 20, 2024
d8730ad
learn scores in the orignal scale
kashif Sep 20, 2024
5a3419f
fix next_lags
kashif Sep 20, 2024
25143bc
1 - gpd.cdf(score) < threshold
kashif Sep 21, 2024
8b10b66
Merge branch 'dev' into gp-distribution
kashif Oct 18, 2024
d70e1ed
Update anomaly_detection_pytorch.py
kashif Nov 5, 2024
f0b9887
Merge branch 'dev' into gp-distribution
kashif Nov 6, 2024
4dfee43
rsample
kashif Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import argparse

import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F

from gluonts.dataset.field_names import FieldName
from gluonts.dataset.repository import get_dataset
from gluonts.itertools import select
from gluonts.torch import DeepAREstimator
from gluonts.torch.distributions import GeneralizedPareto
from gluonts.torch.util import lagged_sequence_values, take_last


def fit_gpd(data, num_iterations=100, learning_rate=0.001):
"""
Fit a Generalized Pareto Distribution to the given data using RMSprop optimizer.

Args:
data (torch.Tensor): Input tensor of shape (batch_size, num_samples)
num_iterations (int): Number of optimization iterations
learning_rate (float): Learning rate for the optimizer

Returns:
GeneralizedPareto: Fitted GPD(loc, scale, concentration) distribution without any validation
"""
batch_size, _ = data.shape

# Initialize parameters for the GPD so that the loc is always less than the data to begin with
loc = data.min(dim=1, keepdim=True)[0] - 1
loc.requires_grad = True
scale = torch.ones(batch_size, 1, device=data.device, requires_grad=True)
concentration = torch.ones(
batch_size, 1, device=data.device, requires_grad=True
).div_(3)

optimizer = torch.optim.RMSprop(
[loc, scale, concentration], lr=learning_rate
)
# Define learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=3
)

def _gdk_domain_map(loc, scale, concentration, validate_args=None):
scale = F.softplus(scale)
neg_conc = concentration < 0
loc = torch.where(neg_conc, loc - scale / concentration, loc)
return GeneralizedPareto(
loc, scale, concentration, validate_args=validate_args
)

def closure():
optimizer.zero_grad()
gpd = _gdk_domain_map(loc, scale, concentration)
loss = -gpd.log_prob(data).mean()
loss.backward()
lr_scheduler.step(loss)
return loss

for _ in range(num_iterations):
optimizer.step(closure)

return _gdk_domain_map(
loc.detach(),
scale.detach(),
concentration.detach(),
validate_args=False,
)


def main(args):
dataset = get_dataset(dataset_name=args.dataset)
estimator = DeepAREstimator(
prediction_length=dataset.metadata.prediction_length,
context_length=args.context_length,
freq=dataset.metadata.freq,
num_feat_static_cat=len(dataset.metadata.feat_static_cat),
cardinality=[
int(cat_feat_info.cardinality)
for cat_feat_info in dataset.metadata.feat_static_cat
],
embedding_dimension=[3],
trainer_kwargs=dict(
max_epochs=args.max_epochs,
),
batch_size=args.batch_size,
)
predictor = estimator.train(dataset.train, cache_data=True)
print(f"Training completed for dataset: {args.dataset}")
model = predictor.prediction_net.model

# Load the test dataset
transformation = estimator.create_transformation()
transformed_test_data = transformation.apply(dataset.test, is_train=True)
test_data_loader = estimator.create_validation_data_loader(
transformed_test_data,
predictor.prediction_net,
)

anomalies = []
means = []
model.cpu().eval()
with torch.no_grad():
for batch in tqdm(test_data_loader, desc="Processing batches"):
inputs = select(
predictor.input_names + [f"future_{FieldName.TARGET}"],
batch,
ignore_missing=True,
)
params, scale, _, static_feat, state = model.unroll_lagged_rnn(
inputs["feat_static_cat"],
inputs["feat_static_real"],
inputs["past_time_feat"],
inputs["past_target"],
inputs["past_observed_values"],
inputs["future_time_feat"][:, :1],
)
# remove the very last param from the params
sliced_params = [p[:, :-1] for p in params]
distr = model.output_distribution(sliced_params, scale=scale)

# get the last target and calcualte its anomaly score
context_target = take_last(
inputs["past_target"], dim=-1, num=model.context_length - 1
)
# calculate the surprisal scores for the context target
scores = -distr.log_prob(context_target)

# get the args.top_score_percentage of the scores for each time series of the batch
top_scores = torch.topk(
scores,
k=int(scores.shape[1] * args.top_score_percentage),
dim=1,
)
# fit a Generalized Pareto Distribution to the top_scores aka surprisal scores values
gpd = fit_gpd(
top_scores.values,
num_iterations=args.gpd_iterations,
learning_rate=args.gpd_learning_rate,
)

# Loop over each prediction length
distr = model.output_distribution(
params, trailing_n=1, scale=scale
)
scaled_past_target = inputs["past_target"] / scale
batch_anomalies = []
batch_means = []
for i in tqdm(
range(inputs["future_target"].shape[1]),
desc="Processing prediction length",
leave=False,
):
target = inputs["future_target"][:, i : i + 1]
score = -distr.log_prob(target)
batch_means.append(distr.mean)
# only check if its an anomaly for scores greater than gpd.loc for each entry in the batch
is_anomaly = torch.where(
score < gpd.loc,
False,
1 - gpd.cdf(score) < args.anomaly_threshold,
)
batch_anomalies.append(is_anomaly)

next_features = torch.cat(
(
static_feat.unsqueeze(dim=1),
inputs["future_time_feat"][:, i : i + 1],
),
dim=-1,
)
next_lags = lagged_sequence_values(
model.lags_seq,
scaled_past_target,
target / scale,
dim=-1,
)
rnn_input = torch.cat((next_lags, next_features), dim=-1)
output, state = model.rnn(rnn_input, state)
scaled_past_target = torch.cat(
(scaled_past_target, target / scale), dim=1
)

params = model.param_proj(output)
distr = model.output_distribution(params, scale=scale)
# stack the batch_anomalies along the prediction length dimension
anomalies.append(torch.stack(batch_anomalies, dim=1))
means.append(torch.stack(batch_means, dim=1))
# concat the anomalies along the batch dimension
anomalies = torch.cat(anomalies, dim=0).cpu().numpy()
means = torch.cat(means, dim=0).cpu().numpy()

# save as pkl
all_dates = []
all_flags = []
all_targets = []
all_means = []
for i, (entry, flags, mean) in enumerate(
zip(dataset.test, anomalies, means)
):
start_date = entry["start"].to_timestamp()
target = entry["target"]
dates = pd.date_range(
start=start_date, periods=len(target), freq=dataset.metadata.freq
)
# take the last prediction_length dates
date_index = dates[-dataset.metadata.prediction_length :]
target_slice = target[-dataset.metadata.prediction_length :]
all_dates.append(date_index)
all_flags.append(flags.flatten().astype(bool))
all_targets.append(target_slice)
all_means.append(mean.flatten())

# create a dataframe with the date_index and the flags
anomaly_df = pd.DataFrame(
{
"date": all_dates,
"is_anomaly": all_flags,
"target": all_targets,
"mean": all_means,
}
)
anomaly_df.set_index("date", inplace=True)
anomaly_df.to_pickle(f"anomalies_{args.dataset}.pkl")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Anomaly detection using PyTorch and GluonTS"
)
parser.add_argument(
"--dataset", type=str, default="electricity", help="Dataset name"
)
parser.add_argument(
"--context_length", type=int, default=None, help="Context length"
)
parser.add_argument(
"--max_epochs", type=int, default=30, help="Maximum number of epochs"
)
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size"
)
parser.add_argument(
"--anomaly_threshold",
type=float,
default=0.05,
help="Threshold for anomaly detection",
)
parser.add_argument(
"--top_score_percentage",
type=float,
default=0.1,
help="Percentage of top scores to consider for GPD fitting",
)
parser.add_argument(
"--gpd_iterations",
type=int,
default=100,
help="Number of iterations for GPD fitting",
)
parser.add_argument(
"--gpd_learning_rate",
type=float,
default=0.001,
help="Learning rate for GPD fitting",
)
args = parser.parse_args()

if args.context_length is None:
args.context_length = (
get_dataset(args.dataset).metadata.prediction_length * 10
)

main(args)
Loading