From 0bba901a9c1cba0ea55e4a6fec277b18fecf4e48 Mon Sep 17 00:00:00 2001 From: ST John Date: Mon, 26 Apr 2021 15:36:19 +0100 Subject: [PATCH 1/2] initial commit saved from v1 - WIP --- .../deep_nonstationary_gp_samples.py | 59 +++++++++ gpflux/kernels.py | 117 ++++++++++++++++++ gpflux/layers/nonstationary_layer.py | 55 ++++++++ tests/gpflux/test_nonstationary_unit.py | 70 +++++++++++ .../test_nonstationary_integration.py | 64 ++++++++++ 5 files changed, 365 insertions(+) create mode 100644 docs/notebooks/deep_nonstationary_gp_samples.py create mode 100644 gpflux/kernels.py create mode 100644 gpflux/layers/nonstationary_layer.py create mode 100644 tests/gpflux/test_nonstationary_unit.py create mode 100644 tests/integration/test_nonstationary_integration.py diff --git a/docs/notebooks/deep_nonstationary_gp_samples.py b/docs/notebooks/deep_nonstationary_gp_samples.py new file mode 100644 index 00000000..cbe3378a --- /dev/null +++ b/docs/notebooks/deep_nonstationary_gp_samples.py @@ -0,0 +1,59 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.3.2 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% +import gpflow +import gpflux +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt +from gpflux.nonstationary import NonstationaryKernel +from plotting import plot_layers + +# %% +Ns = 1000 +D = 1 +a, b = 0, 1 +X = np.linspace(a, b, 1000).reshape(-1, 1) +Xlam = np.c_[X, np.zeros_like(X)] + +# %% +D_in = D1 = D2 = D_out = D + +# Layer 1 +Z1 = X.copy() +feat1 = gpflow.features.InducingPoints(Z1) +kern1 = gpflow.kernels.RBF(D_in, lengthscales=0.1) +layer1 = gpflux.layers.GPLayer(kern1, feat1, D1) + +# Layer 2 +Z2 = Xlam.copy() +feat2 = gpflow.features.InducingPoints(Z2) +kern2 = NonstationaryKernel(gpflow.kernels.RBF(D1), D_in, scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor) +layer2 = gpflux.layers.NonstationaryGPLayer(kern2, feat2, D2) + +# Layer 3 +Z3 = Xlam.copy() +feat3 = gpflow.features.InducingPoints(Z3) +kern3 = NonstationaryKernel(gpflow.kernels.RBF(D2), D1, scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor) +layer3 = gpflux.layers.NonstationaryGPLayer(kern3, feat3, D_out) + +model = gpflux.DeepGP(np.empty((1, 1)), np.empty((1, 1)), [layer1, layer2, layer3]) + +# %% +plot_layers(X, model) +plt.show() diff --git a/gpflux/kernels.py b/gpflux/kernels.py new file mode 100644 index 00000000..5843459f --- /dev/null +++ b/gpflux/kernels.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +import gpflow +from gpflow import Param, params_as_tensors + + +class NonstationaryKernel(gpflow.kernels.Kernel): + def __init__(self, basekernel, lengthscales_dim, scaling_offset=0.0, positivity=tf.exp): + """ + basekernel.input_dim == data_dim: dimension of original (first-layer) input + lengthscales_dim: output dimension of previous layer + + For the nonstationary deep GP, the `positivity` is actually part of the *model* + specification and does change model behaviour. + `scaling_offset` behaves equivalent to a constant mean function in the previous + layer; when giving the previous layer a constant mean function, it should be set + to non-trainable. + + When the positivity is chosen to be exp(), then the scaling offset is also + equivalent to the basekernel lengthscale, hence the default is 0 and non-trainable. + For other positivity choices, it may be useful to enable scaling_offset to be trainable. + + Scaling offset can be scalar or be a vector of length `data_dim`. + """ + if not isinstance(basekernel, gpflow.kernels.Stationary): + raise TypeError("base kernel must be stationary") + + data_dim = basekernel.input_dim + # must call super().__init__() before adding parameters: + super().__init__(data_dim + lengthscales_dim) + + self.data_dim = data_dim + self.lengthscales_dim = lengthscales_dim + if lengthscales_dim not in (data_dim, 1): + raise ValueError("lengthscales_dim must be equal to basekernel's input_dim or 1") + + self.basekernel = basekernel + self.scaling_offset = Param(scaling_offset) + if positivity is tf.exp: + self.scaling_offset.trainable = False + self._positivity = positivity # XXX use softplus / log(1+exp)? + + @params_as_tensors + def _split_scale(self, X): + """ + X has data_dim + lengthscales_dim elements + """ + original_input = X[..., :self.data_dim] + log_lengthscales = X[..., self.data_dim:] + lengthscales = self._positivity(log_lengthscales + self.scaling_offset) + return original_input, lengthscales + + @params_as_tensors + def K(self, X, X2=None): + base_lengthscales = self.basekernel.lengthscales + x1, lengthscales1 = self._split_scale(X) + if X2 is not None: + x2, lengthscales2 = self._split_scale(X2) + else: + x2, lengthscales2 = x1, lengthscales1 + + # d = self.lengthscales_dim, either 1 or D + # x1: [N, D], lengthscales1: [N, d] + # x2: [M, D], lengthscales2: [M, d] + + # axis=0 and axis=-3 are equivalent, and axis=1 and axis=-2 are equivalent + lengthscales1 = tf.expand_dims(lengthscales1, axis=-2) # [N, 1, d] + lengthscales2 = tf.expand_dims(lengthscales2, axis=-3) # [1, M, d] + squared_lengthscales_sum = (lengthscales1 ** 2 + lengthscales2 ** 2) # [N, M, d] + + q = tf.reduce_prod(tf.sqrt(2 * lengthscales1 * lengthscales2 / squared_lengthscales_sum), + axis=-1) # [N, M] + + if self.lengthscales_dim == 1: # also use this when data_dim==1 + # Last dimension should be 1, so the reduce_prod above *should* have no effect here. + # The Stationary._scaled_square_dist() handles the base lengthscale rescaling, + # but we need to correct the `q` prefactor: + q = q ** tf.cast(self.data_dim, q.dtype) + r2 = self.basekernel._scaled_square_dist(x1, None if X2 is None else x2) + r2 /= 0.5 * squared_lengthscales_sum[..., 0] # [N, M] + + elif self.lengthscales_dim == self.data_dim: + assert self.data_dim > 1 + + # Here, we have to handle the rescaling and expanding manually, as the full case cannot + # be implemented with a matmul. + x1 /= base_lengthscales + x2 /= base_lengthscales + x1 = tf.expand_dims(x1, axis=-2) # [N, 1, D] + x2 = tf.expand_dims(x2, axis=-3) # [1, M, D] + + r2 = tf.reduce_sum(2 * (x1 - x2) ** 2 / squared_lengthscales_sum, axis=-1) # [N, M] + + else: + assert False, "lengthscales_dim must be 1 or same as data_dim" + + return q * self.basekernel.K_r2(r2) + + @params_as_tensors + def Kdiag(self, X): + return self.basekernel.Kdiag(X) diff --git a/gpflux/layers/nonstationary_layer.py b/gpflux/layers/nonstationary_layer.py new file mode 100644 index 00000000..81898316 --- /dev/null +++ b/gpflux/layers/nonstationary_layer.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from typing import NamedTuple, Optional + +import numpy as np +import tensorflow as tf + +import gpflow +from gpflow import Param, Parameterized, params_as_tensors, settings +from gpflow.conditionals import sample_conditional +from gpflow.kullback_leiblers import gauss_kl +from gpflow.mean_functions import Zero + +from gpflux.kernels import NonstationaryKernel +from gpflux.layers.gp_layer import GPLayer +from gpflux.types import TensorLike + + +class NonstationaryGPLayer(GPLayer): + def __init__(self, + kernel: NonstationaryKernel, + feature: gpflow.features.InducingFeature, + *args, **kwargs): + assert isinstance(kernel, NonstationaryKernel) + if isinstance(feature, gpflow.features.InducingPoints): + assert feature.Z.shape[1] == kernel.input_dim + super().__init__(kernel, feature, *args, **kwargs) + + @params_as_tensors + def propagate(self, H, *, X=None, **kwargs): + """ + Concatenates original input X with output H of the previous layer + for the non-stationary kernel: the latter will be interpreted as + lengthscales. + + :param H: input to this layer [N, P] + :param X: original input [N, D] + """ + XH = tf.concat([X, H], axis=1) + return super().propagate(XH, **kwargs) diff --git a/tests/gpflux/test_nonstationary_unit.py b/tests/gpflux/test_nonstationary_unit.py new file mode 100644 index 00000000..b98817bf --- /dev/null +++ b/tests/gpflux/test_nonstationary_unit.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +import gpflow + +from gpflux.nonstationary import NonstationaryKernel + + +def test_positive_definite(session_tf): + D = 2 + kern = NonstationaryKernel(gpflow.kernels.RBF(D), D) + + XL = np.random.rand(3, 2 * D) + K1 = kern.compute_K_symm(XL) + K2 = kern.compute_K(XL, XL) + + assert_allclose(K1, K2) + + np.linalg.cholesky(K1) + + +@pytest.mark.parametrize("lengthscales_dim", [1, 2]) +def test_shapes(session_tf, lengthscales_dim): + D = 2 + N = 5 + M = 6 + d = lengthscales_dim + X1 = np.random.randn(N, D) + X2 = np.random.randn(M, D) + L1 = np.random.randn(N, d) + L2 = np.random.randn(M, d) + XL1 = np.hstack([X1, L1]) + XL2 = np.hstack([X2, L2]) + basekern = gpflow.kernels.RBF(D) + kern = NonstationaryKernel(basekern, d) + baseK = basekern.compute_K(X1, X2) + K = kern.compute_K(XL1, XL2) + assert baseK.shape == K.shape + + +def test_1D_equivalence(session_tf): + D = 2 + kern = NonstationaryKernel(gpflow.kernels.RBF(D), D) + kern_1D = NonstationaryKernel(gpflow.kernels.RBF(D), 1) + + XL1D = np.random.randn(3, D + 1) + XL = np.concatenate([XL1D, np.tile(XL1D[:, -1, None], [1, D - 1])], 1) + + K = kern.compute_K_symm(XL) + K_1D = kern_1D.compute_K_symm(XL1D) + + assert_allclose(K, K_1D) diff --git a/tests/integration/test_nonstationary_integration.py b/tests/integration/test_nonstationary_integration.py new file mode 100644 index 00000000..8c8814b7 --- /dev/null +++ b/tests/integration/test_nonstationary_integration.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import numpy as np + +import gpflow + +import gpflux +from gpflux.models.deep_gp import DeepGP +from gpflux.nonstationary import NonstationaryKernel + + +def test_nonstationary_gp_1d(session_tf): + """ + This test build a deep nonstationary GP model consisting of 3 layers, + and checks if the model can be optimized. + """ + D_in = D1 = D2 = D_out = 1 + X = np.linspace(0, 10, 500).reshape(-1, 1) + Y = np.random.randn(500, 1) + Xlam = np.c_[X, np.zeros_like(X)] + + # Layer 1 + Z1 = X.copy() + feat1 = gpflow.features.InducingPoints(Z1) + kern1 = gpflow.kernels.RBF(D_in, lengthscales=0.1) + layer1 = gpflux.layers.GPLayer(kern1, feat1, D1) + + # Layer 2 + Z2 = Xlam.copy() + feat2 = gpflow.features.InducingPoints(Z2) + kern2 = NonstationaryKernel(gpflow.kernels.RBF(D1), D_in, scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor) + layer2 = gpflux.layers.NonstationaryGPLayer(kern2, feat2, D2) + + # Layer 3 + Z3 = Xlam.copy() + feat3 = gpflow.features.InducingPoints(Z3) + kern3 = NonstationaryKernel(gpflow.kernels.RBF(D2), D1, scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor) + layer3 = gpflux.layers.NonstationaryGPLayer(kern3, feat3, D_out) + + model = DeepGP(X, Y, [layer1, layer2, layer3]) + + # minimize + likelihood_before_opt = model.compute_log_likelihood() + gpflow.train.AdamOptimizer(0.01).minimize(model, maxiter=100) + likelihood_after_opt = model.compute_log_likelihood() + + assert likelihood_before_opt < likelihood_after_opt From e3f1ff23e47b601e6d3b7ec4d8f3d70a3b02d12f Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 28 Oct 2022 14:44:29 +0300 Subject: [PATCH 2/2] make format --- docs/notebooks/deep_nonstationary_gp_samples.py | 16 ++++++++++++---- gpflux/kernels.py | 11 ++++++----- gpflux/layers/nonstationary_layer.py | 7 +++---- .../test_nonstationary_integration.py | 16 ++++++++++++---- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/docs/notebooks/deep_nonstationary_gp_samples.py b/docs/notebooks/deep_nonstationary_gp_samples.py index cbe3378a..0ca7a45c 100644 --- a/docs/notebooks/deep_nonstationary_gp_samples.py +++ b/docs/notebooks/deep_nonstationary_gp_samples.py @@ -41,15 +41,23 @@ # Layer 2 Z2 = Xlam.copy() feat2 = gpflow.features.InducingPoints(Z2) -kern2 = NonstationaryKernel(gpflow.kernels.RBF(D1), D_in, scaling_offset=-0.1, - positivity=gpflow.transforms.positive.forward_tensor) +kern2 = NonstationaryKernel( + gpflow.kernels.RBF(D1), + D_in, + scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor, +) layer2 = gpflux.layers.NonstationaryGPLayer(kern2, feat2, D2) # Layer 3 Z3 = Xlam.copy() feat3 = gpflow.features.InducingPoints(Z3) -kern3 = NonstationaryKernel(gpflow.kernels.RBF(D2), D1, scaling_offset=-0.1, - positivity=gpflow.transforms.positive.forward_tensor) +kern3 = NonstationaryKernel( + gpflow.kernels.RBF(D2), + D1, + scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor, +) layer3 = gpflux.layers.NonstationaryGPLayer(kern3, feat3, D_out) model = gpflux.DeepGP(np.empty((1, 1)), np.empty((1, 1)), [layer1, layer2, layer3]) diff --git a/gpflux/kernels.py b/gpflux/kernels.py index 5843459f..630c5ffa 100644 --- a/gpflux/kernels.py +++ b/gpflux/kernels.py @@ -61,8 +61,8 @@ def _split_scale(self, X): """ X has data_dim + lengthscales_dim elements """ - original_input = X[..., :self.data_dim] - log_lengthscales = X[..., self.data_dim:] + original_input = X[..., : self.data_dim] + log_lengthscales = X[..., self.data_dim :] lengthscales = self._positivity(log_lengthscales + self.scaling_offset) return original_input, lengthscales @@ -82,10 +82,11 @@ def K(self, X, X2=None): # axis=0 and axis=-3 are equivalent, and axis=1 and axis=-2 are equivalent lengthscales1 = tf.expand_dims(lengthscales1, axis=-2) # [N, 1, d] lengthscales2 = tf.expand_dims(lengthscales2, axis=-3) # [1, M, d] - squared_lengthscales_sum = (lengthscales1 ** 2 + lengthscales2 ** 2) # [N, M, d] + squared_lengthscales_sum = lengthscales1 ** 2 + lengthscales2 ** 2 # [N, M, d] - q = tf.reduce_prod(tf.sqrt(2 * lengthscales1 * lengthscales2 / squared_lengthscales_sum), - axis=-1) # [N, M] + q = tf.reduce_prod( + tf.sqrt(2 * lengthscales1 * lengthscales2 / squared_lengthscales_sum), axis=-1 + ) # [N, M] if self.lengthscales_dim == 1: # also use this when data_dim==1 # Last dimension should be 1, so the reduce_prod above *should* have no effect here. diff --git a/gpflux/layers/nonstationary_layer.py b/gpflux/layers/nonstationary_layer.py index 81898316..8e4f05f8 100644 --- a/gpflux/layers/nonstationary_layer.py +++ b/gpflux/layers/nonstationary_layer.py @@ -32,10 +32,9 @@ class NonstationaryGPLayer(GPLayer): - def __init__(self, - kernel: NonstationaryKernel, - feature: gpflow.features.InducingFeature, - *args, **kwargs): + def __init__( + self, kernel: NonstationaryKernel, feature: gpflow.features.InducingFeature, *args, **kwargs + ): assert isinstance(kernel, NonstationaryKernel) if isinstance(feature, gpflow.features.InducingPoints): assert feature.Z.shape[1] == kernel.input_dim diff --git a/tests/integration/test_nonstationary_integration.py b/tests/integration/test_nonstationary_integration.py index 8c8814b7..31b894c0 100644 --- a/tests/integration/test_nonstationary_integration.py +++ b/tests/integration/test_nonstationary_integration.py @@ -43,15 +43,23 @@ def test_nonstationary_gp_1d(session_tf): # Layer 2 Z2 = Xlam.copy() feat2 = gpflow.features.InducingPoints(Z2) - kern2 = NonstationaryKernel(gpflow.kernels.RBF(D1), D_in, scaling_offset=-0.1, - positivity=gpflow.transforms.positive.forward_tensor) + kern2 = NonstationaryKernel( + gpflow.kernels.RBF(D1), + D_in, + scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor, + ) layer2 = gpflux.layers.NonstationaryGPLayer(kern2, feat2, D2) # Layer 3 Z3 = Xlam.copy() feat3 = gpflow.features.InducingPoints(Z3) - kern3 = NonstationaryKernel(gpflow.kernels.RBF(D2), D1, scaling_offset=-0.1, - positivity=gpflow.transforms.positive.forward_tensor) + kern3 = NonstationaryKernel( + gpflow.kernels.RBF(D2), + D1, + scaling_offset=-0.1, + positivity=gpflow.transforms.positive.forward_tensor, + ) layer3 = gpflux.layers.NonstationaryGPLayer(kern3, feat3, D_out) model = DeepGP(X, Y, [layer1, layer2, layer3])