Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into chore/editor-config
Browse files Browse the repository at this point in the history
  • Loading branch information
guqing committed Aug 2, 2024
2 parents 98c2444 + 2a03e03 commit b73ffd9
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 15 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ test {

halo {
version = '2.17'
debug = true
}
3 changes: 2 additions & 1 deletion src/main/java/run/halo/oauth/Oauth2Authenticator.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package run.halo.oauth;

import static org.apache.commons.lang3.StringUtils.defaultIfBlank;
import static org.apache.commons.lang3.StringUtils.defaultString;
import static run.halo.oauth.SocialServerOauth2AuthorizationRequestResolver.SOCIAL_CONNECTION;

Expand Down Expand Up @@ -170,7 +171,7 @@ private ServerAuthenticationSuccessHandler registrationPageHandler(String regist
Assert.notNull(oauth2User, "oauth2User cannot be null");

String loginName = oauth2User.getName();
String name = defaultString(oauth2User.getAttribute("name"), loginName);
String name = defaultIfBlank(oauth2User.getAttribute("name"), loginName);
MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
queryParams.add("login", loginName);
queryParams.add("name", name);
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/run/halo/oauth/Oauth2LoginConfiguration.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package run.halo.oauth;

import com.google.common.base.Throwables;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.DelegatingReactiveAuthenticationManager;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
Expand Down Expand Up @@ -45,6 +48,7 @@
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.stereotype.Component;
import org.springframework.util.ClassUtils;
import org.springframework.util.MultiValueMap;
import reactor.core.publisher.Mono;
import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.security.LoginHandlerEnhancer;
Expand All @@ -55,6 +59,7 @@
* @author guqing
* @since 1.0.0
*/
@Slf4j
@Getter
@Component
public final class Oauth2LoginConfiguration {
Expand Down Expand Up @@ -114,13 +119,43 @@ ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
@Override
public Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange,
AuthenticationException exception) {
var queryParams = webFilterExchange.getExchange().getRequest().getQueryParams();
var response = new OAuth2ErrorResponse(queryParams);
log.error("An error occurred while attempting to oauth2 authenticate: \n{}",
response, Throwables.getRootCause(exception));
return loginHandlerEnhancer.onLoginFailure(webFilterExchange.getExchange(),
exception)
.then(super.onAuthenticationFailure(webFilterExchange, exception));
}
};
}

@RequiredArgsConstructor
static class OAuth2ErrorResponse {
private final MultiValueMap<String, String> queryParams;

public String error() {
return queryParams.getFirst("error");
}

public String errorDescription() {
return queryParams.getFirst("error_description");
}

public String errorUri() {
return queryParams.getFirst("error_uri");
}

@Override
public String toString() {
return """
error: %s
error_description: %s
error_uri: %s
""".formatted(error(), errorDescription(), errorUri());
}
}

GrantedAuthoritiesMapper getAuthoritiesMapper() {
return new SimpleAuthorityMapper();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ public Mono<ClientRegistration> findByRegistrationId(String registrationId) {
Mono.error(new ProviderNotFoundException(
"Unsupported OAuth2 provider: " + registrationId)))
.flatMap(provider -> fetchEnabledProviders()
.map(enabledNames -> {
if (enabledNames.contains(registrationId)) {
return provider;
.doOnNext(enabledNames -> {
if (!enabledNames.contains(registrationId)) {
throw new OAuth2AuthenticationException(
"Authentication provider is not enabled: " + registrationId);
}
throw new OAuth2AuthenticationException(
"Authentication provider is not enabled: " + registrationId);
})
.thenReturn(provider)
)
.flatMap(this::getClientRegistrationMono);
}
Expand Down
34 changes: 25 additions & 9 deletions src/main/java/run/halo/oauth/UserConnectionServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebInputException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import run.halo.app.core.extension.UserConnection;
Expand Down Expand Up @@ -39,14 +40,22 @@ public Mono<UserConnection> createConnection(String username,

UserConnection connection = convert(username, authentication);
String providerUserId = authentication.getPrincipal().getName();
return fetchUserConnection(connection.getSpec().getRegistrationId(), providerUserId)
.flatMap(persisted -> {
connection.getMetadata().setName(persisted.getMetadata().getName());
connection.getMetadata()
.setVersion(persisted.getMetadata().getVersion());
return client.update(connection);
})
.switchIfEmpty(Mono.defer(() -> client.create(connection)));
return findByRegistrationId(connection.getSpec().getRegistrationId())
.hasElement()
.flatMap(exists -> {
if (exists) {
return Mono.error(new ServerWebInputException(
"已经绑定过 " + connection.getSpec().getRegistrationId() + " 账号,请先解绑"));
}
return fetchUserConnection(connection.getSpec().getRegistrationId(), providerUserId)
.flatMap(persisted -> {
connection.getMetadata().setName(persisted.getMetadata().getName());
connection.getMetadata()
.setVersion(persisted.getMetadata().getVersion());
return client.update(connection);
})
.switchIfEmpty(Mono.defer(() -> client.create(connection)));
});
}

@Override
Expand Down Expand Up @@ -81,6 +90,12 @@ Flux<UserConnection> listByRegistrationIdAndUsername(String registrationId, Stri
&& persisted.getSpec().getUsername().equals(username), null);
}

private Mono<UserConnection> findByRegistrationId(String registrationId) {
return client.list(UserConnection.class,
persisted -> persisted.getSpec().getRegistrationId().equals(registrationId), null)
.next();
}

private Mono<UserConnection> fetchUserConnection(String registrationId, String providerUserId) {
return client.list(UserConnection.class, persisted -> persisted.getSpec()
.getProviderUserId().equals(providerUserId)
Expand Down Expand Up @@ -111,7 +126,8 @@ UserConnection convert(String username, OAuth2LoginAuthenticationToken authentic

Oauth2UserProfile oauth2UserProfile =
oauth2UserProfileMapperManager.mapProfile(registrationId, oauth2User);
spec.setDisplayName(oauth2UserProfile.getDisplayName());
var displayName = StringUtils.defaultIfBlank(oauth2UserProfile.getDisplayName(), username);
spec.setDisplayName(displayName);
spec.setAvatarUrl(oauth2UserProfile.getAvatarUrl());
spec.setProfileUrl(oauth2UserProfile.getProfileUrl());
return userConnection;
Expand Down

0 comments on commit b73ffd9

Please sign in to comment.