From dbe32b57460c18a2d07ee95313430aa67f0d1689 Mon Sep 17 00:00:00 2001 From: Nate McMaster Date: Fri, 29 Mar 2024 21:20:51 -0700 Subject: [PATCH] fix: update UseServerCertificateSelector to call the original selector (#290) --- Directory.Build.props | 2 +- .../KestrelHttpsOptionsExtensions.cs | 9 ++- .../KestrelHttpsOptionsExtensionsTests.cs | 65 +++++++++++++++++++ 3 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 test/LettuceEncrypt.UnitTests/KestrelHttpsOptionsExtensionsTests.cs diff --git a/Directory.Build.props b/Directory.Build.props index 8bc163e2..f5fc5edd 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -34,7 +34,7 @@ - 1.3.0 + 1.3.1 beta true $([MSBuild]::ValueOrDefault($(BUILD_NUMBER), 0)) diff --git a/src/Kestrel.Certificates/KestrelHttpsOptionsExtensions.cs b/src/Kestrel.Certificates/KestrelHttpsOptionsExtensions.cs index 02e80147..54e27841 100644 --- a/src/Kestrel.Certificates/KestrelHttpsOptionsExtensions.cs +++ b/src/Kestrel.Certificates/KestrelHttpsOptionsExtensions.cs @@ -22,7 +22,14 @@ public static HttpsConnectionAdapterOptions UseServerCertificateSelector( this HttpsConnectionAdapterOptions httpsOptions, IServerCertificateSelector certificateSelector) { - httpsOptions.ServerCertificateSelector = certificateSelector.Select!; + var fallbackSelector = httpsOptions.ServerCertificateSelector; + httpsOptions.ServerCertificateSelector = (connectionContext, domainName) => + { + var primaryCert = certificateSelector.Select(connectionContext!, domainName); + // fallback to the original selector if the injected selector fails to find a certificate. + return primaryCert ?? fallbackSelector?.Invoke(connectionContext, domainName); + }; + return httpsOptions; } } diff --git a/test/LettuceEncrypt.UnitTests/KestrelHttpsOptionsExtensionsTests.cs b/test/LettuceEncrypt.UnitTests/KestrelHttpsOptionsExtensionsTests.cs new file mode 100644 index 00000000..0adaaef8 --- /dev/null +++ b/test/LettuceEncrypt.UnitTests/KestrelHttpsOptionsExtensionsTests.cs @@ -0,0 +1,65 @@ +// Copyright (c) Nate McMaster. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Security.Cryptography.X509Certificates; +using McMaster.AspNetCore.Kestrel.Certificates; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Moq; +using Xunit; + +namespace LettuceEncrypt.UnitTests; + +using SelectorFunc = Func; + +public class KestrelHttpsOptionsExtensionsTests +{ + [Fact] + public void UseServerCertificateSelectorFallsbackToOriginalSelector() + { + var injectedSelector = new Mock(); + injectedSelector + .Setup(c => c.Select(It.IsAny(), It.IsAny())) + .Returns(() => null); + + var originalSelectorWasCalled = false; + SelectorFunc originalSelector = (_, __) => { originalSelectorWasCalled = true; return null; }; + + var options = new HttpsConnectionAdapterOptions + { + ServerCertificateSelector = originalSelector + }; + + KestrelHttpsOptionsExtensions.UseServerCertificateSelector(options, injectedSelector.Object); + options.ServerCertificateSelector(null, null); + + Assert.NotSame(options.ServerCertificateSelector, originalSelector); + Assert.True(originalSelectorWasCalled); + injectedSelector.VerifyAll(); + } + + [Fact] + public void UseServerCertificateSelectorDoesNotCallFallback() + { + var injectedSelector = new Mock(); + injectedSelector + .Setup(c => c.Select(It.IsAny(), It.IsAny())) + .Returns(() => TestUtils.CreateTestCert("foo.test")); + + var originalSelectorWasCalled = false; + SelectorFunc originalSelector = (_, __) => { originalSelectorWasCalled = true; return null; }; + + var options = new HttpsConnectionAdapterOptions + { + ServerCertificateSelector = originalSelector + }; + + KestrelHttpsOptionsExtensions.UseServerCertificateSelector(options, injectedSelector.Object); + options.ServerCertificateSelector(null, null); + + Assert.NotSame(options.ServerCertificateSelector, originalSelector); + Assert.False(originalSelectorWasCalled); + injectedSelector.VerifyAll(); + } +}