Skip to content

Commit

Permalink
[HWORKS-527] fix jupyter jwt monitor for HA (#1444)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErmiasG authored Jun 13, 2023
1 parent 42c330c commit 83e9851
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package io.hops.hopsworks.common.jupyter;

import java.io.Serializable;
import java.util.Objects;

public class CidAndPort {
public class CidAndPort implements Serializable {
private static final long serialVersionUID = -7736027812979433344L;
String cid;
Integer port;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ public JupyterJWT(Project project, Users user, LocalDateTime expiration, CidAndP
this.pidAndPort = pidAndPort;
}

public JupyterJWT(Project project, Users user, LocalDateTime expiration, CidAndPort pidAndPort, String token,
Path tokenFile) {
super(project, user, expiration);
this.pidAndPort = pidAndPort;
this.tokenFile = tokenFile;
this.token = token;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* This file is part of Hopsworks
* Copyright (C) 2023, Hopsworks AB. All rights reserved
*
* Hopsworks is free software: you can redistribute it and/or modify it under the terms of
* the GNU Affero General Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*
* Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
* PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see <https://www.gnu.org/licenses/>.
*/
package io.hops.hopsworks.common.jupyter;

import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.map.IMap;
import com.hazelcast.query.Predicate;
import com.hazelcast.query.Predicates;
import io.hops.hopsworks.common.util.DateUtils;

import javax.ejb.Singleton;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import javax.inject.Inject;
import java.nio.file.Paths;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;

@Singleton
@TransactionAttribute(TransactionAttributeType.NEVER)
public class JupyterJWTCache {
private static final String MAP_NAME = "jupyterJWTMap";
@Inject
private HazelcastInstance hazelcastInstance;

private final TreeSet<JupyterJWTDTO> jupyterJWTs = new TreeSet<>((t0, t1) -> {
if (t0.equals(t1)) {
return 0;
} else {
if (t0.getExpiration().isBefore(t1.getExpiration())) {
return -1;
} else if (t0.getExpiration().isAfter(t1.getExpiration())) {
return 1;
}
return 0;
}
});

private final HashMap<CidAndPort, JupyterJWT> pidAndPortToJWT = new HashMap<>();

public void add(JupyterJWT jupyterJWT) {
if (hazelcastInstance != null) {
IMap<CidAndPort, JupyterJWTDTO> pidAndPortToJWTMap = hazelcastInstance.getMap(MAP_NAME);
pidAndPortToJWTMap.put(jupyterJWT.pidAndPort, new JupyterJWTDTO(jupyterJWT));
} else {
jupyterJWTs.add(new JupyterJWTDTO(jupyterJWT));
pidAndPortToJWT.put(jupyterJWT.pidAndPort, jupyterJWT);
}
}

public Optional<JupyterJWT> get(CidAndPort pidAndPort) {
if (hazelcastInstance != null) {
IMap<CidAndPort, JupyterJWTDTO> pidAndPortToJWTMap = hazelcastInstance.getMap(MAP_NAME);
JupyterJWTDTO jupyterJWTDTO = pidAndPortToJWTMap.get(pidAndPort);
if (jupyterJWTDTO != null) {
return Optional.of(
new JupyterJWT(jupyterJWTDTO.getProject(), jupyterJWTDTO.getUser(), jupyterJWTDTO.getExpiration(),
pidAndPort, jupyterJWTDTO.getToken(), Paths.get(jupyterJWTDTO.getTokenFile())));
}
return Optional.empty();
} else {
return Optional.ofNullable(pidAndPortToJWT.get(pidAndPort));
}
}

public void remove(CidAndPort pidAndPort) {
if (hazelcastInstance != null) {
IMap<CidAndPort, JupyterJWTDTO> pidAndPortToJWTMap = hazelcastInstance.getMap(MAP_NAME);
pidAndPortToJWTMap.remove(pidAndPort);
} else {
JupyterJWT jupyterJWT = pidAndPortToJWT.remove(pidAndPort);
jupyterJWTs.remove(new JupyterJWTDTO(jupyterJWT));
}
}

public void replaceAll(Set<JupyterJWT> renewedJWTs) {
if (hazelcastInstance != null) {
IMap<CidAndPort, JupyterJWTDTO> pidAndPortToJWTMap = hazelcastInstance.getMap(MAP_NAME);
renewedJWTs.forEach(t -> pidAndPortToJWTMap.replace(t.pidAndPort, new JupyterJWTDTO(t)));
} else {
renewedJWTs.forEach(t -> {
//remove old token
JupyterJWT jupyterJWT = pidAndPortToJWT.remove(t.pidAndPort);
jupyterJWTs.remove(new JupyterJWTDTO(jupyterJWT));
//Add the new token
jupyterJWTs.add(new JupyterJWTDTO(t));
pidAndPortToJWT.put(t.pidAndPort, t);
});
}
}

public int getSize() {
if (hazelcastInstance != null) {
IMap<CidAndPort, JupyterJWTDTO> pidAndPortToJWTMap = hazelcastInstance.getMap(MAP_NAME);
return pidAndPortToJWTMap.size();
} else {
return jupyterJWTs.size();
}
}

public Iterator<JupyterJWTDTO> getMaybeExpired() {
if (hazelcastInstance != null) {
IMap<CidAndPort, JupyterJWTDTO> pidAndPortToJWTMap = hazelcastInstance.getMap(MAP_NAME);
Predicate<CidAndPort, JupyterJWTDTO> expirationPredicate = Predicates.lessEqual("expiration", DateUtils.getNow());
Collection<JupyterJWTDTO> jupyterJWTDTOS = pidAndPortToJWTMap.values(expirationPredicate);
return jupyterJWTDTOS.iterator();
} else {
return jupyterJWTs.iterator();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* This file is part of Hopsworks
* Copyright (C) 2023, Hopsworks AB. All rights reserved
*
* Hopsworks is free software: you can redistribute it and/or modify it under the terms of
* the GNU Affero General Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*
* Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
* PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see <https://www.gnu.org/licenses/>.
*/
package io.hops.hopsworks.common.jupyter;

import io.hops.hopsworks.common.util.DateUtils;
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.persistence.entity.user.Users;

import java.io.Serializable;
import java.time.LocalDateTime;
import java.util.Objects;

public class JupyterJWTDTO implements Serializable {
private static final long serialVersionUID = -5687462769985361531L;
private Project project;
private Users user;
private LocalDateTime expiration;
private String token;
private String tokenFile;
private final CidAndPort pidAndPort;

public JupyterJWTDTO(JupyterJWT jupyterJWT) {
this.project = jupyterJWT.project;
this.user = jupyterJWT.user;
this.expiration = jupyterJWT.expiration;
this.token = jupyterJWT.token;
this.tokenFile = jupyterJWT.tokenFile.toString();
this.pidAndPort = jupyterJWT.pidAndPort;
}

public Project getProject() {
return project;
}

public void setProject(Project project) {
this.project = project;
}

public Users getUser() {
return user;
}

public void setUser(Users user) {
this.user = user;
}

public LocalDateTime getExpiration() {
return expiration;
}

public void setExpiration(LocalDateTime expiration) {
this.expiration = expiration;
}

public String getToken() {
return token;
}

public void setToken(String token) {
this.token = token;
}

public String getTokenFile() {
return tokenFile;
}

public void setTokenFile(String tokenFile) {
this.tokenFile = tokenFile;
}

public CidAndPort getPidAndPort() {
return pidAndPort;
}

public boolean maybeRenew(LocalDateTime now) {
return now.isAfter(expiration) || now.isEqual(expiration);
}

public boolean isExpired() {
LocalDateTime now = DateUtils.getNow();
return now.isAfter(expiration) || now.isEqual(expiration);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
JupyterJWTDTO that = (JupyterJWTDTO) o;
return Objects.equals(project.getId(), that.project.getId()) && Objects.equals(user.getUid(), that.user.getUid());
}

@Override
public int hashCode() {
return Objects.hash(project.getId(), user.getUid());
}
}
Loading

0 comments on commit 83e9851

Please sign in to comment.