Skip to content

Commit

Permalink
Refine HaloAuthenticationToken
Browse files Browse the repository at this point in the history
Signed-off-by: JohnNiang <[email protected]>
  • Loading branch information
JohnNiang committed Sep 29, 2024
1 parent 363de39 commit 2edb28c
Show file tree
Hide file tree
Showing 24 changed files with 375 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import java.util.Collection;
import java.util.Collections;
import lombok.Getter;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import run.halo.app.infra.AnonymousUserConst;
import org.springframework.security.oauth2.core.user.OAuth2User;

/**
* Halo OAuth2 authentication token which combines {@link UserDetails} and original
Expand All @@ -17,8 +17,7 @@
* @author johnniang
* @since 2.20.0
*/
// TODO Make the class serializable by JSON
public class HaloOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
public class HaloOAuth2AuthenticationToken extends AbstractAuthenticationToken {

@Getter
private final UserDetails userDetails;
Expand All @@ -35,11 +34,7 @@ public class HaloOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
*/
public HaloOAuth2AuthenticationToken(UserDetails userDetails,
OAuth2AuthenticationToken original) {
super(
original.getPrincipal(),
original.getAuthorities(),
original.getAuthorizedClientRegistrationId()
);
super(combineAuthorities(userDetails, original));
this.userDetails = userDetails;
this.original = original;
setAuthenticated(true);
Expand All @@ -62,6 +57,16 @@ public Collection<GrantedAuthority> getAuthorities() {
return Collections.unmodifiableList(authorities);
}

@Override
public Object getCredentials() {
return "";
}

@Override
public OAuth2User getPrincipal() {
return original.getPrincipal();
}

/**
* Creates an authenticated {@link HaloOAuth2AuthenticationToken} using {@link UserDetails} and
* original {@link OAuth2AuthenticationToken}.
Expand All @@ -76,23 +81,16 @@ public static HaloOAuth2AuthenticationToken authenticated(
return new HaloOAuth2AuthenticationToken(userDetails, original);
}

/**
* Creates an unauthenticated {@link HaloOAuth2AuthenticationToken} using original {@link
* OAuth2AuthenticationToken}.
*
* @param original the original {@link OAuth2AuthenticationToken}
* @return an unauthenticated {@link HaloOAuth2AuthenticationToken}
*/
public static HaloOAuth2AuthenticationToken unauthenticated(
OAuth2AuthenticationToken original
) {
var anonymousUser = User.builder()
.username(AnonymousUserConst.PRINCIPAL)
.authorities("ROLE_" + AnonymousUserConst.Role)
.password("")
.build();
var token = new HaloOAuth2AuthenticationToken(anonymousUser, original);
token.setAuthenticated(false);
return token;
private static Collection<? extends GrantedAuthority> combineAuthorities(
UserDetails userDetails, OAuth2AuthenticationToken original) {
var userDetailsAuthorities = userDetails.getAuthorities();
var originalAuthorities = original.getAuthorities();
var authorities = new ArrayList<GrantedAuthority>(
originalAuthorities.size() + userDetailsAuthorities.size()
);
authorities.addAll(originalAuthorities);
authorities.addAll(userDetailsAuthorities);
return Collections.unmodifiableList(authorities);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ Mono<UserConnection> createUserConnection(
);

/**
* Get user connection and then update it.
* Get user connection by registration id and OAuth2 user.
* If found, update updatedAt timestamp of the user connection.
*
* @param registrationId Registration id
* @param oauth2User OAuth2 user
* @return Updated user connection or empty
*/
Mono<UserConnection> getAndUpdateUserConnection(
Mono<UserConnection> getUserConnection(
String registrationId, OAuth2User oauth2User
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private Mono<UserConnection> updateUserConnection(UserConnection connection,
}

@Override
public Mono<UserConnection> getAndUpdateUserConnection(String registrationId,
public Mono<UserConnection> getUserConnection(String registrationId,
OAuth2User oauth2User) {
var listOptions = ListOptions.builder()
.fieldQuery(and(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import static org.springframework.util.ResourceUtils.FILE_URL_PREFIX;
import static org.springframework.web.reactive.function.server.RequestPredicates.accept;
import static org.springframework.web.reactive.function.server.RequestPredicates.method;
import static org.springframework.web.reactive.function.server.RequestPredicates.path;
import static org.springframework.web.reactive.function.server.RouterFunctions.route;
import static run.halo.app.infra.utils.FileUtils.checkDirectoryTraversal;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -18,7 +16,6 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.http.CacheControl;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.http.codec.HttpMessageWriter;
Expand All @@ -30,15 +27,14 @@
import org.springframework.web.reactive.config.ResourceHandlerRegistration;
import org.springframework.web.reactive.config.ResourceHandlerRegistry;
import org.springframework.web.reactive.config.WebFluxConfigurer;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerResponse;
import org.springframework.web.reactive.resource.EncodedResourceResolver;
import org.springframework.web.reactive.resource.PathResourceResolver;
import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter;
import org.springframework.web.reactive.result.view.ViewResolutionResultHandler;
import org.springframework.web.reactive.result.view.ViewResolver;
import reactor.core.publisher.Mono;
import run.halo.app.core.endpoint.WebSocketHandlerMapping;
import run.halo.app.core.endpoint.console.CustomEndpointsBuilder;
import run.halo.app.core.extension.endpoint.CustomEndpoint;
Expand Down Expand Up @@ -126,34 +122,33 @@ public WebSocketHandlerMapping webSocketHandlerMapping() {
}

@Bean
RouterFunction<ServerResponse> consoleIndexRedirection() {
var consolePredicate = method(HttpMethod.GET)
.and(path("/console/**").and(path("/console/assets/**").negate()))
RouterFunction<ServerResponse> consoleEndpoints() {
var consolePredicate = path("/console/**").and(path("/console/assets/**").negate())
.and(accept(MediaType.TEXT_HTML))
.and(new WebSocketRequestPredicate().negate());
return route(consolePredicate,
request -> this.serveIndex(haloProp.getConsole().getLocation() + "index.html"));
}

@Bean
RouterFunction<ServerResponse> ucIndexRedirect() {
var consolePredicate = method(HttpMethod.GET)
.and(path("/uc/**").and(path("/uc/assets/**").negate()))
var ucPredicate = path("/uc/**").and(path("/uc/assets/**").negate())
.and(accept(MediaType.TEXT_HTML))
.and(new WebSocketRequestPredicate().negate());
return route(consolePredicate,
request -> this.serveIndex(haloProp.getUc().getLocation() + "index.html"));
}

private Mono<ServerResponse> serveIndex(String indexLocation) {
var indexResource = applicationContext.getResource(indexLocation);
try {
return ServerResponse.ok()
.cacheControl(CacheControl.noStore())
.body(BodyInserters.fromResource(indexResource));
} catch (Throwable e) {
return Mono.error(e);
}
var consoleIndexHtml =
applicationContext.getResource(haloProp.getConsole().getLocation() + "index.html");

var ucIndexHtml =
applicationContext.getResource(haloProp.getUc().getLocation() + "index.html");

return RouterFunctions.route()
.GET(consolePredicate,
request -> ServerResponse.ok()
.cacheControl(CacheControl.noStore())
.bodyValue(consoleIndexHtml)
)
.GET(ucPredicate,
request -> ServerResponse.ok()
.cacheControl(CacheControl.noStore())
.bodyValue(ucIndexHtml)
)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.authorization.AuthorizationDecision;
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.core.Authentication;
import org.springframework.security.crypto.factory.PasswordEncoderFactories;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.web.server.SecurityWebFilterChain;
Expand All @@ -28,6 +30,7 @@
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.session.MapSession;
import org.springframework.session.config.annotation.web.server.EnableSpringWebSession;
import reactor.core.publisher.Mono;
import run.halo.app.core.user.service.RoleService;
import run.halo.app.core.user.service.UserService;
import run.halo.app.extension.ReactiveExtensionClient;
Expand Down Expand Up @@ -68,8 +71,6 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http,
var staticResourcesMatcher = pathMatchers(HttpMethod.GET,
"/themes/{themeName}/assets/{*resourcePaths}",
"/plugins/{pluginName}/assets/**",
"/console/**",
"/uc/**",
"/upload/**",
"/webjars/**",
"/js/**",
Expand All @@ -86,9 +87,22 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http,
"/api/**",
"/apis/**",
"/actuator/**"
).access(new RequestInfoAuthorizationManager(roleService))
.pathMatchers(
"/login/**",
"/challenges/**",
"/password-reset/**",
"/signup",
"/logout"
).permitAll()
.pathMatchers("/console/**", "/uc/**").authenticated()
.matchers(createHtmlMatcher()).access((authentication, context) ->
// we only need to check the authentication is authenticated
// because we treat anonymous user as authenticated
authentication.map(Authentication::isAuthenticated)
.map(AuthorizationDecision::new)
.switchIfEmpty(Mono.fromSupplier(() -> new AuthorizationDecision(false)))
)
.access(new RequestInfoAuthorizationManager(roleService))
.matchers(createHtmlMatcher()).authenticated()
.anyExchange().permitAll())
.anonymous(spec -> {
spec.authorities(AuthorityUtils.ROLE_PREFIX + AnonymousUserConst.Role);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ public static ApplicationContext create(ApplicationContext rootContext) {
.ifUnique(userDetailsService ->
beanFactory.registerSingleton("userDetailsService", userDetailsService)
);
rootContext.getBeanProvider(UserConnectionService.class)
.ifUnique(userConnectionService ->
beanFactory.registerSingleton("userConnectionService", userConnectionService)
);
// TODO add more shared instance here

sharedContext.refresh();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.server.ServerResponse;
import run.halo.app.security.authentication.SecurityConfigurer;
import run.halo.app.security.authentication.oauth2.OAuth2AuthenticationEntryPoint;
import run.halo.app.security.authentication.oauth2.OAuth2UserUnboundAccessDeniedHandler;
import run.halo.app.security.authentication.twofactor.TwoFactorAuthenticationEntryPoint;

Expand All @@ -34,7 +33,8 @@ public void configure(ServerHttpSecurity http) {
http.exceptionHandling(exception -> {
var accessDeniedHandlers =
new ArrayList<ServerWebExchangeDelegatingServerAccessDeniedHandler.DelegateEntry>(
2);
2
);
accessDeniedHandlers.add(
new ServerWebExchangeDelegatingServerAccessDeniedHandler.DelegateEntry(
OAuth2UserUnboundAccessDeniedHandler.MATCHER,
Expand All @@ -54,10 +54,6 @@ public void configure(ServerHttpSecurity http) {
TwoFactorAuthenticationEntryPoint.MATCHER,
new TwoFactorAuthenticationEntryPoint(messageSource, context)
));
entryPoints.add(new DelegatingServerAuthenticationEntryPoint.DelegateEntry(
OAuth2AuthenticationEntryPoint.MATCHER,
new OAuth2AuthenticationEntryPoint()
));
entryPoints.add(new DelegatingServerAuthenticationEntryPoint.DelegateEntry(
exchange -> ServerWebExchangeMatcher.MatchResult.match(),
new DefaultServerAuthenticationEntryPoint()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package run.halo.app.security.authentication.oauth2;

import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
Expand All @@ -20,28 +22,30 @@ public class DefaultOAuth2LoginHandlerEnhancer implements OAuth2LoginHandlerEnha

private final UserConnectionService connectionService;

@Setter
private OAuth2AuthenticationTokenCache oauth2TokenCache =
new WebSessionOAuth2AuthenticationTokenCache();

private final AuthenticationTrustResolver authenticationTrustResolver =
new AuthenticationTrustResolverImpl();

public DefaultOAuth2LoginHandlerEnhancer(UserConnectionService connectionService) {
this.connectionService = connectionService;
}

@Override
public Mono<Void> loginSuccess(ServerWebExchange exchange, Authentication authentication) {
if (authentication instanceof HaloOAuth2AuthenticationToken) {
// Skip handling if logging in with OAuth2
return Mono.empty();
if (!authenticationTrustResolver.isFullyAuthenticated(authentication)) {
// Should never happen
// Remove token directly if not fully authenticated
return oauth2TokenCache.removeToken(exchange).then();
}
return exchange.getSession()
.flatMap(session -> {
var oauth2TokenObject =
session.getAttribute(HaloOAuth2AuthenticationCacheFilter.CACHE_KEY);
if (!(oauth2TokenObject instanceof HaloOAuth2AuthenticationToken haloOAuth2Token)) {
return Mono.empty();
}
var oauth2User = haloOAuth2Token.getPrincipal();
return oauth2TokenCache.getToken(exchange)
.flatMap(oauth2Token -> {
var oauth2User = oauth2Token.getPrincipal();
var username = authentication.getName();
var registrationId = haloOAuth2Token.getAuthorizedClientRegistrationId();
var providerUserId = oauth2User.getName();
return connectionService.getAndUpdateUserConnection(registrationId, oauth2User)
var registrationId = oauth2Token.getAuthorizedClientRegistrationId();
return connectionService.getUserConnection(registrationId, oauth2User)
.doOnNext(connection -> {
if (log.isDebugEnabled()) {
log.debug(
Expand All @@ -54,14 +58,9 @@ public Mono<Void> loginSuccess(ServerWebExchange exchange, Authentication authen
username,
registrationId,
oauth2User
).doOnNext(connection -> {
log.info("Bound user {} to {} in registration {}",
username, providerUserId, registrationId
);
session.getAttributes().remove(OAuth2AuthenticationToken.class.getName());
})));
})
.then();
)))
.then(oauth2TokenCache.removeToken(exchange));
});
}

}
Loading

0 comments on commit 2edb28c

Please sign in to comment.