Skip to content

Commit

Permalink
add expiration_time
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 committed Jun 3, 2024
1 parent 2f50d7d commit 60fce37
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.doris.common.classloader;

import org.apache.doris.common.jni.utils.ExpiringMap;

import com.google.common.collect.Streams;
import org.apache.log4j.Logger;

Expand All @@ -33,7 +35,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
Expand All @@ -44,7 +45,7 @@
public class ScannerLoader {
public static final Logger LOG = Logger.getLogger(ScannerLoader.class);
private static final Map<String, Class<?>> loadedClasses = new HashMap<>();
private static final Map<String, ClassLoader> udfLoadedClasses = new ConcurrentHashMap<>();
private static final ExpiringMap<String, ClassLoader> udfLoadedClasses = new ExpiringMap<String, ClassLoader>();
private static final String CLASS_SUFFIX = ".class";
private static final String LOAD_PACKAGE = "org.apache.doris";

Expand All @@ -65,22 +66,18 @@ public void loadAllScannerJars() {
}

public static ClassLoader getUdfClassLoader(String functionSignature) {
if (udfLoadedClasses.containsKey(functionSignature)) {
return udfLoadedClasses.get(functionSignature);
}
return null;
return udfLoadedClasses.get(functionSignature);
}

public static synchronized void cacheClassLoader(String functionSignature, ClassLoader classLoader) {
public static synchronized void cacheClassLoader(String functionSignature, ClassLoader classLoader,
long expirationTime) {
LOG.info("cacheClassLoader for: " + functionSignature);
udfLoadedClasses.put(functionSignature, classLoader);
udfLoadedClasses.put(functionSignature, classLoader, expirationTime * 60 * 1000L);
}

public synchronized void removeUdfClassLoader(String functionSignature) {
if (udfLoadedClasses.containsKey(functionSignature)) {
LOG.info("removeUdfClassLoader for: " + functionSignature);
udfLoadedClasses.remove(functionSignature);
}
LOG.info("removeUdfClassLoader for: " + functionSignature);
udfLoadedClasses.remove(functionSignature);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.common.jni.utils;

import org.apache.log4j.Logger;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

public class ExpiringMap<K, V> {
private final ConcurrentHashMap<K, V> map = new ConcurrentHashMap<>(); // key --> value
private final ConcurrentHashMap<K, Long> ttlMap = new ConcurrentHashMap<>(); // key --> ttl interval
// key --> expirationTime(ttl interval + currentTimeMillis)
private final ConcurrentHashMap<K, Long> expirationMap = new ConcurrentHashMap<>();
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
private static final long DEFAULT_INTERVAL_TIME = 10 * 60 * 1000L; // 10 minutes
public static final Logger LOG = Logger.getLogger(ExpiringMap.class);

public ExpiringMap() {
startExpirationTask();
}

public void put(K key, V value, long expirationTimeMs) {
long expirationTime = System.currentTimeMillis() + expirationTimeMs;
map.put(key, value);
expirationMap.put(key, expirationTime);
ttlMap.put(key, expirationTimeMs);
}

public V get(K key) {
Long expirationTime = expirationMap.get(key);
if (expirationTime == null || System.currentTimeMillis() > expirationTime) {
map.remove(key);
expirationMap.remove(key);
ttlMap.remove(key);
return null;
}
// reset time again
long ttl = ttlMap.get(key);
long newExpirationTime = System.currentTimeMillis() + ttl;
expirationMap.put(key, newExpirationTime);
return map.get(key);
}

private void startExpirationTask() {
scheduler.scheduleAtFixedRate(() -> {
long now = System.currentTimeMillis();
for (K key : expirationMap.keySet()) {
if (expirationMap.get(key) <= now) {
map.remove(key);
expirationMap.remove(key);
ttlMap.remove(key);
}
}
}, DEFAULT_INTERVAL_TIME, DEFAULT_INTERVAL_TIME, TimeUnit.MINUTES);
}

public void remove(K key) {
map.remove(key);
expirationMap.remove(key);
ttlMap.remove(key);
}

public void shutdown() {
scheduler.shutdown();
try {
if (!scheduler.awaitTermination(60, TimeUnit.SECONDS)) {
scheduler.shutdownNow();
}
} catch (InterruptedException e) {
scheduler.shutdownNow();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private Method findPrepareMethod(Method[] methods) {
return null; // Method not found
}

public ClassLoader getClassLoader(String jarPath, String signature)
public ClassLoader getClassLoader(String jarPath, String signature, long expirationTimeMs)
throws MalformedURLException, FileNotFoundException {
ClassLoader loader = null;
if (jarPath == null) {
Expand All @@ -155,7 +155,7 @@ public ClassLoader getClassLoader(String jarPath, String signature)
classLoader = UdfUtils.getClassLoader(jarPath, parent);
loader = classLoader;
if (isStaticLoad) {
ScannerLoader.cacheClassLoader(signature, loader);
ScannerLoader.cacheClassLoader(signature, loader, expirationTimeMs);
}
}
}
Expand All @@ -174,7 +174,11 @@ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type fun
LOG.debug("Loading UDF '" + className + "' from " + jarPath);
}
isStaticLoad = request.getFn().isSetIsStaticLoad() && request.getFn().is_static_load;
ClassLoader loader = getClassLoader(jarPath, request.getFn().getSignature());
long expirationTimeMs = 360L; // default is 6 hours
if (request.getFn().isSetExpirationTime()) {
expirationTimeMs = request.getFn().getExpirationTime();
}
ClassLoader loader = getClassLoader(jarPath, request.getFn().getSignature(), expirationTimeMs);
Class<?> c = Class.forName(className, true, loader);
methodAccess = MethodAccess.get(c);
Constructor<?> ctor = c.getConstructor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public class CreateFunctionStmt extends DdlStmt {
public static final String IS_RETURN_NULL = "always_nullable";
// iff is static load, BE will be cache the udf class load, so only need load once
public static final String IS_STATIC_LOAD = "static_load";
public static final String EXPIRATION_TIME = "expiration_time";
private static final Logger LOG = LogManager.getLogger(CreateFunctionStmt.class);

private SetType type = SetType.DEFAULT;
Expand All @@ -121,6 +122,7 @@ public class CreateFunctionStmt extends DdlStmt {
private Function function;
private String checksum = "";
private boolean isStaticLoad = false;
private long expirationTime = 360; // default 6 hours = 360 minutes
// now set udf default NullableMode is ALWAYS_NULLABLE
// if not, will core dump when input is not null column, but need return null
// like https://github.com/apache/doris/pull/14002/files
Expand Down Expand Up @@ -293,6 +295,14 @@ private void analyzeCommon(Analyzer analyzer) throws AnalysisException {
if (staticLoad != null && staticLoad) {
isStaticLoad = true;
}
String expirationTimeString = properties.get(EXPIRATION_TIME);
if (expirationTimeString != null) {
long timeMinutes = Long.parseLong(expirationTimeString);
if (timeMinutes <= 0) {
throw new AnalysisException("expirationTime should greater than zero: ");
}
this.expirationTime = timeMinutes;
}
}
}

Expand Down Expand Up @@ -448,6 +458,7 @@ private void analyzeUdf() throws AnalysisException {
function.setChecksum(checksum);
function.setNullableMode(returnNullMode);
function.setStaticLoad(isStaticLoad);
function.setExpirationTime(expirationTime);
}

private void analyzeJavaUdaf(String clazz) throws AnalysisException {
Expand Down
13 changes: 13 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public enum NullableMode {
protected boolean isUDTFunction = false;
// iff true, this udf function is static load, and BE need cache class load.
protected boolean isStaticLoad = false;
protected long expirationTime = 360; // default 6 hours;

// Only used for serialization
protected Function() {
Expand Down Expand Up @@ -203,6 +204,7 @@ public Function(Function other) {
this.isGlobal = other.isGlobal;
this.isUDTFunction = other.isUDTFunction;
this.isStaticLoad = other.isStaticLoad;
this.expirationTime = other.expirationTime;
}

public void setNestedFunction(Function nestedFunction) {
Expand Down Expand Up @@ -572,6 +574,7 @@ public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] re
fn.setVectorized(vectorized);
fn.setIsUdtfFunction(isUDTFunction);
fn.setIsStaticLoad(isStaticLoad);
fn.setExpirationTime(expirationTime);
return fn;
}

Expand Down Expand Up @@ -682,6 +685,7 @@ protected void writeFields(DataOutput output) throws IOException {
output.writeUTF(nullableMode.toString());
output.writeBoolean(isUDTFunction);
output.writeBoolean(isStaticLoad);
output.writeLong(expirationTime);
}

@Override
Expand Down Expand Up @@ -724,6 +728,7 @@ public void readFields(DataInput input) throws IOException {
}
if (Env.getCurrentEnvJournalVersion() >= FeMetaVersion.VERSION_134) {
isStaticLoad = input.readBoolean();
expirationTime = input.readLong();
}
}

Expand Down Expand Up @@ -812,6 +817,14 @@ public boolean isStaticLoad() {
return this.isStaticLoad;
}

public void setExpirationTime(long expirationTime) {
this.expirationTime = expirationTime;
}

public long getExpirationTime() {
return this.expirationTime;
}

// Try to serialize this function and write to nowhere.
// Just for checking if we forget to implement write() method for some Exprs.
// To avoid FE exist when writing edit log.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ public class JavaUdf extends ScalarFunction implements ExplicitlyCastableSignatu
private final String closeFn;
private final String checkSum;
private final boolean isStaticLoad;
private final long expirationTime;

/**
* Constructor of UDF
*/
public JavaUdf(String name, long functionId, String dbName, TFunctionBinaryType binaryType,
FunctionSignature signature,
NullableMode nullableMode, String objectFile, String symbol, String prepareFn, String closeFn,
String checkSum, boolean isStaticLoad, Expression... args) {
String checkSum, boolean isStaticLoad, long expirationTime, Expression... args) {
super(name, args);
this.dbName = dbName;
this.functionId = functionId;
Expand All @@ -77,6 +78,7 @@ public JavaUdf(String name, long functionId, String dbName, TFunctionBinaryType
this.closeFn = closeFn;
this.checkSum = checkSum;
this.isStaticLoad = isStaticLoad;
this.expirationTime = expirationTime;
}

@Override
Expand Down Expand Up @@ -106,7 +108,8 @@ public NullableMode getNullableMode() {
public JavaUdf withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == this.children.size());
return new JavaUdf(getName(), functionId, dbName, binaryType, signature, nullableMode,
objectFile, symbol, prepareFn, closeFn, checkSum, isStaticLoad, children.toArray(new Expression[0]));
objectFile, symbol, prepareFn, closeFn, checkSum, isStaticLoad, expirationTime,
children.toArray(new Expression[0]));
}

/**
Expand Down Expand Up @@ -135,7 +138,7 @@ public static void translateToNereidsFunction(String dbName, org.apache.doris.ca
scalar.getSymbolName(),
scalar.getPrepareFnSymbol(),
scalar.getCloseFnSymbol(),
scalar.getChecksum(), scalar.isStaticLoad(),
scalar.getChecksum(), scalar.isStaticLoad(), scalar.getExpirationTime(),
virtualSlots);

JavaUdfBuilder builder = new JavaUdfBuilder(udf);
Expand Down Expand Up @@ -165,6 +168,7 @@ public Function getCatalogFunction() {
expr.setChecksum(checkSum);
expr.setId(functionId);
expr.setStaticLoad(isStaticLoad);
expr.setExpirationTime(expirationTime);
return expr;
} catch (Exception e) {
throw new AnalysisException(e.getMessage(), e.getCause());
Expand Down
1 change: 1 addition & 0 deletions gensrc/thrift/Types.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ struct TFunction {
13: optional bool vectorized = false
14: optional bool is_udtf_function = false
15: optional bool is_static_load = false
16: optional i64 expiration_time //minutes
}

enum TJdbcOperation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ suite("test_javaudf_static_load_test") {
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.StaticIntTest",
"static_load"="true",
"expiration_time"="10",
"type"="JAVA_UDF"
); """

Expand Down Expand Up @@ -89,6 +90,7 @@ suite("test_javaudf_static_load_test") {
"file"="file://${jarPath}",
"symbol"="org.apache.doris.udf.StaticIntTest",
"static_load"="true",
"expiration_time"="10",
"type"="JAVA_UDF"
); """
qt_select11 """ SELECT static_load_test(); """
Expand Down

0 comments on commit 60fce37

Please sign in to comment.