Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support GenerateOneTimeTokenRequestResolver #16297

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import java.util.Collections;
import java.util.Map;
import java.util.Objects;

import jakarta.servlet.http.HttpServletRequest;

import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.ott.GenerateOneTimeTokenRequest;
import org.springframework.security.authentication.ott.InMemoryOneTimeTokenService;
import org.springframework.security.authentication.ott.OneTimeToken;
import org.springframework.security.authentication.ott.OneTimeTokenAuthenticationProvider;
Expand All @@ -40,7 +42,9 @@
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import org.springframework.security.web.authentication.ott.DefaultGenerateOneTimeTokenRequestResolver;
import org.springframework.security.web.authentication.ott.GenerateOneTimeTokenFilter;
import org.springframework.security.web.authentication.ott.GenerateOneTimeTokenRequestResolver;
import org.springframework.security.web.authentication.ott.OneTimeTokenAuthenticationConverter;
import org.springframework.security.web.authentication.ott.OneTimeTokenGenerationSuccessHandler;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
Expand Down Expand Up @@ -79,6 +83,8 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>

private AuthenticationProvider authenticationProvider;

private GenerateOneTimeTokenRequestResolver requestResolver;

public OneTimeTokenLoginConfigurer(ApplicationContext context) {
this.context = context;
}
Expand Down Expand Up @@ -135,6 +141,7 @@ private void configureOttGenerateFilter(H http) {
GenerateOneTimeTokenFilter generateFilter = new GenerateOneTimeTokenFilter(getOneTimeTokenService(http),
getOneTimeTokenGenerationSuccessHandler(http));
generateFilter.setRequestMatcher(antMatcher(HttpMethod.POST, this.tokenGeneratingUrl));
generateFilter.setRequestResolver(getGenerateRequestResolver(http));
http.addFilter(postProcess(generateFilter));
http.addFilter(DefaultResourcesFilter.css());
}
Expand Down Expand Up @@ -301,6 +308,28 @@ private AuthenticationFailureHandler getAuthenticationFailureHandler() {
return this.authenticationFailureHandler;
}

/**
* Use this {@link GenerateOneTimeTokenRequestResolver} when resolving
* {@link GenerateOneTimeTokenRequest} from {@link HttpServletRequest}. By default,
* the {@link DefaultGenerateOneTimeTokenRequestResolver} is used.
* @param requestResolver the {@link GenerateOneTimeTokenRequestResolver}
* @since 6.5
*/
public OneTimeTokenLoginConfigurer<H> generateRequestResolver(GenerateOneTimeTokenRequestResolver requestResolver) {
Assert.notNull(requestResolver, "requestResolver cannot be null");
this.requestResolver = requestResolver;
return this;
}

private GenerateOneTimeTokenRequestResolver getGenerateRequestResolver(H http) {
if (this.requestResolver != null) {
return this.requestResolver;
}
GenerateOneTimeTokenRequestResolver bean = getBeanOrNull(http, GenerateOneTimeTokenRequestResolver.class);
this.requestResolver = Objects.requireNonNullElseGet(bean, DefaultGenerateOneTimeTokenRequestResolver::new);
return this.requestResolver;
}

private OneTimeTokenService getOneTimeTokenService(H http) {
if (this.oneTimeTokenService != null) {
return this.oneTimeTokenService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.security.config.annotation.web.configurers.ott;

import java.io.IOException;
import java.time.Instant;
import java.time.ZoneOffset;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -29,6 +31,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.security.authentication.ott.GenerateOneTimeTokenRequest;
import org.springframework.security.authentication.ott.OneTimeToken;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
Expand All @@ -40,6 +43,8 @@
import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
import org.springframework.security.web.authentication.ott.DefaultGenerateOneTimeTokenRequestResolver;
import org.springframework.security.web.authentication.ott.GenerateOneTimeTokenRequestResolver;
import org.springframework.security.web.authentication.ott.OneTimeTokenGenerationSuccessHandler;
import org.springframework.security.web.authentication.ott.RedirectOneTimeTokenGenerationSuccessHandler;
import org.springframework.security.web.csrf.CsrfToken;
Expand Down Expand Up @@ -194,6 +199,55 @@ Please provide it as a bean or pass it to the oneTimeTokenLogin() DSL.
""");
}

@Test
void oneTimeTokenWhenCustomTokenExpirationTimeSetThenAuthenticate() throws Exception {
this.spring.register(OneTimeTokenConfigWithCustomTokenExpirationTime.class).autowire();
this.mvc.perform(post("/ott/generate").param("username", "user").with(csrf()))
.andExpectAll(status().isFound(), redirectedUrl("/login/ott"));

OneTimeToken token = TestOneTimeTokenGenerationSuccessHandler.lastToken;

this.mvc.perform(post("/login/ott").param("token", token.getTokenValue()).with(csrf()))
.andExpectAll(status().isFound(), redirectedUrl("/"), authenticated());
assertThat(getCurrentMinutes(token.getExpiresAt())).isEqualTo(10);
}

private int getCurrentMinutes(Instant expiresAt) {
int expiresMinutes = expiresAt.atZone(ZoneOffset.UTC).getMinute();
int currentMinutes = Instant.now().atZone(ZoneOffset.UTC).getMinute();
return expiresMinutes - currentMinutes;
}

@Configuration(proxyBeanMethods = false)
@EnableWebSecurity
@Import(UserDetailsServiceConfig.class)
static class OneTimeTokenConfigWithCustomTokenExpirationTime {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeHttpRequests((authz) -> authz
.anyRequest().authenticated()
)
.oneTimeTokenLogin((ott) -> ott
.tokenGenerationSuccessHandler(new TestOneTimeTokenGenerationSuccessHandler())
);
// @formatter:on
return http.build();
}

@Bean
GenerateOneTimeTokenRequestResolver generateOneTimeTokenRequestResolver() {
DefaultGenerateOneTimeTokenRequestResolver delegate = new DefaultGenerateOneTimeTokenRequestResolver();
return (request) -> {
GenerateOneTimeTokenRequest generate = delegate.resolve(request);
return new GenerateOneTimeTokenRequest(generate.getUsername(), 600);
};
}

}

@Configuration(proxyBeanMethods = false)
@EnableWebSecurity
@Import(UserDetailsServiceConfig.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,37 @@
*/
public class GenerateOneTimeTokenRequest {

private static final int DEFAULT_EXPIRES_IN = 300;

private final String username;

private final int expiresIn;

public GenerateOneTimeTokenRequest(String username) {
Assert.hasText(username, "username cannot be empty");
this.username = username;
this.expiresIn = DEFAULT_EXPIRES_IN;
}

/**
* Constructs an <code>GenerateOneTimeTokenRequest</code> with the specified username
* and expiresIn
* @param username username
* @param expiresIn one-time token expiration time (seconds)
*/
public GenerateOneTimeTokenRequest(String username, int expiresIn) {
Assert.hasText(username, "username cannot be empty");
Assert.isTrue(expiresIn > 0, "expiresIn must be > 0");
this.username = username;
this.expiresIn = expiresIn;
}

public String getUsername() {
return this.username;
}

public int getExpiresIn() {
return this.expiresIn;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ public final class InMemoryOneTimeTokenService implements OneTimeTokenService {
@NonNull
public OneTimeToken generate(GenerateOneTimeTokenRequest request) {
String token = UUID.randomUUID().toString();
Instant fiveMinutesFromNow = this.clock.instant().plusSeconds(300);
OneTimeToken ott = new DefaultOneTimeToken(token, request.getUsername(), fiveMinutesFromNow);
Instant expiresAt = this.clock.instant().plusSeconds(request.getExpiresIn());
OneTimeToken ott = new DefaultOneTimeToken(token, request.getUsername(), expiresAt);
this.oneTimeTokenByToken.put(token, ott);
cleanExpiredTokensIfNeeded();
return ott;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.sql.Timestamp;
import java.sql.Types;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -132,8 +131,8 @@ public void setCleanupCron(String cleanupCron) {
public OneTimeToken generate(GenerateOneTimeTokenRequest request) {
Assert.notNull(request, "generateOneTimeTokenRequest cannot be null");
String token = UUID.randomUUID().toString();
Instant fiveMinutesFromNow = this.clock.instant().plus(Duration.ofMinutes(5));
OneTimeToken oneTimeToken = new DefaultOneTimeToken(token, request.getUsername(), fiveMinutesFromNow);
Instant expiresAt = this.clock.instant().plusSeconds(request.getExpiresIn());
OneTimeToken oneTimeToken = new DefaultOneTimeToken(token, request.getUsername(), expiresAt);
insertOneTimeToken(oneTimeToken);
return oneTimeToken;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.web.authentication.ott;

import jakarta.servlet.http.HttpServletRequest;

import org.springframework.security.authentication.ott.GenerateOneTimeTokenRequest;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Default implementation of {@link GenerateOneTimeTokenRequestResolver}. Resolves
* {@link GenerateOneTimeTokenRequest} from username parameter.
*
* @author Max Batischev
* @since 6.5
*/
public final class DefaultGenerateOneTimeTokenRequestResolver implements GenerateOneTimeTokenRequestResolver {

private static final int DEFAULT_EXPIRES_IN = 300;

private int expiresIn = DEFAULT_EXPIRES_IN;

@Override
public GenerateOneTimeTokenRequest resolve(HttpServletRequest request) {
String username = request.getParameter("username");
if (!StringUtils.hasText(username)) {
return null;
}
return new GenerateOneTimeTokenRequest(username, this.expiresIn);
}

/**
* Sets one-time token expiration time (seconds)
* @param expiresIn one-time token expiration time
*/
public void setExpiresIn(int expiresIn) {
Assert.isTrue(expiresIn > 0, "expiresAt must be > 0");
this.expiresIn = expiresIn;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.springframework.security.authentication.ott.OneTimeTokenService;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import static org.springframework.security.web.util.matcher.AntPathRequestMatcher.antMatcher;
Expand All @@ -49,6 +48,8 @@ public final class GenerateOneTimeTokenFilter extends OncePerRequestFilter {

private RequestMatcher requestMatcher = antMatcher(HttpMethod.POST, "/ott/generate");

private GenerateOneTimeTokenRequestResolver requestResolver = new DefaultGenerateOneTimeTokenRequestResolver();

public GenerateOneTimeTokenFilter(OneTimeTokenService tokenService,
OneTimeTokenGenerationSuccessHandler tokenGenerationSuccessHandler) {
Assert.notNull(tokenService, "tokenService cannot be null");
Expand All @@ -64,12 +65,11 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
filterChain.doFilter(request, response);
return;
}
String username = request.getParameter("username");
if (!StringUtils.hasText(username)) {
GenerateOneTimeTokenRequest generateRequest = this.requestResolver.resolve(request);
if (generateRequest == null) {
filterChain.doFilter(request, response);
return;
}
GenerateOneTimeTokenRequest generateRequest = new GenerateOneTimeTokenRequest(username);
OneTimeToken ott = this.tokenService.generate(generateRequest);
this.tokenGenerationSuccessHandler.handle(request, response, ott);
}
Expand All @@ -83,4 +83,15 @@ public void setRequestMatcher(RequestMatcher requestMatcher) {
this.requestMatcher = requestMatcher;
}

/**
* Use the given {@link GenerateOneTimeTokenRequestResolver} to resolve
* {@link GenerateOneTimeTokenRequest}.
* @param requestResolver {@link GenerateOneTimeTokenRequestResolver}
* @since 6.5
*/
public void setRequestResolver(GenerateOneTimeTokenRequestResolver requestResolver) {
Assert.notNull(requestResolver, "requestResolver cannot be null");
this.requestResolver = requestResolver;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.web.authentication.ott;

import jakarta.servlet.http.HttpServletRequest;

import org.springframework.lang.Nullable;
import org.springframework.security.authentication.ott.GenerateOneTimeTokenRequest;

/**
* A strategy for resolving a {@link GenerateOneTimeTokenRequest} from the
* {@link HttpServletRequest}.
*
* @author Max Batischev
* @since 6.5
*/
public interface GenerateOneTimeTokenRequestResolver {

/**
* Resolves {@link GenerateOneTimeTokenRequest} from {@link HttpServletRequest}
* @param request {@link HttpServletRequest} to resolve
* @return {@link GenerateOneTimeTokenRequest}
*/
@Nullable
GenerateOneTimeTokenRequest resolve(HttpServletRequest request);

}
Loading
Loading