Skip to content

Commit

Permalink
Development: Enable bearer authentication (#9403)
Browse files Browse the repository at this point in the history
  • Loading branch information
janthoXO authored Dec 9, 2024
1 parent 390d00c commit 4ebb290
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.regex.Pattern;

import jakarta.annotation.Nullable;
import jakarta.servlet.http.Cookie;
import jakarta.validation.constraints.NotNull;

import org.slf4j.Logger;
Expand All @@ -28,6 +27,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
Expand All @@ -52,7 +52,6 @@
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import org.springframework.web.util.WebUtils;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Iterators;
Expand Down Expand Up @@ -201,9 +200,14 @@ public HandshakeInterceptor httpSessionHandshakeInterceptor() {
public boolean beforeHandshake(@NotNull ServerHttpRequest request, @NotNull ServerHttpResponse response, @NotNull WebSocketHandler wsHandler,
@NotNull Map<String, Object> attributes) {
if (request instanceof ServletServerHttpRequest servletRequest) {
attributes.put(IP_ADDRESS, servletRequest.getRemoteAddress());
Cookie jwtCookie = WebUtils.getCookie(servletRequest.getServletRequest(), JWTFilter.JWT_COOKIE_NAME);
return JWTFilter.isJwtCookieValid(tokenProvider, jwtCookie);
try {
attributes.put(IP_ADDRESS, servletRequest.getRemoteAddress());
return JWTFilter.extractValidJwt(servletRequest.getServletRequest(), tokenProvider) != null;
}
catch (IllegalArgumentException e) {
response.setStatusCode(HttpStatusCode.valueOf(400));
return false;
}
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import java.io.IOException;

import jakarta.annotation.Nullable;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
Expand All @@ -22,6 +24,10 @@ public class JWTFilter extends GenericFilterBean {

public static final String JWT_COOKIE_NAME = "jwt";

private static final String AUTHORIZATION_HEADER = "Authorization";

private static final String BEARER_PREFIX = "Bearer ";

private final TokenProvider tokenProvider;

public JWTFilter(TokenProvider tokenProvider) {
Expand All @@ -31,26 +37,89 @@ public JWTFilter(TokenProvider tokenProvider) {
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
Cookie jwtCookie = WebUtils.getCookie(httpServletRequest, JWT_COOKIE_NAME);
if (isJwtCookieValid(this.tokenProvider, jwtCookie)) {
Authentication authentication = this.tokenProvider.getAuthentication(jwtCookie.getValue());
HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
String jwtToken;
try {
jwtToken = extractValidJwt(httpServletRequest, this.tokenProvider);
}
catch (IllegalArgumentException e) {
httpServletResponse.sendError(HttpServletResponse.SC_BAD_REQUEST);
return;
}

if (jwtToken != null) {
Authentication authentication = this.tokenProvider.getAuthentication(jwtToken);
SecurityContextHolder.getContext().setAuthentication(authentication);
}

filterChain.doFilter(servletRequest, servletResponse);
}

/**
* Checks if the cookie containing the jwt is valid
* Extracts the valid jwt found in the cookie or the Authorization header
*
* @param tokenProvider the artemis token provider used to generate and validate jwt's
* @param jwtCookie the cookie containing the jwt
* @return true if the jwt is valid, false if missing or invalid
* @param httpServletRequest the http request
* @param tokenProvider the Artemis token provider used to generate and validate jwt's
* @return the valid jwt or null if not found or invalid
*/
public static boolean isJwtCookieValid(TokenProvider tokenProvider, Cookie jwtCookie) {
public static @Nullable String extractValidJwt(HttpServletRequest httpServletRequest, TokenProvider tokenProvider) {
var cookie = WebUtils.getCookie(httpServletRequest, JWT_COOKIE_NAME);
var authHeader = httpServletRequest.getHeader(AUTHORIZATION_HEADER);

if (cookie == null && authHeader == null) {
return null;
}

if (cookie != null && authHeader != null) {
// Single Method Enforcement: Only one method of authentication is allowed
throw new IllegalArgumentException("Multiple authentication methods detected: Both JWT cookie and Bearer token are present");
}

String jwtToken = cookie != null ? getJwtFromCookie(cookie) : getJwtFromBearer(authHeader);

if (!isJwtValid(tokenProvider, jwtToken)) {
return null;
}

return jwtToken;
}

/**
* Extracts the jwt from the cookie
*
* @param jwtCookie the cookie with Key "jwt"
* @return the jwt or null if not found
*/
private static @Nullable String getJwtFromCookie(@Nullable Cookie jwtCookie) {
if (jwtCookie == null) {
return false;
return null;
}
return jwtCookie.getValue();
}

/**
* Extracts the jwt from the Authorization header
*
* @param jwtBearer the content of the Authorization header
* @return the jwt or null if not found
*/
private static @Nullable String getJwtFromBearer(@Nullable String jwtBearer) {
if (!StringUtils.hasText(jwtBearer) || !jwtBearer.startsWith(BEARER_PREFIX)) {
return null;
}
String jwt = jwtCookie.getValue();
return StringUtils.hasText(jwt) && tokenProvider.validateTokenForAuthority(jwt);

String token = jwtBearer.substring(BEARER_PREFIX.length()).trim();
return StringUtils.hasText(token) ? token : null;
}

/**
* Checks if the jwt is valid
*
* @param tokenProvider the Artemis token provider used to generate and validate jwt's
* @param jwtToken the jwt
* @return true if the jwt is valid, false if missing or invalid
*/
private static boolean isJwtValid(TokenProvider tokenProvider, @Nullable String jwtToken) {
return StringUtils.hasText(jwtToken) && tokenProvider.validateTokenForAuthority(jwtToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ private Claims parseClaims(String authToken) {
return Jwts.parser().verifyWith(key).build().parseSignedClaims(authToken).getPayload();
}

public <T> T getClaim(String token, String claimName, Class<T> claimType) {
Claims claims = parseClaims(token);
return claims.get(claimName, claimType);
}

public Date getExpirationDate(String authToken) {
return parseClaims(authToken).getExpiration();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import java.util.Map;
import java.util.Optional;

import jakarta.servlet.ServletException;
Expand Down Expand Up @@ -69,7 +70,7 @@ public PublicUserJwtResource(JWTCookieService jwtCookieService, AuthenticationMa
*/
@PostMapping("authenticate")
@EnforceNothing
public ResponseEntity<Void> authorize(@Valid @RequestBody LoginVM loginVM, @RequestHeader("User-Agent") String userAgent, HttpServletResponse response) {
public ResponseEntity<Map<String, String>> authorize(@Valid @RequestBody LoginVM loginVM, @RequestHeader("User-Agent") String userAgent, HttpServletResponse response) {

var username = loginVM.getUsername();
var password = loginVM.getPassword();
Expand All @@ -86,7 +87,7 @@ public ResponseEntity<Void> authorize(@Valid @RequestBody LoginVM loginVM, @Requ
ResponseCookie responseCookie = jwtCookieService.buildLoginCookie(rememberMe);
response.addHeader(HttpHeaders.SET_COOKIE, responseCookie.toString());

return ResponseEntity.ok().build();
return ResponseEntity.ok(Map.of("access_token", responseCookie.getValue()));
}
catch (BadCredentialsException ex) {
log.warn("Wrong credentials during login for user {}", loginVM.getUsername());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.time.ZonedDateTime;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

Expand All @@ -26,6 +27,9 @@
import org.springframework.security.test.context.support.WithAnonymousUser;
import org.springframework.security.test.context.support.WithMockUser;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import de.tum.cit.aet.artemis.core.connector.GitlabRequestMockProvider;
import de.tum.cit.aet.artemis.core.domain.Authority;
import de.tum.cit.aet.artemis.core.domain.Course;
Expand All @@ -35,6 +39,7 @@
import de.tum.cit.aet.artemis.core.repository.AuthorityRepository;
import de.tum.cit.aet.artemis.core.security.Role;
import de.tum.cit.aet.artemis.core.security.SecurityUtils;
import de.tum.cit.aet.artemis.core.security.jwt.TokenProvider;
import de.tum.cit.aet.artemis.core.service.user.PasswordService;
import de.tum.cit.aet.artemis.core.util.CourseFactory;
import de.tum.cit.aet.artemis.programming.test_repository.ProgrammingExerciseTestRepository;
Expand All @@ -50,6 +55,9 @@ class InternalAuthenticationIntegrationTest extends AbstractSpringIntegrationJen
@Autowired
private PasswordService passwordService;

@Autowired
private TokenProvider tokenProvider;

@Autowired
private ProgrammingExerciseTestRepository programmingExerciseRepository;

Expand Down Expand Up @@ -223,6 +231,10 @@ void testJWTAuthentication() throws Exception {

MockHttpServletResponse response = request.postWithoutResponseBody("/api/public/authenticate", loginVM, HttpStatus.OK, httpHeaders);
AuthenticationIntegrationTestHelper.authenticationCookieAssertions(response.getCookie("jwt"), false);

var responseBody = new ObjectMapper().readValue(response.getContentAsString(), new TypeReference<Map<String, Object>>() {
});
assertThat(tokenProvider.validateTokenForAuthority(responseBody.get("access_token").toString())).isTrue();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void setup() {
}

@Test
void testJWTFilter() throws Exception {
void testJWTFilterCookie() throws Exception {
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("test-user", "test-password",
Collections.singletonList(new SimpleGrantedAuthority(Role.STUDENT.getAuthority())));
String jwt = tokenProvider.createToken(authentication, false);
Expand All @@ -61,6 +61,40 @@ void testJWTFilter() throws Exception {
assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("test-user");
}

@Test
void testJWTFilterBearer() throws Exception {
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("test-user", "test-password",
Collections.singletonList(new SimpleGrantedAuthority(Role.STUDENT.getAuthority())));

String jwt = tokenProvider.createToken(authentication, false);
MockHttpServletRequest request = new MockHttpServletRequest();
request.setCookies(new Cookie(JWTFilter.JWT_COOKIE_NAME, jwt));
request.addHeader("Authorization", "Bearer " + jwt);
request.setRequestURI("/api/test");

MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();
jwtFilter.doFilter(request, response, filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
}

@Test
void testJWTFilterCookieAndBearer() throws Exception {
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken("test-user", "test-password",
Collections.singletonList(new SimpleGrantedAuthority(Role.STUDENT.getAuthority())));

String jwt = tokenProvider.createToken(authentication, false);
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader("Authorization", "Bearer " + jwt);
request.setRequestURI("/api/test");

MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();
jwtFilter.doFilter(request, response, filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("test-user");
}

@Test
void testJWTFilterInvalidToken() throws Exception {
String jwt = "wrong_jwt";
Expand Down

0 comments on commit 4ebb290

Please sign in to comment.