Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Viaduct 2 Frontend #799

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .idea/kotlinScripting.xml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.github.aplcornell.viaduct.backends

import io.github.aplcornell.viaduct.backends.aby.ABYBackend
import io.github.aplcornell.viaduct.backends.cleartext.CleartextBackend
import io.github.aplcornell.viaduct.backends.commitment.CommitmentBackend

/** Combines all back ends that support circuit code generation. */
object CircuitCodeGenerationBackend : Backend by listOf(CleartextBackend, ABYBackend).unions()
object CircuitCodeGenerationBackend : Backend by listOf(CleartextBackend, ABYBackend, CommitmentBackend).unions()
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,30 @@ package io.github.aplcornell.viaduct.backends.cleartext

import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.asTypeName
import io.github.aplcornell.viaduct.circuitcodegeneration.AbstractCodeGenerator
import io.github.aplcornell.viaduct.circuitcodegeneration.Argument
import io.github.aplcornell.viaduct.circuitcodegeneration.CodeGeneratorContext
import io.github.aplcornell.viaduct.circuitcodegeneration.UnsupportedCommunicationException
import io.github.aplcornell.viaduct.circuitcodegeneration.kotlinType
import io.github.aplcornell.viaduct.circuitcodegeneration.receiveExpected
import io.github.aplcornell.viaduct.circuitcodegeneration.receiveReplicated
import io.github.aplcornell.viaduct.circuitcodegeneration.typeTranslator
import io.github.aplcornell.viaduct.runtime.commitment.Commitment
import io.github.aplcornell.viaduct.runtime.commitment.Committed
import io.github.aplcornell.viaduct.syntax.BinaryOperator
import io.github.aplcornell.viaduct.syntax.Host
import io.github.aplcornell.viaduct.syntax.Protocol
import io.github.aplcornell.viaduct.syntax.UnaryOperator
import io.github.aplcornell.viaduct.syntax.circuit.OperatorNode
import io.github.aplcornell.viaduct.syntax.operators.Maximum
import io.github.aplcornell.viaduct.syntax.operators.Minimum
import io.github.aplcornell.viaduct.backends.commitment.Commitment as CommitmentProtocol

class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCodeGenerator(context) {

override fun operatorApplication(protocol: Protocol, op: OperatorNode, arguments: List<CodeBlock>): CodeBlock =
when (op.operator) {
Minimum ->
Expand Down Expand Up @@ -112,6 +121,128 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
}
}

private fun createCommitment(
source: Protocol,
target: Protocol,
argument: Argument,
builder: CodeBlock.Builder,
): CodeBlock {
require(context.host in source.hosts + target.hosts)
if (source !is Local) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}
require(source.hosts.size == 1 && source.host in source.hosts)
require(target is CommitmentProtocol)
if (target.cleartextHost != source.host || target.cleartextHost in target.hashHosts) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}

val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value))
val sendingHost = target.cleartextHost
val receivingHosts = target.hashHosts
return when (context.host) {
sendingHost -> {
val tempName1 = context.newTemporary("CommitTemp")
val tempName2 = context.newTemporary("CommitTemp")
builder.addStatement(
"val %N = %T(%L)",
tempName1,
(Committed::class).asTypeName().parameterizedBy(argType),
argument.value,
)
builder.addStatement(
"val %N = %N.%M()",
tempName2,
tempName1,
MemberName(Committed.Companion::class.asClassName(), "commitment"),
)
receivingHosts.forEach {
builder.addStatement("%L", context.send(CodeBlock.of("%N", tempName2), it))
}
CodeBlock.of("%N", tempName1)
}

in receivingHosts -> {
val tempName3 = context.newTemporary("CommitTemp")
builder.addStatement(
"val %N = %L",
tempName3,
context.receive((Commitment::class).asTypeName().parameterizedBy(argType), source.host),
)
CodeBlock.of("%N", tempName3)
}

else -> throw IllegalStateException()
}
}

private fun openCommitment(
source: Protocol,
target: Protocol,
argument: Argument,
builder: CodeBlock.Builder,
): CodeBlock {
require(source is CommitmentProtocol)
if (target !is Cleartext) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}
require(context.host in source.hosts + target.hosts)
if (source.hashHosts != target.hosts || source.cleartextHost in source.hashHosts) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}

val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value))
val sendingHost = source.cleartextHost
val receivingHosts = target.hosts
return when (context.host) {
sendingHost -> {
receivingHosts.forEach {
builder.addStatement("%L", context.send(argument.value, it))
}
CodeBlock.of("%L.value", argument.value)
}
in receivingHosts -> {
val tempName1 = context.newTemporary("CommitTemp")
builder.addStatement(
"val %N = %L",
tempName1,
context.receive((Committed::class).asTypeName().parameterizedBy(argType), source.cleartextHost),
)
val tempName2 = context.newTemporary("CommitTemp")
builder.addStatement(
"val %N = %L",
tempName2,
argument.value,
)
val tempName3 = context.newTemporary("CommitTemp")
builder.addStatement(
"val %N = %N.%N(%N)",
tempName3,
tempName2,
"open",
tempName1,
)

val peers = receivingHosts.filter { it != context.host }
if (peers.isNotEmpty()) {
for (host in peers) builder.addStatement("%L", context.send(CodeBlock.of(tempName3), host))
builder.addStatement(
"%L",
receiveExpected(
CodeBlock.of(tempName3),
context.host,
argType,
peers,
context,
),
)
}
CodeBlock.of("%N", tempName3)
}
else -> throw IllegalStateException()
}
}

override fun import(
protocol: Protocol,
arguments: List<Argument>,
Expand All @@ -127,6 +258,13 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
CodeBlock.of("")
}
}
is CommitmentProtocol -> {
if (context.host in protocol.hosts + arg.protocol.hosts) {
openCommitment(arg.protocol, protocol, arg, builder)
} else {
CodeBlock.of("")
}
}

else -> throw UnsupportedCommunicationException(arg.protocol, protocol, arg.sourceLocation)
}
Expand All @@ -149,6 +287,13 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
CodeBlock.of("")
}
}
is CommitmentProtocol -> {
if (context.host in protocol.hosts + arg.protocol.hosts) {
createCommitment(protocol, arg.protocol, arg, builder)
} else {
CodeBlock.of("")
}
}

else -> throw UnsupportedCommunicationException(protocol, arg.protocol, arg.sourceLocation)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ object CommitmentBackend : Backend {

override fun codeGenerator(context: CodeGeneratorContext): CodeGenerator = CommitmentDispatchCodeGenerator(context)

override fun circuitCodeGenerator(context: CircuitCodeGeneratorContext): CircuitCodeGenerator = TODO()
override fun circuitCodeGenerator(context: CircuitCodeGeneratorContext): CircuitCodeGenerator = CommitmentCircuitCodeGenerator(context)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.github.aplcornell.viaduct.backends.commitment

import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.TypeName
import com.squareup.kotlinpoet.asTypeName
import io.github.aplcornell.viaduct.circuitcodegeneration.AbstractCodeGenerator
import io.github.aplcornell.viaduct.circuitcodegeneration.Argument
import io.github.aplcornell.viaduct.circuitcodegeneration.CodeGeneratorContext
import io.github.aplcornell.viaduct.circuitcodegeneration.UnsupportedCommunicationException
import io.github.aplcornell.viaduct.circuitcodegeneration.typeTranslator
import io.github.aplcornell.viaduct.runtime.commitment.Committed
import io.github.aplcornell.viaduct.syntax.Protocol
import io.github.aplcornell.viaduct.syntax.types.ValueType
import io.github.aplcornell.viaduct.runtime.commitment.Commitment as CommitmentValue

/**
* Backend code generator for the commitment protocol for the circuit IR.
*
* Throws an UnsupportedCommunicationException when used in an input program as a computation protocol.
* This is because the commitment protocol is only a storage format and not a computation protocol.
*/
class CommitmentCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCodeGenerator(context) {
override fun paramType(protocol: Protocol, sourceType: ValueType): TypeName {
require(protocol is Commitment)
return when (context.host) {
protocol.cleartextHost -> (Committed::class).asTypeName().parameterizedBy(typeTranslator(sourceType))
in protocol.hashHosts -> (CommitmentValue::class).asTypeName().parameterizedBy(typeTranslator(sourceType))
else -> throw IllegalStateException()
}
}

override fun storageType(protocol: Protocol, sourceType: ValueType): TypeName {
return super.storageType(protocol, sourceType)
}

override fun import(protocol: Protocol, arguments: List<Argument>): Pair<CodeBlock, List<CodeBlock>> {
throw UnsupportedCommunicationException(arguments.first().protocol, protocol, arguments.first().sourceLocation)
}

override fun export(protocol: Protocol, arguments: List<Argument>): Pair<CodeBlock, List<CodeBlock>> {
throw UnsupportedCommunicationException(arguments.first().protocol, protocol, arguments.first().sourceLocation)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.github.aplcornell.viaduct.syntax.source

import io.github.aplcornell.viaduct.prettyprinting.Document
import io.github.aplcornell.viaduct.prettyprinting.bracketed
import io.github.aplcornell.viaduct.prettyprinting.plus
import io.github.aplcornell.viaduct.syntax.Arguments
import io.github.aplcornell.viaduct.syntax.SourceLocation
import io.github.aplcornell.viaduct.syntax.ValueTypeNode

class ArrayTypeNode(
val elementType: ValueTypeNode,
val shape: Arguments<IndexExpressionNode>,
override val sourceLocation: SourceLocation,
) : Node() {
override val children: Iterable<Node>
get() = shape

override fun toDocument(): Document = elementType + shape.bracketed()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package io.github.aplcornell.viaduct.syntax.source

import io.github.aplcornell.viaduct.prettyprinting.Document
import io.github.aplcornell.viaduct.prettyprinting.bracketed
import io.github.aplcornell.viaduct.prettyprinting.plus
import io.github.aplcornell.viaduct.prettyprinting.times
import io.github.aplcornell.viaduct.prettyprinting.tupled
import io.github.aplcornell.viaduct.syntax.Arguments
import io.github.aplcornell.viaduct.syntax.Operator
import io.github.aplcornell.viaduct.syntax.SourceLocation
import io.github.aplcornell.viaduct.syntax.surface.keyword
import io.github.aplcornell.viaduct.syntax.values.Value

/** A computation that produces a result. */
sealed class ExpressionNode : Node()
sealed class IndexExpressionNode : ExpressionNode()

/** A literal constant. */
class LiteralNode(
val value: Value,
override val sourceLocation: SourceLocation,
) : IndexExpressionNode() {
override val children: Iterable<Nothing>
get() = listOf()

override fun toDocument(): Document = value.toDocument()
}

class ReferenceNode(
val name: VariableNode,
override val sourceLocation: SourceLocation,
) : IndexExpressionNode() {
override val children: Iterable<Nothing>
get() = listOf()

override fun toDocument(): Document = name.toDocument()
}

class LookupNode(
val variable: VariableNode,
val indices: Arguments<IndexExpressionNode>,
override val sourceLocation: SourceLocation,
) : ExpressionNode() {
override val children: Iterable<Node>
get() = indices

override fun toDocument(): Document = variable + indices.bracketed()
}

/** An n-ary operator applied to n arguments. */
class OperatorApplicationNode(
val operator: OperatorNode,
val arguments: Arguments<ExpressionNode>,
override val sourceLocation: SourceLocation,
) : ExpressionNode() {
override val children: Iterable<Node>
get() = listOf(operator) + arguments

override fun toDocument(): Document = Document("(") + operator.operator.toDocument(arguments) + ")"
}

class OperatorNode(
val operator: Operator,
override val sourceLocation: SourceLocation,
) : Node() {
override val children: Iterable<Nothing>
get() = listOf()

override fun toDocument(): Document = Document("::$operator")
}

/**
* @param defaultValue to be used when the list is empty
* @param operator must be associative
*/
class ReduceNode(
val operator: OperatorNode,
val defaultValue: ExpressionNode,
val indices: IndexParameterNode,
val body: ExpressionNode,
override val sourceLocation: SourceLocation,
) : ExpressionNode() {
override val children: Iterable<Node>
get() = listOf(operator, defaultValue, indices, body)

override fun toDocument(): Document {
return keyword("reduce") + listOf(operator, defaultValue).tupled() * "{" * indices * "->" * body * " }"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.github.aplcornell.viaduct.syntax.source

import io.github.aplcornell.viaduct.prettyprinting.Document
import io.github.aplcornell.viaduct.prettyprinting.times
import io.github.aplcornell.viaduct.syntax.SourceLocation

class IndexParameterNode(
override val name: VariableNode,
val bound: IndexExpressionNode,
override val sourceLocation: SourceLocation,
) : Node(), VariableDeclarationNode {
override val children: Iterable<Node>
get() = listOf(bound)

override fun toDocument(): Document = name.toDocument() * "<" * bound
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.github.aplcornell.viaduct.syntax.source

import io.github.aplcornell.viaduct.attributes.TreeNode
import io.github.aplcornell.viaduct.prettyprinting.PrettyPrintable
import io.github.aplcornell.viaduct.syntax.HasSourceLocation

sealed class Node : TreeNode<Node>, HasSourceLocation, PrettyPrintable
Loading
Loading