Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vivianyyd committed Dec 3, 2023
1 parent cd6a791 commit 3454602
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ nonterminal LabelExpression label_expr;
nonterminal List parameter_list, nonempty_parameter_list;
nonterminal ParameterNode parameter;

nonterminal List<VariableBindingNode> var_binding_list, nonempty_var_binding_list;
nonterminal VariableBindingNode var_binding;

nonterminal Located<Protocol> protocol;
nonterminal Located<Protocol> protocol_annot;
nonterminal Located<ProtocolName> protocol_name;
Expand Down Expand Up @@ -282,35 +279,6 @@ return ::=
:}
;

var_binding_list ::=
nonempty_var_binding_list:var_bindings {:
RESULT = var_bindings;
:}
| /* empty */ {:
RESULT = CollectionsKt.mutableListOf();
:}
;

nonempty_var_binding_list ::=
nonempty_var_binding_list:var_bindings COMMA var_binding:r {:
var_bindings.add(r);
RESULT = var_bindings;
:}
| var_binding:r {:
RESULT = CollectionsKt.mutableListOf(r);
:}
;

var_binding ::=
variable:name AT protocol:p {:
RESULT = new VariableBindingNode(
name,
p,
location(nameleft, pright)
);
:}
;

func_block ::=
OPEN_BRACE:begin stmt_list:statements return:ret CLOSE_BRACE:end {:
RESULT = new RoutineBlockNode(statements, ret, location(beginleft, endright));
Expand All @@ -335,9 +303,10 @@ stmt_list ::=
;

stmt ::=
VAL:begin var_binding_list:varbindings EQ:e command:rhs {:
RESULT = new LetNode(
new Arguments(varbindings, location(varbindingsleft, eright)),
VAL:begin variable:var AT protocol:p EQ:e command:rhs {:
RESULT = new CommandLetNode(
var,
p,
rhs,
location(beginleft, rhsright)
);
Expand All @@ -352,31 +321,22 @@ stmt ::=
location(beginleft, eright)
);
:}
| IF:begin OPEN_PAREN index_expr:guard CLOSE_PAREN flow_block:thenbranch ELSE flow_block:elsebranch {:
RESULT = new IfNode(guard, thenbranch, elsebranch, location(beginleft, elsebranchright));
:}
| LOOP:begin flow_block:body {:
RESULT = new LoopNode(body, location(beginleft, bodyright));
:}
| BREAK:b {: RESULT = new BreakNode(location(bleft,bright)); :}
;

command ::=
IDENT:funcname LT:ibegin index_expr_list:inds GT:iend
OPEN_PAREN:abegin reference_list:args CLOSE_PAREN:aend {:
RESULT = new CallNode(
new Located(new FunctionName(funcname), location(funcnameleft, funcnameright)),
new Arguments(inds, location(ibeginleft, iendright)),
new Arguments(args, location(abeginleft, aendright)),
location(funcnameleft, aendright)
);
:}
| host:sender PERIOD INPUT LT array_type:type GT OPEN_PAREN CLOSE_PAREN:end {:
host:sender PERIOD INPUT LT array_type:type GT OPEN_PAREN CLOSE_PAREN:end {:
RESULT = new InputNode(type, sender, location(senderleft, endright));
:}
| host:recipient PERIOD OUTPUT LT array_type:type GT OPEN_PAREN reference:message CLOSE_PAREN:end {:
RESULT = new OutputNode(type, message, recipient, location(recipientleft, endright));
:}
| IF:begin OPEN_PAREN index_expr:guard CLOSE_PAREN flow_block:thenbranch ELSE flow_block:elsebranch {:
RESULT = new IfNode(guard, thenbranch, elsebranch, location(beginleft, elsebranchright));
:}
| LOOP:begin flow_block:body {:
RESULT = new LoopNode(body, location(beginleft, bodyright));
:}
| BREAK:b {: RESULT = new BreakNode(location(bleft,bright)); :}
| DECLASSIFY:begin reference:e optional_from_label:from TO label:to {:
RESULT = new DeclassificationNode(e, from, to, location(beginleft, toright));
:}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@ import io.github.aplcornell.viaduct.syntax.ObjectVariable
import io.github.aplcornell.viaduct.syntax.ProtocolNode
import io.github.aplcornell.viaduct.syntax.Temporary
import io.github.aplcornell.viaduct.syntax.precircuit.BlockNode
import io.github.aplcornell.viaduct.syntax.precircuit.CallNode
import io.github.aplcornell.viaduct.syntax.precircuit.CommandLetNode
import io.github.aplcornell.viaduct.syntax.precircuit.ComputeLetNode
import io.github.aplcornell.viaduct.syntax.precircuit.FunctionDeclarationNode
import io.github.aplcornell.viaduct.syntax.precircuit.HostDeclarationNode
import io.github.aplcornell.viaduct.syntax.precircuit.LetNode
import io.github.aplcornell.viaduct.syntax.precircuit.LookupNode
import io.github.aplcornell.viaduct.syntax.precircuit.Node
import io.github.aplcornell.viaduct.syntax.precircuit.ParameterNode
import io.github.aplcornell.viaduct.syntax.precircuit.ProgramNode
import io.github.aplcornell.viaduct.syntax.precircuit.ReduceNode
import io.github.aplcornell.viaduct.syntax.precircuit.ReferenceNode
import io.github.aplcornell.viaduct.syntax.precircuit.Variable
import io.github.aplcornell.viaduct.syntax.precircuit.VariableBindingNode
import io.github.aplcornell.viaduct.syntax.precircuit.VariableDeclarationNode
import io.github.aplcornell.viaduct.syntax.precircuit.VariableNode
import io.github.aplcornell.viaduct.syntax.precircuit.VariableReferenceNode
Expand All @@ -37,7 +35,7 @@ import kotlin.reflect.KProperty
* Associates each use of a [Name] with its declaration, and every [Name] declaration with the
* set of its uses.
*
* For example, [Temporary] variables are associated with [LetNode]s, [ObjectVariable]s with
* For example, [Temporary] variables are associated with [CommandLetNode]s, [ObjectVariable]s with
* [DeclarationNode]s, and [JumpLabel]s with [InfiniteLoopNode]s.
* */
class NameAnalysis private constructor(private val tree: Tree<Node, ProgramNode>) {
Expand Down Expand Up @@ -80,7 +78,7 @@ class NameAnalysis private constructor(private val tree: Tree<Node, ProgramNode>
when (node) {
is ComputeLetNode -> listOf(Pair(node.name, node))

is LetNode -> node.bindings.map { binding -> Pair(binding.name, binding) }
is CommandLetNode -> listOf(Pair(node.name, node))

else -> listOf()
}
Expand Down Expand Up @@ -133,13 +131,10 @@ class NameAnalysis private constructor(private val tree: Tree<Node, ProgramNode>
thisRef.contextIn
}

/** Returns the statement that defines the [Variable] in [node]. */
/** Returns the node that defines the [Variable] in [node]. */
fun declaration(node: VariableReferenceNode): VariableDeclarationNode =
(node as Node).variableDeclarations[node.name]

fun declaration(node: CallNode): FunctionDeclarationNode =
node.functionDeclarations[node.name]

/** Returns the funtion declaration that contains [parameter]. */
fun functionDeclaration(parameter: ParameterNode): FunctionDeclarationNode =
tree.parent(parameter) as FunctionDeclarationNode
Expand Down Expand Up @@ -177,11 +172,9 @@ class NameAnalysis private constructor(private val tree: Tree<Node, ProgramNode>
fun check(node: Node) {
// Check that name references are valid
when (node) {
is VariableBindingNode -> node.protocol.check()
is ComputeLetNode -> node.protocol.check()
is ReferenceNode -> declaration(node)
is LookupNode -> declaration(node)
is CallNode -> declaration(node)
else -> {}
}
// Check that there are no name clashes
Expand All @@ -193,7 +186,7 @@ class NameAnalysis private constructor(private val tree: Tree<Node, ProgramNode>
}
}

is LetNode -> node.bindings.forEach { node.variableDeclarations.put(it.name, it) }
is CommandLetNode -> node.variableDeclarations.put(node.name, node)
is ReduceNode -> node.variableDeclarations.put(node.indices.name, node.indices)
is ProgramNode -> {
// Forcing these thunks
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package io.github.aplcornell.viaduct.precircuitanalysis

import io.github.aplcornell.viaduct.syntax.Protocol
import io.github.aplcornell.viaduct.syntax.precircuit.CommandLetNode
import io.github.aplcornell.viaduct.syntax.precircuit.ComputeLetNode
import io.github.aplcornell.viaduct.syntax.precircuit.Node
import io.github.aplcornell.viaduct.syntax.precircuit.VariableBindingNode

fun Node.protocols(): Iterable<Protocol> {
val protocols = mutableSetOf<Protocol>()

fun visit(node: Node) {
when (node) {
is ComputeLetNode -> protocols.add(node.protocol.value)
is VariableBindingNode -> protocols.add(node.protocol.value)
is CommandLetNode -> protocols.add(node.protocol.value)
else -> {}
}
node.children.forEach(::visit)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,44 +1,135 @@
package io.github.aplcornell.viaduct.reordering

import io.github.aplcornell.viaduct.attributes.attribute
import io.github.aplcornell.viaduct.precircuitanalysis.NameAnalysis
import io.github.aplcornell.viaduct.syntax.precircuit.ArrayTypeNode
import io.github.aplcornell.viaduct.syntax.precircuit.BlockNode
import io.github.aplcornell.viaduct.syntax.precircuit.BreakNode
import io.github.aplcornell.viaduct.syntax.precircuit.CommandLetNode
import io.github.aplcornell.viaduct.syntax.precircuit.ComputeLetNode
import io.github.aplcornell.viaduct.syntax.precircuit.LetNode
import io.github.aplcornell.viaduct.syntax.precircuit.Node
import io.github.aplcornell.viaduct.syntax.precircuit.DeclassificationNode
import io.github.aplcornell.viaduct.syntax.precircuit.DowngradeNode
import io.github.aplcornell.viaduct.syntax.precircuit.EndorsementNode
import io.github.aplcornell.viaduct.syntax.precircuit.ExpressionNode
import io.github.aplcornell.viaduct.syntax.precircuit.FunctionDeclarationNode
import io.github.aplcornell.viaduct.syntax.precircuit.IfNode
import io.github.aplcornell.viaduct.syntax.precircuit.InputNode
import io.github.aplcornell.viaduct.syntax.precircuit.LiteralNode
import io.github.aplcornell.viaduct.syntax.precircuit.LookupNode
import io.github.aplcornell.viaduct.syntax.precircuit.LoopNode
import io.github.aplcornell.viaduct.syntax.precircuit.OperatorApplicationNode
import io.github.aplcornell.viaduct.syntax.precircuit.OutputNode
import io.github.aplcornell.viaduct.syntax.precircuit.ProgramNode
import io.github.aplcornell.viaduct.syntax.precircuit.ReduceNode
import io.github.aplcornell.viaduct.syntax.precircuit.ReferenceNode
import io.github.aplcornell.viaduct.syntax.precircuit.ReturnNode
import io.github.aplcornell.viaduct.syntax.precircuit.StatementNode
import io.github.aplcornell.viaduct.syntax.precircuit.VariableReferenceNode

val Node.uses: List<VariableReferenceNode> by attribute { TODO() }

class DependencyGraph(program: ProgramNode) {
private val nameAnalysis: NameAnalysis = NameAnalysis.get(program)
private val nodeToDependencies = mutableMapOf<StatementNode, MutableList<StatementNode>>()
private val nodeToDependents = mutableMapOf<StatementNode, MutableList<StatementNode>>()

init {
program.declarations.filterIsInstance<FunctionDeclarationNode>().forEach { buildDependencyGraph(it.body) }
println("HERE IS THE DEP MAP AFTER INIT")
nodeToDependencies.forEach { (k, _) -> println(k.toDocument().print()) }
}

private fun uses(node: StatementNode): List<VariableReferenceNode> = when (node) {
is ComputeLetNode -> uses(node.type) + uses(node.value)
is CommandLetNode -> when (val command = node.command) {
is InputNode -> uses(command.type)
is OutputNode -> uses(command.type) + listOf(command.message)
is DowngradeNode -> uses(command.expression)
}
is ReturnNode -> node.values.flatMap { uses(it) }
is BreakNode -> listOf()
is IfNode -> uses(node.guard) + node.thenBranch.flatMap { uses(it) } + node.elseBranch.flatMap { uses(it) }
is LoopNode -> node.body.flatMap { uses(it) }
}

private fun uses(node: ExpressionNode): List<VariableReferenceNode> = when (node) {
is LiteralNode -> listOf()
is ReferenceNode -> listOf(node)
is LookupNode -> listOf(node) + node.indices.flatMap { uses(it) }
is OperatorApplicationNode -> node.arguments.flatMap { uses(it) }
is ReduceNode -> uses(node.defaultValue) + uses(node.body)
}

private fun uses(node: ArrayTypeNode): List<VariableReferenceNode> = node.shape.flatMap { uses(it) }

private fun addDependencies(node: StatementNode, dependencies: List<StatementNode>) {
nodeToDependencies.getOrPut(node) { mutableListOf() }
nodeToDependents.getOrPut(node) { mutableListOf() }
dependencies.forEach {
nodeToDependencies[node]!!.add(it)
nodeToDependents.getOrPut(it) { mutableListOf() }.add(node)
}
}

fun dependents(statement: StatementNode) = nodeToDependents[statement]!!
fun dependencies(statement: StatementNode): List<StatementNode> {
println("getting dependencies of " + statement.toDocument().print())
return nodeToDependencies[statement]!! //?: listOf()
}

fun BlockNode<StatementNode>.buildDependencyGraph(block: BlockNode<StatementNode>): Map<StatementNode, List<StatementNode>> {
val dependencyGraph = block.statements.associateWith { listOf<StatementNode>() }.toMutableMap()
private fun dataDependencies(stmt: StatementNode): List<StatementNode> {
return uses(stmt).map { nameAnalysis.declaration(it) }.mapNotNull {
// TODO Extremely hacky. A better way is to have two Contexts in NameAnalysis;
// one for contextAfter (let nodes only) and one for contextChildren (other variable declaration types)
when (it) {
is StatementNode -> it
else -> null
}
}
}

/** Fills in [nodeToDependencies] for [this]. Reordering will only occur within blocks. */
// TODO Kind of weird: Data dependency edges extend outside of blocks but we only include security dependencies
// which are within the same block due to our assumption that we'll only reorder within blocks
private fun buildDependencyGraph(block: BlockNode<StatementNode>) {
val prevInputs: List<StatementNode> = listOf()
val prevOutputs: List<StatementNode> = listOf()
val prevDeclassifies: List<StatementNode> = listOf()
val prevEndorses: List<StatementNode> = listOf()
block.forEach { stmt ->
println("ADDing data dependencies for " + stmt.toDocument().print())
addDependencies(stmt, dataDependencies(stmt))
println("HERE IS THE DEP MAP AFTER ADDING")
nodeToDependencies.forEach { (k, _) -> println(k.toDocument().print()) }

this.forEach { stmt ->
when (stmt) {
is ComputeLetNode -> {
val dataDeps = stmt.uses.map { nameAnalysis.declaration(it) }
addDependencies(stmt, listOf())
} // Only data dependencies matter
is CommandLetNode -> {
when (stmt.command) {
is InputNode, is OutputNode ->
// Shouldn't change interface to the user
addDependencies(stmt, prevInputs + prevOutputs)

is DeclassificationNode -> addDependencies(stmt, prevEndorses)
is EndorsementNode -> {
addDependencies(stmt, listOf())
} // TODO No need to depend on declassifies? Does it break robust declassification
}
}
is ReturnNode -> // Cannot reorder with any side effects
addDependencies(stmt, prevInputs + prevOutputs + prevDeclassifies + prevEndorses)
is BreakNode ->
// Cannot reorder with any side effects
addDependencies(stmt, prevInputs + prevOutputs + prevDeclassifies + prevEndorses)
is IfNode -> { // TODO: the if's dependencies should be the union of its childrens
addDependencies(stmt, listOf())
buildDependencyGraph(stmt.thenBranch)
buildDependencyGraph(stmt.elseBranch)
}
is LoopNode -> {
addDependencies(stmt, listOf())
buildDependencyGraph(stmt.body)
}
is LetNode -> TODO()
is ReturnNode -> TODO()
}

}
/*
for each statement in block
data dependencies = declarations(stmt)
if stmt is output, iodependencies = all previous inputs (this can be relaxed but let's do something easy for now)
if stmt is declassify, securitydependencies = all previous endorses
if stmt is endorse, securitydependencies = all previous declassifies
if bob happy reveal data
endorse bob input
cannot move endorse before the if??
*/
return dependencyGraph
}
}
Loading

0 comments on commit 3454602

Please sign in to comment.