diff --git a/src/poktroll_clients/__init__.py b/src/poktroll_clients/__init__.py index 79ce7d4..bc7f70c 100644 --- a/src/poktroll_clients/__init__.py +++ b/src/poktroll_clients/__init__.py @@ -1,12 +1,12 @@ # Add generated protobuf types to the module path. from os import path +from .go_memory import ffi, libpoktroll_clients, go_ref, check_err, check_ref, GoManagedMem +from .depinject import Supply, SupplyMany +from .events_query_client import EventsQueryClient from .block_client import BlockClient, BlockQueryClient from .tx_context import TxContext from .tx_client import TxClient -from .events_query_client import EventsQueryClient -from .depinject import Supply, SupplyMany -from .go_memory import go_ref __all__ = [ 'BlockClient', @@ -16,5 +16,10 @@ 'EventsQueryClient', 'Supply', 'SupplyMany', + 'ffi', 'go_ref', + 'check_err', + 'check_ref', + 'GoManagedMem', + 'libpoktroll_clients', ] diff --git a/src/poktroll_clients/block_client.py b/src/poktroll_clients/block_client.py index dd6456a..c5c6df4 100644 --- a/src/poktroll_clients/block_client.py +++ b/src/poktroll_clients/block_client.py @@ -25,7 +25,7 @@ class BlockQueryClient(GoManagedMem): TODO_IN_THIS_COMMIT: comment """ - self_ref: go_ref + go_ref: go_ref err_ptr: ffi.CData def __init__(self, query_node_rpc_url: str): diff --git a/src/poktroll_clients/events_query_client.py b/src/poktroll_clients/events_query_client.py index 34a37e7..a51eb49 100644 --- a/src/poktroll_clients/events_query_client.py +++ b/src/poktroll_clients/events_query_client.py @@ -11,8 +11,8 @@ class EventsQueryClient(GoManagedMem): err_ptr: ffi.CData def __init__(self, query_node_rpc_websocket_url: str): - go_ref = libpoktroll_clients.NewEventsQueryClient(query_node_rpc_websocket_url.encode('utf-8')) - super().__init__(go_ref) + self_ref = libpoktroll_clients.NewEventsQueryClient(query_node_rpc_websocket_url.encode('utf-8')) + super().__init__(self_ref) def EventsBytes(self, query: str) -> go_ref: return libpoktroll_clients.EventsQueryClientEventsBytes(self.go_ref, query.encode('utf-8')) diff --git a/src/poktroll_clients/ffi.py b/src/poktroll_clients/ffi.py index c5f496c..b239e47 100644 --- a/src/poktroll_clients/ffi.py +++ b/src/poktroll_clients/ffi.py @@ -30,6 +30,7 @@ long long int __align; } pthread_cond_t; + // TODO: convert to snake case typedef struct AsyncContext { pthread_mutex_t mutex; pthread_cond_t cond; @@ -45,6 +46,7 @@ typedef void (*error_callback)(AsyncContext* ctx, const char* error); typedef void (*cleanup_callback)(AsyncContext* ctx); + // TODO: convert to snake case typedef struct AsyncOperation { AsyncContext* ctx; success_callback on_success; @@ -78,6 +80,8 @@ serialized_proto* messages; size_t num_messages; } proto_message_array; + + serialized_proto* GetGoProtoAsSerializedProto(go_ref go_proto_ref, char **err); go_ref NewEventsQueryClient(const char* comet_websocket_url); go_ref EventsQueryClientEventsBytes(go_ref selfRef, const char* query); @@ -91,6 +95,88 @@ go_ref NewTxClient(go_ref deps_ref, char *signing_key_name, char **err); go_ref TxClient_SignAndBroadcast(AsyncOperation* op, go_ref self_ref, serialized_proto *msg); go_ref TxClient_SignAndBroadcastMany(AsyncOperation* op, go_ref self_ref, proto_message_array *msgs); + + // Params update methods (all modules) + go_ref TxClient_UpdateSharedParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateApplicationParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateGatewayParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateSupplierParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateSessionParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateServiceParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateProofParams(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateTokenomicsParams(AsyncOperation* op, go_ref self_ref, char *params); + + // Param (individual) update methods (all modules) + go_ref TxClient_UpdateSharedParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateApplicationParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateGatewayParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateSupplierParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateSessionParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateServiceParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateProofParam(AsyncOperation* op, go_ref self_ref, char *params); + go_ref TxClient_UpdateTokenomicsParam(AsyncOperation* op, go_ref self_ref, char *params); + + // Application module message methods + go_ref TxClient_StakeApplication(AsyncOperation* op, go_ref self_ref, char *address, char *stake, proto_message_array *services); + go_ref TxClient_UnstakeApplication(AsyncOperation* op, go_ref self_ref, char *address, char *stake, proto_message_array *services); + go_ref TxClient_DelegateToGateway(AsyncOperation* op, go_ref self_ref, char *address, char *stake, proto_message_array *services); + go_ref TxClient_UndelegateFromGateway(AsyncOperation* op, go_ref self_ref, char *address, char *stake, proto_message_array *services); + go_ref TxClient_TransferApplication(AsyncOperation* op, go_ref self_ref, char *address, char *stake, proto_message_array *services); + + // Gateway module message methods + go_ref TxClient_StakeGateway(AsyncOperation* op, go_ref self_ref, char *address, char *stake); + go_ref TxClient_UnstakeGateway(AsyncOperation* op, go_ref self_ref, char *address, char *stake); + + // Supplier module message methods + go_ref TxClient_StakeSupplier(AsyncOperation* op, go_ref self_ref, char *address, char *stake); + go_ref TxClient_UnstakeSupplier(AsyncOperation* op, go_ref self_ref, char *address, char *stake); + + // Service module message methods + go_ref TxClient_AddService(AsyncOperation* op, go_ref self_ref, char *owner_address, serialized_proto *service); + + // Proof module message methods + go_ref TxClient_CreateClaim(AsyncOperation* op, go_ref self_ref, char *owner_address, char *session_header, char *root_hash, char *proof); + go_ref TxClient_SubmitProof(AsyncOperation* op, go_ref self_ref, char *owner_address, char *session_header, char *proof); + + go_ref NewQueryClient(go_ref deps_ref, char *query_node_rpc_url, char **err); + + // Params query methods (all modules) + // go_ref QueryClient_GetSharedParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetSharedParams(go_ref self_ref, char **err); + go_ref QueryClient_GetApplicationParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetGatewayParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetSupplierParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetSessionParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetServiceParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetProofParams(AsyncOperation* op, go_ref self_ref); + go_ref QueryClient_GetTokenomicsParams(AsyncOperation* op, go_ref self_ref); + + // Application module query methods + go_ref QueryClient_GetApplication(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllApplications(AsyncOperation* op, go_ref self_ref, char *address); + + // Gateway module query methods + go_ref QueryClient_GetGateway(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllGateways(AsyncOperation* op, go_ref self_ref, char *address); + + // Supplier module query methods + go_ref QueryClient_GetSupplier(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllSuppliers(AsyncOperation* op, go_ref self_ref, char *address); + + // Session module query methods + go_ref QueryClient_GetSession(AsyncOperation* op, go_ref self_ref, char *address); + + // Service module query methods + go_ref QueryClient_GetService(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllServices(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetRelayMiningDifficulty(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllRelayMiningDifficulties(AsyncOperation* op, go_ref self_ref, char *address); + + // Proof module query methods + go_ref QueryClient_GetClaim(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllClaims(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetProof(AsyncOperation* op, go_ref self_ref, char *address); + go_ref QueryClient_GetAllProofs(AsyncOperation* op, go_ref self_ref, char *address); """) diff --git a/src/poktroll_clients/go_memory.py b/src/poktroll_clients/go_memory.py index 7b7d36c..f19577e 100644 --- a/src/poktroll_clients/go_memory.py +++ b/src/poktroll_clients/go_memory.py @@ -10,7 +10,7 @@ # TODO_IN_THIS_COMMIT: switch to an err_msg[] array -def check_err(err_ptr: ffi.CData): +def check_err(err_ptr: ffi.CData) -> None: """ TODO_IN_THIS_COMMIT: comment... """ @@ -18,7 +18,7 @@ def check_err(err_ptr: ffi.CData): raise FFIError(ffi.string(err_ptr[0])) -def check_ref(go_ref: go_ref): +def check_ref(go_ref: go_ref) -> None: if go_ref < 1: raise FFIError("unexpected emtpy go_ref") @@ -34,15 +34,15 @@ class GoManagedMem: go_ref: go_ref err_ptr: ffi.CData = ffi.new("char **") - def __init__(self, go_ref: go_ref): + def __init__(self, self_ref: go_ref): """ Constructor for GoManagedMem. Stores the Go-managed memory reference. """ - self.go_ref = go_ref + self.go_ref = self_ref check_err(self.err_ptr) - check_ref(go_ref) + check_ref(self_ref) def __del__(self): """ diff --git a/src/poktroll_clients/protobuf.py b/src/poktroll_clients/protobuf.py index 70a37dc..d92917f 100644 --- a/src/poktroll_clients/protobuf.py +++ b/src/poktroll_clients/protobuf.py @@ -1,6 +1,11 @@ +import importlib from dataclasses import dataclass from typing import List +from google.protobuf import symbol_database, message +from google.protobuf.json_format import MessageToDict + +from poktroll_clients import go_ref, libpoktroll_clients, check_err from poktroll_clients.ffi import ffi @@ -13,6 +18,21 @@ class SerializedProto: type_url: str data: bytes + @staticmethod + def from_c_struct(c_serialized_proto: ffi.CData): + return SerializedProto( + type_url=(ffi.string(c_serialized_proto.type_url, c_serialized_proto.type_url_length).decode('utf-8')), + data=(bytes(ffi.buffer(c_serialized_proto.data, c_serialized_proto.data_length))), + ) + + def __init__(self, c_serialized_proto: ffi.CData = None, type_url: str = "", data: bytes = b""): + self.type_url = type_url + self.data = data + + if c_serialized_proto is not None: + self.type_url = ffi.string(c_serialized_proto.type_url, c_serialized_proto.type_url_length).decode('utf-8') + self.data = bytes(ffi.buffer(c_serialized_proto.data, c_serialized_proto.data_length)) + def to_c_struct(self) -> ffi.CData: """ Converts the Python protobuf data to a C struct while preserving the underlying memory. @@ -68,3 +88,81 @@ def to_c_struct(self) -> ffi.CData: proto_message_array.messages[i].data_length = c_msg.data_length return proto_message_array + + +def get_serialized_proto(go_proto_ref: go_ref) -> SerializedProto: + """ + TODO_IN_THIS_COMMIT: move and comment... + """ + err_ptr = ffi.new("char **") + + c_serialized_proto = libpoktroll_clients.GetGoProtoAsSerializedProto(go_proto_ref, err_ptr) + + check_err(err_ptr) + + return SerializedProto.from_c_struct(c_serialized_proto) + + +def deserialize_protobuf(serialized_data: bytes, type_url: str) -> message.Message: + """ + Deserialize protocol buffer data given a type URL. + + Args: + serialized_data: Bytes containing the serialized protobuf message + type_url: Type URL in format "type.googleapis.com/package.MessageType" + or "package.MessageType" + Returns: + dict: Deserialized protobuf message as a dictionary + + Raises: + ValueError: If type URL is invalid or message type cannot be found + ImportError: If the protobuf module cannot be imported + """ + try: + # First, import the module containing the protobuf classes + # This ensures the types are registered in the symbol database + type_url = type_url.lstrip("/") + poktroll_namespace = type_url.rsplit(".", 1)[0] + package_filename = f"{type_url.rsplit('.', 1)[1].lower()}_pb2" + package_module = f"poktroll_clients.proto.{poktroll_namespace}.{package_filename}" + importlib.import_module(package_module) + except ImportError as e: + raise ImportError(f"Could not import protobuf module {package_module}: {str(e)}") + + # Extract the full message type from the type URL + if '/' in type_url: + _, full_type = type_url.split('/', 1) + else: + full_type = type_url + + # Split into package and message type to validate format + parts = full_type.split('.') + if len(parts) < 2: + raise ValueError("Invalid type URL format") + + try: + # Get the message class from the symbol database + db = symbol_database.Default() + message_class = db.GetSymbol(full_type) + + # Create a new message instance and parse the data + message = message_class() + message.ParseFromString(serialized_data) + + return message + # # Convert to dictionary for easier handling + # return MessageToDict(message) + + except KeyError as e: + raise ValueError( + f"Could not find message type: {full_type}. Make sure it's registered in the symbol database.") from e + except Exception as e: + raise ValueError(f"Error deserializing protobuf: {str(e)}") from e + + +def get_proto_from_go_ref(go_proto_ref: go_ref) -> message.Message: + """ + TODO_IN_THIS_COMMIT: move and comment... + """ + serialized_proto = get_serialized_proto(go_proto_ref) + return deserialize_protobuf(serialized_proto.data, serialized_proto.type_url) diff --git a/src/poktroll_clients/query_client.py b/src/poktroll_clients/query_client.py new file mode 100644 index 0000000..37d682d --- /dev/null +++ b/src/poktroll_clients/query_client.py @@ -0,0 +1,165 @@ +import asyncio +from typing import Tuple, Dict + +from atomics import INTEGRAL, atomic, INT +from cffi import FFIError + +from poktroll_clients import ( + go_ref, + ffi, + libpoktroll_clients, + GoManagedMem, BlockQueryClient, Supply, check_err, check_ref, +) +from poktroll_clients.proto.poktroll.shared.params_pb2 import Params as SharedParams +from poktroll_clients.proto.poktroll.application.params_pb2 import Params as ApplicationParams +from poktroll_clients.proto.poktroll.gateway.params_pb2 import Params as GatewayParams +from poktroll_clients.proto.poktroll.supplier.params_pb2 import Params as SupplierParams +from poktroll_clients.proto.poktroll.session.params_pb2 import Params as SessionParams +from poktroll_clients.proto.poktroll.service.params_pb2 import Params as ServiceParams +from poktroll_clients.proto.poktroll.proof.params_pb2 import Params as ProofParams +from poktroll_clients.proto.poktroll.tokenomics.params_pb2 import Params as TokenomicsParams +from poktroll_clients.protobuf import get_proto_from_go_ref + + +class QueryClient(GoManagedMem): + """ + TODO_IN_THIS_COMMIT: comment + """ + + go_ref: go_ref + err_ptr: ffi.CData + # _callback_idx: INTEGRAL = atomic(width=8, atype=INT) + # _callback_fns: Dict[int, Tuple[ffi.CData, ffi.CData, ffi.CData]] = {} + + def __init__(self, query_node_rpc_url: str, deps_ref: go_ref = -1): + if deps_ref == -1: + deps_ref = _new_query_client_depinject_config(query_node_rpc_url) + + self_ref = libpoktroll_clients.NewQueryClient(deps_ref, + query_node_rpc_url.encode('utf-8'), + self.err_ptr) + super().__init__(self_ref) + + # async def get_shared_params(self) -> asyncio.Future[SharedParams]: + # op, future = self._new_async_operation() + # + # err_ch_ref = libpoktroll_clients.QueryClient_GetSharedParams( # <-- line 71 + # op, + # self.go_ref, + # ) + # + # if err_ch_ref == -1: + # error_msg = ffi.string(op.ctx.error_msg).decode('utf-8') + # future.set_exception(FFIError(error_msg)) + # + # return await future + + def get_shared_params(self) -> asyncio.Future[SharedParams]: + response_ref = libpoktroll_clients.QueryClient_GetSharedParams(self.go_ref, self.err_ptr) + check_err(self.err_ptr) + check_ref(response_ref) + + return get_proto_from_go_ref(response_ref) + + def get_application_params(self) -> ApplicationParams: + response_ref = libpoktroll_clients.QueryClient_GetApplicationParams(self.go_ref, self.err_ptr) + + def get_supplier_params(self) -> SupplierParams: + response_ref = libpoktroll_clients.QueryClient_GetSupplierParams(self.go_ref, self.err_ptr) + + def get_gateway_params(self) -> GatewayParams: + response_ref = libpoktroll_clients.QueryClient_GetGatewayParams(self.go_ref, self.err_ptr) + + def get_session_params(self) -> SessionParams: + response_ref = libpoktroll_clients.QueryClient_GetSessionParams(self.go_ref, self.err_ptr) + + def get_service_params(self) -> ServiceParams: + response_ref = libpoktroll_clients.QueryClient_GetServiceParams(self.go_ref, self.err_ptr) + + def get_proof_params(self) -> ProofParams: + response_ref = libpoktroll_clients.QueryClient_GetProofParams(self.go_ref, self.err_ptr) + + def get_tokenomics_params(self) -> TokenomicsParams: + response_ref = libpoktroll_clients.QueryClient_GetTokenomicsParams(self.go_ref, self.err_ptr) + + # TODO_CONSIDERATION: support an async API as well? + + # def _new_async_operation(self) -> Tuple[ffi.CData, asyncio.Future]: + # """ + # Creates a new AsyncOperation with callbacks and associated Future. + # The callbacks are protected from garbage collection by storing in self._callback_fns. + # + # TODO_IN_THIS_COMMIT: & de-duplicate w/ TxClient... + # """ + # + # try: + # loop = asyncio.get_running_loop() + # except RuntimeError: + # loop = asyncio.new_event_loop() + # asyncio.set_event_loop(loop) + # + # future = loop.create_future() + # + # # Create AsyncContext + # ctx = ffi.new("AsyncContext *") + # next_callback_idx = self._callback_idx.fetch_inc() + # + # # Define callbacks + # @ffi.callback("void(AsyncContext*, const void*)") + # def success_cb(ctx: ffi.CData, response_ref: go_ref): + # try: + # print("success_cb") + # # serialized_proto_response = ffi.cast("serialized_proto*", response_ref) + # response = get_proto_from_go_ref(response_ref).params + # loop.call_soon_threadsafe(future.set_result, response) + # finally: + # self._free_callback(next_callback_idx) + # + # @ffi.callback("void(AsyncContext*, const char*)") + # def error_cb(ctx, error): + # print("error_cb") + # try: + # error_str = ffi.string(error).decode('utf-8') + # loop.call_soon_threadsafe(future.set_exception, Exception(error_str)) + # except Exception as e: + # future.set_exception(e) + # finally: + # self._free_callback(next_callback_idx) + # + # @ffi.callback("void(AsyncContext*)") + # def cleanup_cb(ctx): + # self._free_callback(next_callback_idx) + # + # # Create AsyncOperation + # op = ffi.new("AsyncOperation *") + # op.ctx = ctx + # op.on_success = success_cb + # op.on_error = error_cb + # op.cleanup = cleanup_cb + # + # # Store callbacks to protect from garbage collection + # self._callback_fns[next_callback_idx] = (success_cb, error_cb, cleanup_cb) + # + # return op, future + + def _free_callback(self, callback_idx: int): + """ + Clean up stored callbacks. + """ + self._callback_fns.pop(callback_idx) + + +def _new_query_client_depinject_config( + query_node_rpc_url: str, +) -> go_ref: + """ + TODO_IN_THIS_COMMIT: comment + """ + + # TODO_IN_THIS_COMMIT: add more detail to the error messages, + # explaining the expected format, with an example. + if not query_node_rpc_url: + raise ValueError("query_node_rpc_url must be specified") + + block_query_client = BlockQueryClient(query_node_rpc_url) + return Supply(block_query_client.go_ref) diff --git a/src/poktroll_clients/tx_client.py b/src/poktroll_clients/tx_client.py index ef96443..3157da4 100644 --- a/src/poktroll_clients/tx_client.py +++ b/src/poktroll_clients/tx_client.py @@ -7,12 +7,19 @@ from cffi import FFIError -from poktroll_clients.events_query_client import EventsQueryClient -from poktroll_clients.block_client import BlockClient, BlockQueryClient -from poktroll_clients.depinject import SupplyMany -from poktroll_clients.tx_context import TxContext -from poktroll_clients.ffi import ffi, libpoktroll_clients -from poktroll_clients.go_memory import GoManagedMem, go_ref, check_err, check_ref +from poktroll_clients import ( + EventsQueryClient, + BlockClient, + BlockQueryClient, + SupplyMany, + TxContext, + libpoktroll_clients, + ffi, + go_ref, + GoManagedMem, + check_err, + check_ref, +) from poktroll_clients.protobuf import SerializedProto, ProtoMessageArray @@ -78,6 +85,8 @@ def _new_async_operation(self) -> Tuple[ffi.CData, asyncio.Future]: """ Creates a new AsyncOperation with callbacks and associated Future. The callbacks are protected from garbage collection by storing in self._callback_fns. + + TODO_IN_THIS_COMMIT: & de-duplicate w/ QueryClient... """ try: diff --git a/tests/test_query_client.py b/tests/test_query_client.py new file mode 100644 index 0000000..6289399 --- /dev/null +++ b/tests/test_query_client.py @@ -0,0 +1,27 @@ +from pprint import pprint + +import pytest + +from poktroll_clients.proto.poktroll.shared.params_pb2 import Params as SharedParams +from poktroll_clients.query_client import QueryClient + + +# @pytest.mark.asyncio +def test_query_client(): + query_client = QueryClient("http://127.0.0.1:26657") + + # shared_params = await query_client.get_shared_params() + shared_params = query_client.get_shared_params() + + expected_shared_params = SharedParams( + num_blocks_per_session=10, + grace_period_end_offset_blocks=1, + claim_window_open_offset_blocks=1, + claim_window_close_offset_blocks=4, + proof_window_close_offset_blocks=4, + supplier_unbonding_period_sessions=1, + application_unbonding_period_sessions=1, + compute_units_to_tokens_multiplier=42, + ) + + assert shared_params == expected_shared_params