Skip to content

Commit

Permalink
Support binding OAuth2 user automatically
Browse files Browse the repository at this point in the history
Signed-off-by: JohnNiang <[email protected]>
  • Loading branch information
JohnNiang committed Sep 29, 2024
1 parent ca9adfc commit a286769
Show file tree
Hide file tree
Showing 36 changed files with 889 additions and 136 deletions.
31 changes: 2 additions & 29 deletions api/src/main/java/run/halo/app/core/extension/UserConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,9 @@ public static class UserConnectionSpec {
private String providerUserId;

/**
* The display name for the user's connection to the OAuth provider.
* The time when the user connection was last updated.
*/
@Schema(requiredMode = REQUIRED)
private String displayName;

/**
* The URL to the user's profile page on the OAuth provider.
* For example, the user's GitHub profile URL.
*/
private String profileUrl;

/**
* The URL to the user's avatar image on the OAuth provider.
* For example, the user's GitHub avatar URL.
*/
private String avatarUrl;

/**
* The access token provided by the OAuth provider.
*/
@Schema(requiredMode = REQUIRED)
private String accessToken;

/**
* The refresh token provided by the OAuth provider (if applicable).
*/
private String refreshToken;

private Instant expiresAt;

private Instant updatedAt;

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package run.halo.app.security;

import org.pf4j.ExtensionPoint;
import org.springframework.web.server.WebFilter;

/**
* Security web filter for HTTP basic.
*
* @author johnniang
* @since 2.20.0
*/
public interface HttpBasicSecurityWebFilter extends WebFilter, ExtensionPoint {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package run.halo.app.security;

import org.pf4j.ExtensionPoint;
import org.springframework.web.server.WebFilter;

/**
* Security web filter for OAuth2 authorization code.
*
* @author johnniang
* @since 2.20.0
*/
public interface OAuth2AuthorizationCodeSecurityWebFilter extends WebFilter, ExtensionPoint {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package run.halo.app.security.authentication.oauth2;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import lombok.Getter;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.core.user.OAuth2User;

/**
* Halo OAuth2 authentication token which combines {@link UserDetails} and original
* {@link OAuth2AuthenticationToken}.
*
* @author johnniang
* @since 2.20.0
*/
public class HaloOAuth2AuthenticationToken extends AbstractAuthenticationToken {

@Getter
private final UserDetails userDetails;

@Getter
private final OAuth2AuthenticationToken original;

/**
* Constructs an {@code HaloOAuth2AuthenticationToken} using {@link UserDetails} and original
* {@link OAuth2AuthenticationToken}.
*
* @param userDetails the {@link UserDetails}
* @param original the original {@link OAuth2AuthenticationToken}
*/
public HaloOAuth2AuthenticationToken(UserDetails userDetails,
OAuth2AuthenticationToken original) {
super(combineAuthorities(userDetails, original));
this.userDetails = userDetails;
this.original = original;
setAuthenticated(true);
}

@Override
public String getName() {
return userDetails.getUsername();
}

@Override
public Collection<GrantedAuthority> getAuthorities() {
var originalAuthorities = super.getAuthorities();
var userDetailsAuthorities = getUserDetails().getAuthorities();
var authorities = new ArrayList<GrantedAuthority>(
originalAuthorities.size() + userDetailsAuthorities.size()
);
authorities.addAll(originalAuthorities);
authorities.addAll(userDetailsAuthorities);
return Collections.unmodifiableList(authorities);
}

@Override
public Object getCredentials() {
return "";
}

@Override
public OAuth2User getPrincipal() {
return original.getPrincipal();
}

/**
* Creates an authenticated {@link HaloOAuth2AuthenticationToken} using {@link UserDetails} and
* original {@link OAuth2AuthenticationToken}.
*
* @param userDetails the {@link UserDetails}
* @param original the original {@link OAuth2AuthenticationToken}
* @return an authenticated {@link HaloOAuth2AuthenticationToken}
*/
public static HaloOAuth2AuthenticationToken authenticated(
UserDetails userDetails, OAuth2AuthenticationToken original
) {
return new HaloOAuth2AuthenticationToken(userDetails, original);
}

private static Collection<? extends GrantedAuthority> combineAuthorities(
UserDetails userDetails, OAuth2AuthenticationToken original) {
var userDetailsAuthorities = userDetails.getAuthorities();
var originalAuthorities = original.getAuthorities();
var authorities = new ArrayList<GrantedAuthority>(
originalAuthorities.size() + userDetailsAuthorities.size()
);
authorities.addAll(originalAuthorities);
authorities.addAll(userDetailsAuthorities);
return Collections.unmodifiableList(authorities);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
import run.halo.app.infra.exception.RateLimitExceededException;
import run.halo.app.infra.exception.UnsatisfiedAttributeValueException;
import run.halo.app.infra.utils.JsonUtils;
import run.halo.app.security.authentication.twofactor.TwoFactorAuthentication;

@Component
@RequiredArgsConstructor
Expand Down Expand Up @@ -600,7 +599,7 @@ record ChangePasswordRequest(
Mono<ServerResponse> me(ServerRequest request) {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.filter(auth -> !(auth instanceof TwoFactorAuthentication))
.filter(Authentication::isAuthenticated)
.flatMap(auth -> userService.getUser(auth.getName())
.flatMap(user -> {
var roleNames = authoritiesToRoles(auth.getAuthorities());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package run.halo.app.core.user.service;

import org.springframework.security.oauth2.core.user.OAuth2User;
import reactor.core.publisher.Mono;
import run.halo.app.core.extension.UserConnection;

public interface UserConnectionService {

/**
* Create user connection.
*
* @param username Username
* @param registrationId Registration id
* @param oauth2User OAuth2 user
* @return Created user connection
*/
Mono<UserConnection> createUserConnection(
String username,
String registrationId,
OAuth2User oauth2User
);

/**
* Get user connection by registration id and OAuth2 user.
* If found, update updatedAt timestamp of the user connection.
*
* @param registrationId Registration id
* @param oauth2User OAuth2 user
* @return Updated user connection or empty
*/
Mono<UserConnection> getUserConnection(
String registrationId, OAuth2User oauth2User
);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package run.halo.app.core.user.service.impl;

import static run.halo.app.extension.ExtensionUtil.defaultSort;
import static run.halo.app.extension.index.query.QueryFactory.and;
import static run.halo.app.extension.index.query.QueryFactory.equal;

import java.time.Clock;
import java.util.HashMap;
import java.util.Optional;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
import run.halo.app.core.extension.UserConnection;
import run.halo.app.core.extension.UserConnection.UserConnectionSpec;
import run.halo.app.core.user.service.UserConnectionService;
import run.halo.app.extension.ListOptions;
import run.halo.app.extension.Metadata;
import run.halo.app.extension.MetadataOperator;
import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.infra.utils.JsonUtils;

@Service
public class UserConnectionServiceImpl implements UserConnectionService {

private final ReactiveExtensionClient client;

private Clock clock = Clock.systemDefaultZone();

public UserConnectionServiceImpl(ReactiveExtensionClient client) {
this.client = client;
}

void setClock(Clock clock) {
this.clock = clock;
}

@Override
public Mono<UserConnection> createUserConnection(
String username,
String registrationId,
OAuth2User oauth2User
) {
var connection = new UserConnection();
connection.setMetadata(new Metadata());
var metadata = connection.getMetadata();
updateUserInfo(metadata, oauth2User);
metadata.setGenerateName(username + "-");
connection.setSpec(new UserConnectionSpec());
var spec = connection.getSpec();
spec.setUsername(username);
spec.setProviderUserId(oauth2User.getName());
spec.setRegistrationId(registrationId);
spec.setUpdatedAt(clock.instant());
return client.create(connection);
}

private Mono<UserConnection> updateUserConnection(UserConnection connection,
OAuth2User oauth2User) {
connection.getSpec().setUpdatedAt(clock.instant());
updateUserInfo(connection.getMetadata(), oauth2User);
return client.update(connection);
}

@Override
public Mono<UserConnection> getUserConnection(String registrationId,
OAuth2User oauth2User) {
var listOptions = ListOptions.builder()
.fieldQuery(and(
equal("spec.registrationId", registrationId),
equal("spec.providerUserId", oauth2User.getName())
))
.build();
return client.listAll(UserConnection.class, listOptions, defaultSort()).next()
.flatMap(connection -> updateUserConnection(connection, oauth2User));
}

private void updateUserInfo(MetadataOperator metadata, OAuth2User oauth2User) {
var annotations = Optional.ofNullable(metadata.getAnnotations())
.orElseGet(HashMap::new);
metadata.setAnnotations(annotations);
annotations.put(
"auth.halo.run/oauth2-user-info",
JsonUtils.objectToJson(oauth2User.getAttributes())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,22 @@ public void onApplicationEvent(@NonNull ApplicationContextInitializedEvent event
.map(UserConnectionSpec::getUsername)
.orElse(null)
)));
is.add(new IndexSpec()
.setName("spec.registrationId")
.setIndexFunc(simpleAttribute(UserConnection.class,
connection -> Optional.ofNullable(connection.getSpec())
.map(UserConnectionSpec::getRegistrationId)
.orElse(null)
))
);
is.add(new IndexSpec()
.setName("spec.providerUserId")
.setIndexFunc(simpleAttribute(UserConnection.class,
connection -> Optional.ofNullable(connection.getSpec())
.map(UserConnectionSpec::getProviderUserId)
.orElse(null)
))
);
});

// security.halo.run
Expand Down
Loading

0 comments on commit a286769

Please sign in to comment.