Skip to content

Commit

Permalink
Add allowAllOrigins option (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
markelliot authored Mar 2, 2022
1 parent 8d1b5fb commit 9a286ad
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
13 changes: 12 additions & 1 deletion barista/src/main/java/com/markelliot/barista/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public static final class Builder {
private final Set<String> allowedOrigins = new LinkedHashSet<>();
private SerDe serde = new SerDe.ObjectMapperSerDe();
private Authz authz = Authz.denyAll();
private boolean allowAllOrigins = false;
private boolean tls = true;
private double tracingRate = 0.2;

Expand Down Expand Up @@ -111,6 +112,11 @@ public Builder allowOrigin(String origin) {
return this;
}

public Builder allowAllOrigins() {
allowAllOrigins = true;
return this;
}

public Builder serde(SerDe serde) {
Objects.requireNonNull(serde);
this.serde = serde;
Expand Down Expand Up @@ -157,7 +163,12 @@ public Server start() {
Undertow.builder()
.setHandler(
HandlerChain.of(DispatchFromIoThreadHandler::new)
.then(h -> new CorsHandler(allowedOrigins, h))
.then(
h ->
new CorsHandler(
allowAllOrigins,
allowedOrigins,
h))
.then(h -> new TracingHandler(tracingRate, h))
.last(
handler.build(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import io.undertow.util.HttpString;
import io.undertow.util.Methods;
import java.util.Set;
import java.util.function.Predicate;

/** An {@link HttpHandler} that returns reasonable CORS allow headers for {@code allowedOrigins}. */
public record CorsHandler(Set<String> allowedOrigins, HttpHandler delegate) implements HttpHandler {
public final class CorsHandler implements HttpHandler {
private static final HttpString ACCESS_CONTROL_ALLOW_ORIGIN =
new HttpString("Access-Control-Allow-Origin");
private static final String ORIGIN_ALL = "*";
Expand All @@ -43,10 +44,21 @@ public record CorsHandler(Set<String> allowedOrigins, HttpHandler delegate) impl
private static final HttpString ACCESS_CONTROL_ALLOW_HEADERS =
new HttpString("Access-Control-Allow-Headers");

private final Set<String> allowedOrigins;
private final HttpHandler delegate;
private final Predicate<String> originCheck;

public CorsHandler(boolean allowAllOrigins, Set<String> allowedOrigins, HttpHandler delegate) {
this.originCheck = allowAllOrigins ? origin -> true : this::checkAllowedOrigin;
this.allowedOrigins = allowedOrigins;
this.delegate = delegate;
}

@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
String origin = exchange.getRequestHeaders().getFirst(Headers.ORIGIN);
if (origin != null && !allowedOrigins.contains(origin)) {
if (!originCheck.test(origin)) {
// not an allowed origin, hard deny
exchange.setStatusCode(403)
.getResponseSender()
.send("Origin '" + origin + "' not allowed.");
Expand All @@ -70,4 +82,9 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
delegate.handleRequest(exchange);
}
}

/** Returns true if origin is null or if origin is in the allowedOrigins set. */
private boolean checkAllowedOrigin(String origin) {
return origin == null || allowedOrigins.contains(origin);
}
}

0 comments on commit 9a286ad

Please sign in to comment.