diff --git a/core/src/main/java/org/jruby/RubyObjectSpace.java b/core/src/main/java/org/jruby/RubyObjectSpace.java index 506bb8bfbbd..e10650c7ab4 100644 --- a/core/src/main/java/org/jruby/RubyObjectSpace.java +++ b/core/src/main/java/org/jruby/RubyObjectSpace.java @@ -37,6 +37,7 @@ import java.util.Iterator; import java.util.Map; +import java.util.stream.Stream; import org.jruby.anno.JRubyMethod; import org.jruby.anno.JRubyModule; @@ -51,6 +52,7 @@ import org.jruby.util.Inspector; import org.jruby.util.Numeric; import org.jruby.util.collections.WeakValuedIdentityMap; +import org.jruby.util.collections.WeakValuedMap; @JRubyModule(name="ObjectSpace") public class RubyObjectSpace { @@ -231,59 +233,74 @@ public WeakMap(Ruby runtime, RubyClass cls) { @JRubyMethod(name = "[]") public IRubyObject op_aref(ThreadContext context, IRubyObject key) { - IRubyObject value = map.get(key); + Map<IRubyObject, IRubyObject> weakMap = getWeakMapFor(key); + IRubyObject value = weakMap.get(key); if (value != null) return value; return context.nil; } + private Map<IRubyObject, IRubyObject> getWeakMapFor(IRubyObject key) { + if (key instanceof RubyFixnum || key instanceof RubyFloat) { + return valueMap; + } + + return identityMap; + } + @JRubyMethod(name = "[]=") public IRubyObject op_aref(ThreadContext context, IRubyObject key, IRubyObject value) { Ruby runtime = context.runtime; - map.put(key, value); + Map<IRubyObject, IRubyObject> weakMap = getWeakMapFor(key); + weakMap.put(key, value); return runtime.newFixnum(System.identityHashCode(value)); } @JRubyMethod(name = "key?") public IRubyObject key_p(ThreadContext context, IRubyObject key) { - return RubyBoolean.newBoolean(context, map.get(key) != null); + Map<IRubyObject, IRubyObject> weakMap = getWeakMapFor(key); + return RubyBoolean.newBoolean(context, weakMap.get(key) != null); } @JRubyMethod(name = "keys") public IRubyObject keys(ThreadContext context) { return context.runtime.newArrayNoCopy( - map.entrySet() - .stream() + getEntryStream() .filter(entry -> entry.getValue() != null) - .map(entry -> entry.getKey()) + .map(Map.Entry::getKey) .toArray(IRubyObject[]::new)); } + private Stream<Map.Entry<IRubyObject, IRubyObject>> getEntryStream() { + return Stream.concat(identityMap.entrySet().stream(), valueMap.entrySet().stream()); + } + @JRubyMethod(name = "values") public IRubyObject values(ThreadContext context) { return context.runtime.newArrayNoCopy( - map.values() - .stream() + getEntryStream() + .map(Map.Entry::getValue) .filter(ref -> ref != null) .toArray(IRubyObject[]::new)); } @JRubyMethod(name = {"length", "size"}) public IRubyObject size(ThreadContext context) { - return context.runtime.newFixnum(map.size()); + return context.runtime.newFixnum(identityMap.size() + valueMap.size()); } @JRubyMethod(name = {"include?", "member?"}) public IRubyObject member_p(ThreadContext context, IRubyObject key) { - return RubyBoolean.newBoolean(context, map.containsKey(key)); + return RubyBoolean.newBoolean(context, getWeakMapFor(key).containsKey(key)); } @JRubyMethod(name = {"each", "each_pair"}) public IRubyObject each(ThreadContext context, Block block) { - map.forEach((key, value) -> { + getEntryStream().forEach((entry) -> { + IRubyObject value = entry.getValue(); if (value != null) { - block.yieldSpecific(context, key, value); + block.yieldSpecific(context, entry.getKey(), value); } }); @@ -292,23 +309,23 @@ public IRubyObject each(ThreadContext context, Block block) { @JRubyMethod(name = "each_key") public IRubyObject each_key(ThreadContext context, Block block) { - for (Map.Entry<IRubyObject, IRubyObject> entry : map.entrySet()) { + getEntryStream().forEach((entry) -> { if (entry.getValue() != null) { block.yieldSpecific(context, entry.getKey()); } - } + }); return this; } @JRubyMethod(name = "each_value") public IRubyObject each_value(ThreadContext context, Block block) { - for (Map.Entry<IRubyObject, IRubyObject> entry : map.entrySet()) { + getEntryStream().forEach((entry) -> { IRubyObject value = entry.getValue(); if (value != null) { block.yieldSpecific(context, value); } - } + }); return this; } @@ -320,7 +337,7 @@ public IRubyObject inspect(ThreadContext context) { RubyString part = inspectPrefix(runtime.getCurrentContext(), metaClass.getRealClass(), inspectHashCode()); int base = part.length(); - map.entrySet().forEach(entry -> { + getEntryStream().forEach(entry -> { if (entry.getValue() != null) { if (part.length() == base) { part.cat(Inspector.COLON_SPACE); @@ -339,6 +356,7 @@ public IRubyObject inspect(ThreadContext context) { return part; } - private final WeakValuedIdentityMap<IRubyObject, IRubyObject> map = new WeakValuedIdentityMap<IRubyObject, IRubyObject>(); + private final WeakValuedIdentityMap<IRubyObject, IRubyObject> identityMap = new WeakValuedIdentityMap<>(); + private final WeakValuedMap<IRubyObject, IRubyObject> valueMap = new WeakValuedMap<>(); } }