Skip to content

Commit

Permalink
Validate field references
Browse files Browse the repository at this point in the history
Drop the requirement that a non-terminal `yield` step is a
record. Such a step does not produce any named fields, but
can still be used anonymously; for example, the following
query is valid:

  from e in scott.emp
    yield e.deptno
    compute sum

If the step after the `yield` attempts to reference fields by
name, no name will be valid. When

  from e in scott.emp
    yield e.deptno * 2
    yield x - 1

gives the error "unbound variable or constructor: x", the
query can be made valid by using a record to name the field:

  from e in scott.emp
    yield {x = e.deptno * 2}
    yield x - 1
  • Loading branch information
julianhyde committed Jan 31, 2024
1 parent 295d60f commit 333ae51
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 62 deletions.
8 changes: 7 additions & 1 deletion src/main/java/net/hydromatic/morel/compile/Compiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,20 @@ && getOnlyElement(bindings).id.type.equals(elementType)) {
// Note that we don't use nextFactory.
final Code yieldCode = compile(cx, yield.exp);
return () -> Codes.collectRowSink(yieldCode);
} else {
} else if (yield.exp instanceof Core.Tuple) {
final Core.Tuple tuple = (Core.Tuple) yield.exp;
final RecordLikeType recordType = tuple.type();
final ImmutableSortedMap.Builder<String, Code> mapCodes =
ImmutableSortedMap.orderedBy(RecordType.ORDERING);
forEach(tuple.args, recordType.argNameTypes().keySet(), (exp, name) ->
mapCodes.put(name, compile(cx, exp)));
return () -> Codes.yieldRowSink(mapCodes.build(), nextFactory.get());
} else {
final ImmutableSortedMap.Builder<String, Code> mapCodes =
ImmutableSortedMap.orderedBy(RecordType.ORDERING);
final Binding binding = yield.bindings.get(0);
mapCodes.put(binding.id.name, compile(cx, yield.exp));
return () -> Codes.yieldRowSink(mapCodes.build(), nextFactory.get());
}

case ORDER:
Expand Down
35 changes: 32 additions & 3 deletions src/main/java/net/hydromatic/morel/compile/TypeMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@
package net.hydromatic.morel.compile;

import net.hydromatic.morel.ast.AstNode;
import net.hydromatic.morel.type.RecordLikeType;
import net.hydromatic.morel.type.RecordType;
import net.hydromatic.morel.type.Type;
import net.hydromatic.morel.type.TypeSystem;
import net.hydromatic.morel.type.TypeVar;
import net.hydromatic.morel.util.PairList;
import net.hydromatic.morel.util.Unifier;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedSet;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;

import static net.hydromatic.morel.util.Pair.forEach;
import static net.hydromatic.morel.util.Static.transform;
Expand Down Expand Up @@ -86,13 +90,13 @@ public Type getType(AstNode node) {
return termToType(term);
}

/** Returns the type of an AST node, or null if no type is known. */
/** Returns an AST node's type, or null if no type is known. */
public @Nullable Type getTypeOpt(AstNode node) {
final Unifier.Term term = nodeTypeTerms.get(node);
return term == null ? null : termToType(term);
}

/** Returns whether the type of an AST node will be a type variable. */
/** Returns whether an AST node's type will be a type variable. */
public boolean typeIsVariable(AstNode node) {
final Unifier.Term term = nodeTypeTerms.get(node);
if (term instanceof Unifier.Variable) {
Expand All @@ -111,6 +115,31 @@ public boolean hasType(AstNode node) {
return nodeTypeTerms.containsKey(node);
}

/** Returns the field names if an AST node has a type that is a record or a
* tuple, otherwise null. */
@Nullable SortedSet<String> typeFieldNames(AstNode node) {
// The term might be a sequence or a variable. We only materialize a type
// if it is a variable. Materializing a type for every sequence allocated
// lots of temporary type variables, and created a lot of noise in ref logs.
final Unifier.Term term = nodeTypeTerms.get(node);
if (term instanceof Unifier.Sequence) {
final Unifier.Sequence sequence = (Unifier.Sequence) term;
// E.g. "record:a:b" becomes record type "{a:t0, b:t1}".
final List<String> fieldList = TypeResolver.fieldList(sequence);
if (fieldList != null) {
return ImmutableSortedSet.copyOf(RecordType.ORDERING, fieldList);
}
}
if (term instanceof Unifier.Variable) {
final Type type = termToType(term);
if (type instanceof RecordLikeType) {
return (SortedSet<String>)
((RecordLikeType) type).argNameTypes().keySet();
}
}
return null;
}

/** Visitor that converts type terms into actual types. */
private static class TermToTypeConverter
implements Unifier.TermVisitor<Type> {
Expand Down
114 changes: 68 additions & 46 deletions src/main/java/net/hydromatic/morel/compile/TypeResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.calcite.util.Holder;
Expand All @@ -57,14 +56,14 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.NavigableSet;
import java.util.Objects;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
Expand All @@ -83,6 +82,7 @@
import static net.hydromatic.morel.util.Static.transformEager;

import static java.lang.String.join;
import static java.util.Objects.requireNonNull;

/** Resolves the type of an expression. */
@SuppressWarnings("StaticPseudoFunctionalStyleMethod")
Expand All @@ -105,7 +105,7 @@ public class TypeResolver {
static final String PROGRESSIVE_LABEL = "z$dummy";

private TypeResolver(TypeSystem typeSystem) {
this.typeSystem = Objects.requireNonNull(typeSystem);
this.typeSystem = requireNonNull(typeSystem);
}

/** Deduces the datatype of a declaration. */
Expand Down Expand Up @@ -189,7 +189,7 @@ private Resolved deduceType_(Environment env, Ast.Decl decl) {
if (type.isProgressive()) {
progressive.set(true);
}
});
}, apply -> {}, apply -> {});
if (progressive.get()) {
node2.accept(FieldExpander.create(typeSystem, env));
} else {
Expand All @@ -203,21 +203,50 @@ private Resolved deduceType_(Environment env, Ast.Decl decl) {
* an unresolved type. Throws if there are unresolved field references. */
private static void checkNoUnresolvedFieldRefs(Ast.Decl decl,
TypeMap typeMap) {
forEachUnresolvedField(decl, typeMap, apply -> {
throw new TypeException("unresolved flex record (can't tell "
+ "what fields there are besides " + apply.fn + ")",
apply.arg.pos);
});
forEachUnresolvedField(decl, typeMap,
apply -> {
throw new TypeException("unresolved flex record (can't tell "
+ "what fields there are besides " + apply.fn + ")",
apply.arg.pos);
},
apply -> {
throw new TypeException("reference to field "
+ ((Ast.RecordSelector) apply.fn).name
+ " of non-record type " + typeMap.getType(apply.arg),
apply.arg.pos);
},
apply -> {
throw new TypeException("no field '"
+ ((Ast.RecordSelector) apply.fn).name
+ "' in type '" + typeMap.getType(apply.arg) + "'",
apply.arg.pos);
});
}

private static void forEachUnresolvedField(Ast.Decl decl, TypeMap typeMap,
Consumer<Ast.Apply> consumer) {
Consumer<Ast.Apply> variableConsumer,
Consumer<Ast.Apply> notRecordTypeConsumer,
Consumer<Ast.Apply> noFieldConsumer) {
decl.accept(
new Visitor() {
@Override protected void visit(Ast.Apply apply) {
if (apply.fn.op == Op.RECORD_SELECTOR
&& typeMap.typeIsVariable(apply.arg)) {
consumer.accept(apply);
if (apply.fn.op == Op.RECORD_SELECTOR) {
final Ast.RecordSelector recordSelector =
(Ast.RecordSelector) apply.fn;
if (typeMap.typeIsVariable(apply.arg)) {
variableConsumer.accept(apply);
} else {
final Collection<String> fieldNames =
typeMap.typeFieldNames(apply.arg);
if (fieldNames == null) {
notRecordTypeConsumer.accept(apply);
} else {
if (!fieldNames.contains(recordSelector.name)) {
// "#f r" is valid if "r" is a record type with a field "f"
noFieldConsumer.accept(apply);
}
}
}
}
super.visit(apply);
}
Expand All @@ -226,8 +255,8 @@ private static void forEachUnresolvedField(Ast.Decl decl, TypeMap typeMap,

private <E extends AstNode> E reg(E node,
Unifier.Variable variable, Unifier.Term term) {
Objects.requireNonNull(node);
Objects.requireNonNull(term);
requireNonNull(node);
requireNonNull(term);
map.put(node, term);
if (variable != null) {
equiv(term, variable);
Expand All @@ -238,7 +267,6 @@ private <E extends AstNode> E reg(E node,
private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
final List<Ast.Exp> args2;
final Unifier.Variable v2;
Unifier.Variable v3 = null;
final Map<Ast.IdPat, Unifier.Term> termMap;
switch (node.op) {
case BOOL_LITERAL:
Expand Down Expand Up @@ -353,6 +381,7 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
// "(from exp: v50 as id: v60 [, exp: v51 as id: v61]...
// [where filterExp: v5] [yield yieldExp: v4]): v"
final Ast.From from = (Ast.From) node;
Unifier.Variable v3 = unifier.variable();
TypeEnv env3 = env;
final Map<Ast.Id, Unifier.Variable> fieldVars = new LinkedHashMap<>();
final List<Ast.FromStep> fromSteps = new ArrayList<>();
Expand All @@ -362,12 +391,8 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
if (step.i != from.steps.size() - 1) {
switch (step.e.op) {
case COMPUTE:
throw new AssertionError("'compute' step must be last in 'from'");
case YIELD:
if (((Ast.Yield) step.e).exp.op != Op.RECORD) {
throw new AssertionError("'yield' step that is not last in 'from'"
+ " must be a record expression");
}
throw new IllegalArgumentException(
"'compute' step must be last in 'from'");
}
}
env3 = p.left;
Expand All @@ -378,12 +403,10 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
v3 = unifier.variable();
yieldExp2 = deduceType(env3, from.implicitYieldExp, v3);
} else {
Objects.requireNonNull(v3);
requireNonNull(v3);
yieldExp2 = null;
}
final Ast.From from2 =
from.copy(fromSteps,
from.implicitYieldExp != null ? yieldExp2 : null);
final Ast.From from2 = from.copy(fromSteps, yieldExp2);
return reg(from2, v,
from.isCompute() ? v3 : unifier.apply(LIST_TY_CON, v3));

Expand Down Expand Up @@ -478,6 +501,7 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
private Pair<TypeEnv, Unifier.Variable> deduceStepType(TypeEnv env,
Ast.FromStep step, Unifier.Variable v, final TypeEnv env2,
Map<Ast.Id, Unifier.Variable> fieldVars, List<Ast.FromStep> fromSteps) {
requireNonNull(v);
switch (step.op) {
case SCAN:
final Ast.Scan scan = (Ast.Scan) step;
Expand Down Expand Up @@ -512,6 +536,7 @@ private Pair<TypeEnv, Unifier.Variable> deduceStepType(TypeEnv env,
fieldVars.put(ast.id(Pos.ZERO, e.getKey().name),
(Unifier.Variable) e.getValue());
}
v = fieldVar(fieldVars);
final Ast.Exp scanCondition2;
if (scan.condition != null) {
final Unifier.Variable v5 = unifier.variable();
Expand Down Expand Up @@ -582,8 +607,6 @@ private Pair<TypeEnv, Unifier.Variable> deduceStepType(TypeEnv env,
final Ast.Group group = (Ast.Group) step;
validateGroup(group);
TypeEnv env3 = env;
final Map<Ast.Id, Unifier.Variable> inFieldVars =
ImmutableMap.copyOf(fieldVars);
fieldVars.clear();
final PairList<Ast.Id, Ast.Exp> groupExps = PairList.of();
for (Map.Entry<Ast.Id, Ast.Exp> groupExp : group.groupExps) {
Expand All @@ -605,18 +628,16 @@ private Pair<TypeEnv, Unifier.Variable> deduceStepType(TypeEnv env,
final Ast.Exp aggregateFn2 =
deduceType(env2, aggregate.aggregate, v9);
final Ast.Exp arg2;
final Unifier.Term term;
final Unifier.Variable v10;
if (aggregate.argument == null) {
arg2 = null;
term = fieldRecord(inFieldVars);
v10 = v;
} else {
final Unifier.Variable v10 = unifier.variable();
v10 = unifier.variable();
arg2 = deduceType(env2, aggregate.argument, v10);
term = v10;
}
reg(aggregate.aggregate, null, v9);
equiv(
unifier.apply(FN_TY_CON, unifier.apply(LIST_TY_CON, term), v8),
equiv(unifier.apply(FN_TY_CON, unifier.apply(LIST_TY_CON, v10), v8),
v9);
env3 = env3.bind(id.name, v8);
fieldVars.put(id, v8);
Expand Down Expand Up @@ -648,16 +669,16 @@ private void validateGroup(Ast.Group group) {
}
}

private Unifier.Term fieldRecord(Map<Ast.Id, Unifier.Variable> fieldVars) {
private Unifier.Variable fieldVar(Map<Ast.Id, Unifier.Variable> fieldVars) {
switch (fieldVars.size()) {
case 0:
return toTerm(PrimitiveType.UNIT);
return equiv(toTerm(PrimitiveType.UNIT), unifier.variable());
case 1:
return Iterables.getOnlyElement(fieldVars.values());
default:
final TreeMap<String, Unifier.Variable> map = new TreeMap<>();
fieldVars.forEach((k, v) -> map.put(k.name, v));
return record(map);
return equiv(record(map), unifier.variable());
}
}

Expand Down Expand Up @@ -1245,8 +1266,9 @@ private Ast.Exp prefix(TypeEnv env, Ast.PrefixCall call, Unifier.Variable v) {
ast.apply(ast.id(Pos.ZERO, call.op.opName), call.a), v);
}

private void equiv(Unifier.Term term, Unifier.Variable atom) {
terms.add(new TermVariable(term, atom));
private Unifier.Variable equiv(Unifier.Term term, Unifier.Variable v) {
terms.add(new TermVariable(term, v));
return v;
}

private void equiv(Unifier.Term term, Unifier.Term term2) {
Expand Down Expand Up @@ -1397,9 +1419,9 @@ private static class BindTypeEnv implements TypeEnv {

BindTypeEnv(String definedName,
Function<TypeSystem, Unifier.Term> termFactory, TypeEnv parent) {
this.definedName = Objects.requireNonNull(definedName);
this.termFactory = Objects.requireNonNull(termFactory);
this.parent = Objects.requireNonNull(parent);
this.definedName = requireNonNull(definedName);
this.termFactory = requireNonNull(termFactory);
this.parent = requireNonNull(parent);
}

@Override public Unifier.Term get(TypeSystem typeSystem, String name,
Expand Down Expand Up @@ -1442,7 +1464,7 @@ private class TypeEnvHolder implements BiConsumer<String, Type> {
private TypeEnv typeEnv;

TypeEnvHolder(TypeEnv typeEnv) {
this.typeEnv = Objects.requireNonNull(typeEnv);
this.typeEnv = requireNonNull(typeEnv);
}

@Override public void accept(String name, Type type) {
Expand All @@ -1468,9 +1490,9 @@ public static class Resolved {
private Resolved(Environment env,
Ast.Decl originalNode, Ast.Decl node, TypeMap typeMap) {
this.env = env;
this.originalNode = Objects.requireNonNull(originalNode);
this.node = Objects.requireNonNull(node);
this.typeMap = Objects.requireNonNull(typeMap);
this.originalNode = requireNonNull(originalNode);
this.node = requireNonNull(node);
this.typeMap = requireNonNull(typeMap);
Preconditions.checkArgument(originalNode instanceof Ast.FunDecl
? node instanceof Ast.ValDecl
: originalNode.getClass() == node.getClass());
Expand Down
Loading

0 comments on commit 333ae51

Please sign in to comment.