diff --git a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java
index 6537723fcc0..77851ede00f 100644
--- a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java
+++ b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java
@@ -16,19 +16,22 @@
package org.springframework.security.web;
import java.io.IOException;
-
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+
import org.springframework.security.web.util.UrlUtils;
+import org.springframework.web.util.UriComponents;
+import org.springframework.web.util.UriComponentsBuilder;
/**
* Simple implementation of RedirectStrategy which is the default used throughout
* the framework.
*
* @author Luke Taylor
+ * @author Josh Cummings
* @since 3.0
*/
public class DefaultRedirectStrategy implements RedirectStrategy {
@@ -36,6 +39,7 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
protected final Log logger = LogFactory.getLog(getClass());
private boolean contextRelative;
+ private boolean hostRelative = true;
/**
* Redirects the response to the supplied URL.
@@ -68,25 +72,40 @@ protected String calculateRedirectUrl(String contextPath, String url) {
}
// Full URL, including http(s)://
+ boolean hostRelative = this.hostRelative;
+ boolean contextRelative = isContextRelative();
- if (!isContextRelative()) {
+ if (!hostRelative && !contextRelative) {
return url;
}
- // Calculate the relative URL from the fully qualified URL, minus the last
- // occurrence of the scheme and base context.
- url = url.substring(url.lastIndexOf("://") + 3); // strip off scheme
- url = url.substring(url.indexOf(contextPath) + contextPath.length());
+ UriComponents components = UriComponentsBuilder
+ .fromHttpUrl(url).build();
- if (url.length() > 1 && url.charAt(0) == '/') {
- url = url.substring(1);
+ String path = components.getPath();
+ if (contextRelative) {
+ path = path.substring(path.indexOf(contextPath) + contextPath.length());
+ if (path.length() > 1 && path.charAt(0) == '/') {
+ path = path.substring(1);
+ }
}
- return url;
+ return UriComponentsBuilder
+ .fromPath(path)
+ .query(components.getQuery())
+ .build().toString();
+ }
+
+ /**
+ * If true, causes any redirection URLs to be calculated minus the authority
+ * (defaults to true).
+ */
+ public void setHostRelative(boolean hostRelative) {
+ this.hostRelative = hostRelative;
}
/**
- * If true, causes any redirection URLs to be calculated minus the protocol
+ * If true, causes any redirection URLs to be calculated minus the authority
* and context path (defaults to false).
*/
public void setContextRelative(boolean useRelativeContext) {
diff --git a/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java b/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java
index 925feb900d4..d71a03bcf43 100644
--- a/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java
+++ b/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java
@@ -15,12 +15,13 @@
*/
package org.springframework.security.web;
-import static org.assertj.core.api.Assertions.*;
-
import org.junit.Test;
+
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
+import static org.assertj.core.api.Assertions.assertThat;
+
/**
*
* @author Luke Taylor
@@ -56,4 +57,34 @@ public void contextRelativeUrlWithMultipleSchemesInHostnameIsHandledCorrectly()
assertThat(response.getRedirectedUrl()).isEqualTo("remainder");
}
+
+ // gh-7273
+ @Test
+ public void sendRedirectWhenUsingDefaultsThenRemovesHost()
+ throws Exception {
+ DefaultRedirectStrategy rds = new DefaultRedirectStrategy();
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setContextPath("/context");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+
+ rds.sendRedirect(request, response,
+ "https://context.blah.com/context/remainder");
+ assertThat(response.getRedirectedUrl()).isEqualTo("/context/remainder");
+ }
+
+ // gh-7273
+ @Test
+ public void sendRedirectWhenHostRelativeFalseThenKeepsHost()
+ throws Exception {
+ DefaultRedirectStrategy rds = new DefaultRedirectStrategy();
+ rds.setHostRelative(false);
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ request.setContextPath("/context");
+ MockHttpServletResponse response = new MockHttpServletResponse();
+
+ rds.sendRedirect(request, response,
+ "https://context.blah.com/context/remainder");
+
+ assertThat(response.getRedirectedUrl()).isEqualTo("https://context.blah.com/context/remainder");
+ }
}