Skip to content

Commit

Permalink
Merge branch '6.3.x'
Browse files Browse the repository at this point in the history
  • Loading branch information
jzheaux committed Dec 6, 2024
2 parents a446968 + dd8ee38 commit 4dd00fe
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.web.server;

import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.json.GsonHttpMessageConverter;
import org.springframework.http.converter.json.JsonbHttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.util.ClassUtils;

/**
* Utility methods for {@link HttpMessageConverter}'s.
*
* @author Joe Grandja
* @author luamas
* @since 5.1
*/
final class HttpMessageConverters {

private static final boolean jackson2Present;

private static final boolean gsonPresent;

private static final boolean jsonbPresent;

static {
ClassLoader classLoader = HttpMessageConverters.class.getClassLoader();
jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader)
&& ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader);
gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader);
jsonbPresent = ClassUtils.isPresent("jakarta.json.bind.Jsonb", classLoader);
}

private HttpMessageConverters() {
}

static GenericHttpMessageConverter<Object> getJsonMessageConverter() {
if (jackson2Present) {
return new MappingJackson2HttpMessageConverter();
}
if (gsonPresent) {
return new GsonHttpMessageConverter();
}
if (jsonbPresent) {
return new JsonbHttpMessageConverter();
}
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.web.server;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.Map;

import org.jetbrains.annotations.NotNull;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageEncoder;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.util.MimeType;

class OAuth2ErrorEncoder implements HttpMessageEncoder<OAuth2Error> {

private final HttpMessageConverter<Object> messageConverter = HttpMessageConverters.getJsonMessageConverter();

@NotNull
@Override
public List<MediaType> getStreamingMediaTypes() {
return List.of();
}

@Override
public boolean canEncode(ResolvableType elementType, MimeType mimeType) {
return getEncodableMimeTypes().contains(mimeType);
}

@NotNull
@Override
public Flux<DataBuffer> encode(Publisher<? extends OAuth2Error> error, DataBufferFactory bufferFactory,
ResolvableType elementType, MimeType mimeType, Map<String, Object> hints) {
return Mono.from(error).flatMap((data) -> {
ByteArrayHttpOutputMessage bytes = new ByteArrayHttpOutputMessage();
try {
this.messageConverter.write(data, MediaType.APPLICATION_JSON, bytes);
return Mono.just(bytes.getBody().toByteArray());
}
catch (IOException ex) {
return Mono.error(ex);
}
}).map(bufferFactory::wrap).flux();
}

@NotNull
@Override
public List<MimeType> getEncodableMimeTypes() {
return List.of(MediaType.APPLICATION_JSON);
}

private static class ByteArrayHttpOutputMessage implements HttpOutputMessage {

private final ByteArrayOutputStream body = new ByteArrayOutputStream();

@NotNull
@Override
public ByteArrayOutputStream getBody() {
return this.body;
}

@NotNull
@Override
public HttpHeaders getHeaders() {
return new HttpHeaders();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

package org.springframework.security.config.web.server;

import java.nio.charset.StandardCharsets;
import java.util.Collections;

import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.core.ResolvableType;
import org.springframework.http.MediaType;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
Expand Down Expand Up @@ -61,6 +62,9 @@ class OidcBackChannelLogoutWebFilter implements WebFilter {

private final ServerLogoutHandler logoutHandler;

private final HttpMessageWriter<OAuth2Error> errorHttpMessageConverter = new EncoderHttpMessageWriter<>(
new OAuth2ErrorEncoder());

/**
* Construct an {@link OidcBackChannelLogoutWebFilter}
* @param authenticationConverter the {@link AuthenticationConverter} for deriving
Expand All @@ -85,7 +89,7 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (ex instanceof AuthenticationServiceException) {
return Mono.error(ex);
}
return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty());
return handleAuthenticationFailure(exchange, ex).then(Mono.empty());
})
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.flatMap(this.authenticationManager::authenticate)
Expand All @@ -94,27 +98,20 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (ex instanceof AuthenticationServiceException) {
return Mono.error(ex);
}
return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty());
return handleAuthenticationFailure(exchange, ex).then(Mono.empty());
})
.flatMap((authentication) -> {
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
return this.logoutHandler.logout(webFilterExchange, authentication);
});
}

private Mono<Void> handleAuthenticationFailure(ServerHttpResponse response, Exception ex) {
private Mono<Void> handleAuthenticationFailure(ServerWebExchange exchange, Exception ex) {
this.logger.debug("Failed to process OIDC Back-Channel Logout", ex);
response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
OAuth2Error error = oauth2Error(ex);
byte[] bytes = String.format("""
{
"error_code": "%s",
"error_description": "%s",
"error_uri: "%s"
}
""", error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = response.bufferFactory().wrap(bytes);
return response.writeWith(Flux.just(buffer));
exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
return this.errorHttpMessageConverter.write(Mono.just(oauth2Error(ex)), ResolvableType.forClass(Object.class),
ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(),
exchange.getResponse(), Collections.emptyMap());
}

private OAuth2Error oauth2Error(Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,24 @@

package org.springframework.security.config.web.server;

import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.ResolvableType;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.oidc.server.session.ReactiveOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
Expand All @@ -44,6 +45,7 @@
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

Expand All @@ -63,6 +65,9 @@ public final class OidcBackChannelServerLogoutHandler implements ServerLogoutHan

private final ReactiveOidcSessionRegistry sessionRegistry;

private final HttpMessageWriter<OAuth2Error> errorHttpMessageConverter = new EncoderHttpMessageWriter<>(
new OAuth2ErrorEncoder());

private WebClient web = WebClient.create();

private String logoutUri = "{baseUrl}/logout/connect/back-channel/{registrationId}";
Expand Down Expand Up @@ -101,7 +106,7 @@ public Mono<Void> logout(WebFilterExchange exchange, Authentication authenticati
totalCount.intValue()));
}
if (!list.isEmpty()) {
return handleLogoutFailure(exchange.getExchange().getResponse(), oauth2Error(list));
return handleLogoutFailure(exchange.getExchange(), oauth2Error(list));
}
else {
return Mono.empty();
Expand Down Expand Up @@ -164,17 +169,11 @@ private OAuth2Error oauth2Error(Collection<?> errors) {
"https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation");
}

private Mono<Void> handleLogoutFailure(ServerHttpResponse response, OAuth2Error error) {
response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
byte[] bytes = String.format("""
{
"error_code": "%s",
"error_description": "%s",
"error_uri: "%s"
}
""", error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = response.bufferFactory().wrap(bytes);
return response.writeWith(Flux.just(buffer));
private Mono<Void> handleLogoutFailure(ServerWebExchange exchange, OAuth2Error error) {
exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
return this.errorHttpMessageConverter.write(Mono.just(error), ResolvableType.forClass(Object.class),
ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(),
exchange.getResponse(), Collections.emptyMap());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.annotation.Order;
import org.springframework.http.ResponseCookie;
import org.springframework.http.client.reactive.ClientHttpConnector;
Expand Down Expand Up @@ -101,6 +102,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
Expand Down Expand Up @@ -199,7 +201,10 @@ void logoutWhenInvalidLogoutTokenThenBadRequest() {
.body(BodyInserters.fromFormData("logout_token", "invalid"))
.exchange()
.expectStatus()
.isBadRequest();
.isBadRequest()
.expectBody(new ParameterizedTypeReference<Map<String, String>>() {
})
.value(hasValue("invalid_request"));
this.test.get().uri("/token/logout").cookie("SESSION", session).exchange().expectStatus().isOk();
}

Expand Down Expand Up @@ -266,9 +271,10 @@ void logoutWhenRemoteLogoutUriThenUses() {
.exchange()
.expectStatus()
.isBadRequest()
.expectBody(String.class)
.value(containsString("partial_logout"))
.value(containsString("not all sessions were terminated"));
.expectBody(new ParameterizedTypeReference<Map<String, String>>() {
})
.value(hasValue("partial_logout"))
.value(hasValue(containsString("not all sessions were terminated")));
this.test.get().uri("/token/logout").cookie("SESSION", one).exchange().expectStatus().isOk();
}

Expand Down

0 comments on commit 4dd00fe

Please sign in to comment.