diff --git a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java index 6537723fcc0..77851ede00f 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java @@ -16,19 +16,22 @@ package org.springframework.security.web; import java.io.IOException; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.security.web.util.UrlUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; /** * Simple implementation of RedirectStrategy which is the default used throughout * the framework. * * @author Luke Taylor + * @author Josh Cummings * @since 3.0 */ public class DefaultRedirectStrategy implements RedirectStrategy { @@ -36,6 +39,7 @@ public class DefaultRedirectStrategy implements RedirectStrategy { protected final Log logger = LogFactory.getLog(getClass()); private boolean contextRelative; + private boolean hostRelative = true; /** * Redirects the response to the supplied URL. @@ -68,25 +72,40 @@ protected String calculateRedirectUrl(String contextPath, String url) { } // Full URL, including http(s):// + boolean hostRelative = this.hostRelative; + boolean contextRelative = isContextRelative(); - if (!isContextRelative()) { + if (!hostRelative && !contextRelative) { return url; } - // Calculate the relative URL from the fully qualified URL, minus the last - // occurrence of the scheme and base context. - url = url.substring(url.lastIndexOf("://") + 3); // strip off scheme - url = url.substring(url.indexOf(contextPath) + contextPath.length()); + UriComponents components = UriComponentsBuilder + .fromHttpUrl(url).build(); - if (url.length() > 1 && url.charAt(0) == '/') { - url = url.substring(1); + String path = components.getPath(); + if (contextRelative) { + path = path.substring(path.indexOf(contextPath) + contextPath.length()); + if (path.length() > 1 && path.charAt(0) == '/') { + path = path.substring(1); + } } - return url; + return UriComponentsBuilder + .fromPath(path) + .query(components.getQuery()) + .build().toString(); + } + + /** + * If true, causes any redirection URLs to be calculated minus the authority + * (defaults to true). + */ + public void setHostRelative(boolean hostRelative) { + this.hostRelative = hostRelative; } /** - * If true, causes any redirection URLs to be calculated minus the protocol + * If true, causes any redirection URLs to be calculated minus the authority * and context path (defaults to false). */ public void setContextRelative(boolean useRelativeContext) { diff --git a/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java b/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java index 925feb900d4..d71a03bcf43 100644 --- a/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java @@ -15,12 +15,13 @@ */ package org.springframework.security.web; -import static org.assertj.core.api.Assertions.*; - import org.junit.Test; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import static org.assertj.core.api.Assertions.assertThat; + /** * * @author Luke Taylor @@ -56,4 +57,34 @@ public void contextRelativeUrlWithMultipleSchemesInHostnameIsHandledCorrectly() assertThat(response.getRedirectedUrl()).isEqualTo("remainder"); } + + // gh-7273 + @Test + public void sendRedirectWhenUsingDefaultsThenRemovesHost() + throws Exception { + DefaultRedirectStrategy rds = new DefaultRedirectStrategy(); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContextPath("/context"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + rds.sendRedirect(request, response, + "https://context.blah.com/context/remainder"); + assertThat(response.getRedirectedUrl()).isEqualTo("/context/remainder"); + } + + // gh-7273 + @Test + public void sendRedirectWhenHostRelativeFalseThenKeepsHost() + throws Exception { + DefaultRedirectStrategy rds = new DefaultRedirectStrategy(); + rds.setHostRelative(false); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContextPath("/context"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + rds.sendRedirect(request, response, + "https://context.blah.com/context/remainder"); + + assertThat(response.getRedirectedUrl()).isEqualTo("https://context.blah.com/context/remainder"); + } }