Skip to content

Commit

Permalink
refactor: enforce single OAuth2 account binding per platform (#59)
Browse files Browse the repository at this point in the history
### What this PR does?
限制每个平台的 OAuth2 帐户一次只能连接到一个用户帐户

/kind improvement

```release-note
限制每个平台的 OAuth2 帐户一次只能连接到一个用户帐户
```
  • Loading branch information
guqing authored Aug 2, 2024
1 parent 00e42ea commit 2a03e03
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/main/java/run/halo/oauth/Oauth2Authenticator.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package run.halo.oauth;

import static org.apache.commons.lang3.StringUtils.defaultIfBlank;
import static org.apache.commons.lang3.StringUtils.defaultString;
import static run.halo.oauth.SocialServerOauth2AuthorizationRequestResolver.SOCIAL_CONNECTION;

Expand Down Expand Up @@ -170,7 +171,7 @@ private ServerAuthenticationSuccessHandler registrationPageHandler(String regist
Assert.notNull(oauth2User, "oauth2User cannot be null");

String loginName = oauth2User.getName();
String name = defaultString(oauth2User.getAttribute("name"), loginName);
String name = defaultIfBlank(oauth2User.getAttribute("name"), loginName);
MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
queryParams.add("login", loginName);
queryParams.add("name", name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ public Mono<ClientRegistration> findByRegistrationId(String registrationId) {
Mono.error(new ProviderNotFoundException(
"Unsupported OAuth2 provider: " + registrationId)))
.flatMap(provider -> fetchEnabledProviders()
.map(enabledNames -> {
if (enabledNames.contains(registrationId)) {
return provider;
.doOnNext(enabledNames -> {
if (!enabledNames.contains(registrationId)) {
throw new OAuth2AuthenticationException(
"Authentication provider is not enabled: " + registrationId);
}
throw new OAuth2AuthenticationException(
"Authentication provider is not enabled: " + registrationId);
})
.thenReturn(provider)
)
.flatMap(this::getClientRegistrationMono);
}
Expand Down
31 changes: 23 additions & 8 deletions src/main/java/run/halo/oauth/UserConnectionServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebInputException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import run.halo.app.core.extension.UserConnection;
Expand Down Expand Up @@ -39,14 +40,22 @@ public Mono<UserConnection> createConnection(String username,

UserConnection connection = convert(username, authentication);
String providerUserId = authentication.getPrincipal().getName();
return fetchUserConnection(connection.getSpec().getRegistrationId(), providerUserId)
.flatMap(persisted -> {
connection.getMetadata().setName(persisted.getMetadata().getName());
connection.getMetadata()
.setVersion(persisted.getMetadata().getVersion());
return client.update(connection);
})
.switchIfEmpty(Mono.defer(() -> client.create(connection)));
return findByRegistrationId(connection.getSpec().getRegistrationId())
.hasElement()
.flatMap(exists -> {
if (exists) {
return Mono.error(new ServerWebInputException(
"已经绑定过 " + connection.getSpec().getRegistrationId() + " 账号,请先解绑"));
}
return fetchUserConnection(connection.getSpec().getRegistrationId(), providerUserId)
.flatMap(persisted -> {
connection.getMetadata().setName(persisted.getMetadata().getName());
connection.getMetadata()
.setVersion(persisted.getMetadata().getVersion());
return client.update(connection);
})
.switchIfEmpty(Mono.defer(() -> client.create(connection)));
});
}

@Override
Expand Down Expand Up @@ -81,6 +90,12 @@ Flux<UserConnection> listByRegistrationIdAndUsername(String registrationId, Stri
&& persisted.getSpec().getUsername().equals(username), null);
}

private Mono<UserConnection> findByRegistrationId(String registrationId) {
return client.list(UserConnection.class,
persisted -> persisted.getSpec().getRegistrationId().equals(registrationId), null)
.next();
}

private Mono<UserConnection> fetchUserConnection(String registrationId, String providerUserId) {
return client.list(UserConnection.class, persisted -> persisted.getSpec()
.getProviderUserId().equals(providerUserId)
Expand Down

0 comments on commit 2a03e03

Please sign in to comment.