-
Notifications
You must be signed in to change notification settings - Fork 1
/
verifier_ext.rs
108 lines (94 loc) · 3.09 KB
/
verifier_ext.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use crate::Runtime;
use codec::{Decode, Encode};
use frame_support::log::{error, trace};
use pallet_contracts::chain_extension::{
ChainExtension, Environment, Ext, InitState, RetVal, SysConfig,
};
use risc0_zkvm::{serde::from_slice, SessionReceipt};
use sp_core::crypto::UncheckedFrom;
use sp_runtime::DispatchError;
use sp_std::prelude::Vec;
/// Contract extension for `FetchRandom`
#[derive(Default)]
pub struct Risc0VerifierExtension;
type DispatchResult = Result<(), DispatchError>;
fn convert_err(err_msg: &'static str) -> impl FnOnce(DispatchError) -> DispatchError {
move |err| {
trace!(
target: "runtime",
"Risc0 error:{:?}",
err
);
DispatchError::Other(err_msg)
}
}
enum FuncId {
Verify,
Deserialize
}
#[derive(Debug, PartialEq, Encode, Decode)]
struct ProofData {
slice: Vec<u32>,
image_id: [u32; 8],
}
impl TryFrom<u16> for FuncId {
type Error = DispatchError;
fn try_from(func_id: u16) -> Result<Self, Self::Error> {
let id = match func_id {
1 => FuncId::Verify,
2 => FuncId::Deserialize,
_ => {
error!("Called an unregistered `func_id`: {:}", func_id);
return Err(DispatchError::Other("Unimplemented func_id"));
},
};
Ok(id)
}
}
/// Chain Extension function for the verification
/// Note! Not weight is charged, but it should be calculated
fn verify<E: Ext>(env: Environment<E, InitState>) -> DispatchResult
where
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
{
let mut buffer = env.buf_in_buf_out();
let proof_data: ProofData = buffer.read_as_unbounded(buffer.in_len())?;
let receipt: SessionReceipt = from_slice(&proof_data.slice)
.map_err(|_| DispatchError::Other("Error during proof deserialization"))?;
receipt.verify(proof_data.image_id).map_err(|_| DispatchError::Other("Proof is invalid"))
}
/// An experiment to just deserialize the input from the slice and encode it using SCALE
/// requires the implementation of `Decode` for `SessionReceipt`
fn deserialize_proof<E: Ext>(env: Environment<E, InitState>) -> DispatchResult
where
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
{
let mut buffer = env.buf_in_buf_out();
let proof_data: ProofData = buffer.read_as_unbounded(buffer.in_len())?;
let receipt: SessionReceipt = from_slice(&proof_data.slice)
.map_err(|_| DispatchError::Other("Error during proof deserialization"))?;
let receipt_wrapped = ReceiptWrapper(receipt);
let bytes = receipt_wrapped.encode();
buffer.write(&bytes, false, None).map_err(
convert_err("Error writing receipt to buffer")
)
}
#[derive(Debug)]
struct ReceiptWrapper(SessionReceipt);
impl Encode for ReceiptWrapper {}
impl ChainExtension<Runtime> for Risc0VerifierExtension {
fn call<E: Ext>(&mut self, env: Environment<E, InitState>) -> Result<RetVal, DispatchError>
where
<E::T as SysConfig>::AccountId: UncheckedFrom<<E::T as SysConfig>::Hash> + AsRef<[u8]>,
{
let func_id = FuncId::try_from(env.func_id())?;
match func_id {
FuncId::Verify => verify::<E>(env)?,
FuncId::Deserialize => deserialize_proof::<E>(env)?,
}
Ok(RetVal::Converging(0))
}
fn enabled() -> bool {
true
}
}