Skip to content

Commit

Permalink
Merge pull request #529 from singnet/key-encryption
Browse files Browse the repository at this point in the history
Key and mnemonic encryption
  • Loading branch information
Arondondon authored Nov 22, 2024
2 parents 10d2d7e + 9ab148d commit 95f50be
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ jobs:
export FORMER_SNET_TEST_INFURA_KEY=${{ secrets.FORM_INF_KEY }}
export PIP_BREAK_SYSTEM_PACKAGES=1
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
# sh -ex ./snet/cli/test/utils/run_all_functional.sh
# sh -ex ./snet/cli/test/utils/run_all_functional.sh
python3 ./snet/cli/test/functional_tests/test_entry_point.py
python3 ./snet/cli/test/functional_tests/func_tests.py
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ trezor==0.13.8
ledgerblue==0.1.48
snet.contracts==0.1.1
lighthouseweb3==0.1.4
cryptography==43.0.3
23 changes: 14 additions & 9 deletions snet/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def add_identity_options(parser, config):
p.set_defaults(fn="list")

p = subparsers.add_parser("create",
help="Create a new identity")
help="Create a new identity. For 'mnemonic' and 'key' identity_type, "
"secret encryption is enabled by default.")
p.set_defaults(fn="create")
p.add_argument("identity_name",
help="Name of identity to create",
Expand All @@ -135,6 +136,10 @@ def add_identity_options(parser, config):
help="Type of identity to create from {}".format(
get_identity_types()),
metavar="IDENTITY_TYPE")
p.add_argument("-de", "--do-not-encrypt",
default=False,
action="store_true",
help="Do not encrypt the identity's private key or mnemonic. For 'key' and 'mnemonic' identity_type.")
p.add_argument("--mnemonic",
help="BIP39 mnemonic for 'mnemonic' identity_type")
p.add_argument("--private-key",
Expand Down Expand Up @@ -451,7 +456,8 @@ def add_contract_function_options(parser, contract_name):
fns.append({
"name": fn["name"],
"named_inputs": [(i["name"], i["type"]) for i in fn["inputs"] if i["name"] != ""],
"positional_inputs": [i["type"] for i in fn["inputs"] if i["name"] == ""]
"positional_inputs": [i["type"] for i in fn["inputs"] if i["name"] == ""],
"stateMutability": fn["stateMutability"]
})

if len(fns) > 0:
Expand All @@ -462,7 +468,10 @@ def add_contract_function_options(parser, contract_name):
for fn in fns:
fn_p = subparsers.add_parser(
fn["name"], help="{} function".format(fn["name"]))
fn_p.set_defaults(fn="call")
if fn["stateMutability"] == "view":
fn_p.set_defaults(fn="call")
else:
fn_p.set_defaults(fn="transact")
fn_p.set_defaults(contract_function=fn["name"])
for i in fn["positional_inputs"]:
fn_p.add_argument(i,
Expand All @@ -473,12 +482,8 @@ def add_contract_function_options(parser, contract_name):
fn_p.add_argument("contract_named_input_{}".format(i[0]),
type=type_converter(i[1]),
metavar="{}_{}".format(i[0].lstrip("_"), i[1].upper()))
fn_p.add_argument("--transact",
action="store_const",
const="transact",
dest="fn",
help="Invoke contract function as transaction")
add_transaction_arguments(fn_p)
if fn["stateMutability"] != "view":
add_transaction_arguments(fn_p)


def add_contract_identity_arguments(parser, names_and_destinations=(("", "at"),)):
Expand Down
65 changes: 50 additions & 15 deletions snet/cli/commands/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from lighthouseweb3 import Lighthouse
import yaml
import web3
from cryptography.fernet import InvalidToken
from snet.contracts import get_contract_def

from snet.cli.contract import Contract
from snet.cli.identity import KeyIdentityProvider, KeyStoreIdentityProvider, LedgerIdentityProvider, \
MnemonicIdentityProvider, RpcIdentityProvider, TrezorIdentityProvider, get_kws_for_identity_type
from snet.cli.metadata.organization import OrganizationMetadata, PaymentStorageClient, Payment, Group
from snet.cli.utils.config import get_contract_address, get_field_from_args_or_session, \
read_default_contract_address
read_default_contract_address, decrypt_secret
from snet.cli.utils.ipfs_utils import get_from_ipfs_and_checkhash, \
hash_to_bytesuri, publish_file_in_ipfs, publish_file_in_filecoin
from snet.cli.utils.utils import DefaultAttributeObject, get_web3, is_valid_url, serializable, type_converter, \
Expand Down Expand Up @@ -174,6 +175,32 @@ def get_identity(self):
if identity_type == "keystore":
return KeyStoreIdentityProvider(self.w3, self.config.get_session_field("keystore_path"))

def check_ident(self):
identity_type = self.config.get_session_field("identity_type")
if get_kws_for_identity_type(identity_type)[0][1] and not self.ident.private_key:
if identity_type == "key":
secret = self.config.get_session_field("private_key")
else:
secret = self.config.get_session_field("mnemonic")
decrypted_secret = self._get_decrypted_secret(secret)
self.ident.set_secret(decrypted_secret)

def _get_decrypted_secret(self, secret):
decrypted_secret = None
try:
password = getpass.getpass("Password: ")
decrypted_secret = decrypt_secret(secret, password)
except InvalidToken:
self._printout("Wrong password! Try again")
if not decrypted_secret:
try:
password = getpass.getpass("Password: ")
decrypted_secret = decrypt_secret(secret, password)
except InvalidToken:
self._printerr("Wrong password! Operation failed.")
exit(1)
return decrypted_secret

def get_contract_argser(self, contract_address, contract_function, contract_def, **kwargs):
def f(*positional_inputs, **named_inputs):
args_dict = self.args.__dict__.copy()
Expand Down Expand Up @@ -243,7 +270,19 @@ def create(self):
if self.args.network:
identity["network"] = self.args.network
identity["default_wallet_index"] = self.args.wallet_index
self.config.add_identity(identity_name, identity, self.out_f)

password = None
if not self.args.do_not_encrypt and get_kws_for_identity_type(identity_type)[0][1]:
self._printout("For 'mnemonic' and 'key' identity_type, secret encryption is enabled by default, "
"so you need to come up with a password that you then need to enter on every transaction. "
"To disable encryption, use the '-de' or '--do-not-encrypt' argument.")
password = getpass.getpass("Password: ")
self._ensure(password is not None, "Password cannot be empty")
pwd_confirm = getpass.getpass("Confirm password: ")
self._ensure(password == pwd_confirm, "Passwords do not match")

self.config.add_identity(identity_name, identity, self.out_f, password)


def list(self):
for identity_section in filter(lambda x: x.startswith("identity."), self.config.sections()):
Expand Down Expand Up @@ -302,11 +341,11 @@ def unset(self):


class SessionShowCommand(BlockchainCommand):

def show(self):
rez = self.config.session_to_dict()
key = "network.%s" % rez['session']['network']
self.populate_contract_address(rez, key)

# we don't want to who private_key and mnemonic
for d in rez.values():
d.pop("private_key", None)
Expand Down Expand Up @@ -348,6 +387,7 @@ def call(self):
return result

def transact(self):
self.check_ident()
contract_address = get_contract_address(self, self.args.contract_name,
"--at is required to specify target contract address")

Expand Down Expand Up @@ -402,7 +442,8 @@ def add_group(self):
raise Exception(f"Invalid {endpoint} endpoint passed")

payment_storage_client = PaymentStorageClient(self.args.payment_channel_connection_timeout,
self.args.payment_channel_request_timeout, self.args.endpoints)
self.args.payment_channel_request_timeout,
self.args.endpoints)
payment = Payment(self.args.payment_address, self.args.payment_expiration_threshold,
self.args.payment_channel_storage_type, payment_storage_client)
group_id = base64.b64encode(secrets.token_bytes(32))
Expand All @@ -424,8 +465,7 @@ def remove_group(self):
raise e

existing_groups = org_metadata.groups
updated_groups = [
group for group in existing_groups if not group_id == group.group_id]
updated_groups = [group for group in existing_groups if not group_id == group.group_id]
org_metadata.groups = updated_groups
org_metadata.save_pretty(metadata_file)

Expand All @@ -437,17 +477,13 @@ def set_changed_values_for_group(self, group):
if self.args.payment_address:
group.update_payment_address(self.args.payment_address)
if self.args.payment_expiration_threshold:
group.update_payment_expiration_threshold(
self.args.payment_expiration_threshold)
group.update_payment_expiration_threshold(self.args.payment_expiration_threshold)
if self.args.payment_channel_storage_type:
group.update_payment_channel_storage_type(
self.args.payment_channel_storage_type)
group.update_payment_channel_storage_type(self.args.payment_channel_storage_type)
if self.args.payment_channel_connection_timeout:
group.update_connection_timeout(
self.args.payment_channel_connection_timeout)
group.update_connection_timeout(self.args.payment_channel_connection_timeout)
if self.args.payment_channel_request_timeout:
group.update_request_timeout(
self.args.payment_channel_request_timeout)
group.update_request_timeout(self.args.payment_channel_request_timeout)

def update_group(self):
group_id = self.args.group_id
Expand Down Expand Up @@ -667,7 +703,6 @@ def get_path(err):
return {"status": 0, "msg": "Organization metadata is valid and ready to publish."}

def create(self):

self._metadata_validate()

metadata_file = self.args.metadata_file
Expand Down
5 changes: 5 additions & 0 deletions snet/cli/commands/mpe_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
class MPEAccountCommand(BlockchainCommand):

def print_account(self):
self.check_ident()
self._printout(self.ident.address)

def print_agix_and_mpe_balances(self):
""" Print balance of ETH, AGIX, and MPE wallet """
self.check_ident()
if self.args.account:
account = self.args.account
else:
Expand All @@ -24,6 +26,7 @@ def print_agix_and_mpe_balances(self):
self._printout(" MPE: %s"%cogs2stragix(mpe_cogs))

def deposit_to_mpe(self):
self.check_ident()
amount = self.args.amount
mpe_address = self.get_mpe_address()

Expand All @@ -33,7 +36,9 @@ def deposit_to_mpe(self):
self.transact_contract_command("MultiPartyEscrow", "deposit", [amount])

def withdraw_from_mpe(self):
self.check_ident()
self.transact_contract_command("MultiPartyEscrow", "withdraw", [self.args.amount])

def transfer_in_mpe(self):
self.check_ident()
self.transact_contract_command("MultiPartyEscrow", "transfer", [self.args.receiver, self.args.amount])
3 changes: 2 additions & 1 deletion snet/cli/commands/mpe_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pathlib import Path

from eth_abi.codec import ABICodec
from web3._utils.encoding import pad_hex
from web3._utils.events import get_event_data
from snet.contracts import get_contract_def, get_contract_deployment_block

Expand Down Expand Up @@ -504,8 +503,10 @@ def _print_channels(self, channels, filters: list[str] = None):
def get_address_from_arg_or_ident(self, arg):
if arg:
return arg
self.check_ident()
return self.ident.address


def print_channels_filter_sender(self):
# we don't need to return other channel fields if we only need channel_id or if we'll sync channels state
return_only_id = self.args.only_id or not self.args.do_not_sync
Expand Down
5 changes: 4 additions & 1 deletion snet/cli/commands/mpe_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _get_endpoint_from_metadata_or_args(self, metadata):
return endpoints[0]

def call_server_lowlevel(self):

self.check_ident()

self._init_or_update_registered_org_if_needed()
self._init_or_update_registered_service_if_needed()
Expand Down Expand Up @@ -263,6 +263,8 @@ def _get_channel_state_statelessly(self, grpc_channel, channel_id):
return server["current_nonce"], server["current_signed_amount"], unspent_amount

def print_channel_state_statelessly(self):
self.check_ident()

grpc_channel = open_grpc_channel(self.args.endpoint)

current_nonce, current_amount, unspent_amount = self._get_channel_state_statelessly(
Expand Down Expand Up @@ -308,6 +310,7 @@ def call_server_statelessly_with_params(self, params, group_name):
return self._call_server_via_grpc_channel(grpc_channel, channel_id, server_state["current_nonce"], server_state["current_signed_amount"] + price, params, service_metadata)

def call_server_statelessly(self):
self.check_ident()
group_name = self.args.group_name
params = self._get_call_params()
response = self.call_server_statelessly_with_params(params, group_name)
Expand Down
1 change: 0 additions & 1 deletion snet/cli/commands/mpe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,6 @@ def extract_service_api_from_metadata(self):
service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash")
download_and_safe_extract_proto(service_api_source, self.args.protodir, self._get_ipfs_client())


def extract_service_api_from_registry(self):
metadata = self._get_service_metadata_from_registry()
service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash")
Expand Down
3 changes: 3 additions & 0 deletions snet/cli/commands/mpe_treasurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,20 @@ def _claim_in_progress_and_claim_channels(self, grpc_channel, channels):
self._blockchain_claim(payments)

def claim_channels(self):
self.check_ident()
grpc_channel = open_grpc_channel(self.args.endpoint)
self._claim_in_progress_and_claim_channels(grpc_channel, self.args.channels)

def claim_all_channels(self):
self.check_ident()
grpc_channel = open_grpc_channel(self.args.endpoint)
# we take list of all channels
unclaimed_payments = self._call_GetListUnclaimed(grpc_channel)
channels = [p["channel_id"] for p in unclaimed_payments]
self._claim_in_progress_and_claim_channels(grpc_channel, channels)

def claim_almost_expired_channels(self):
self.check_ident()
grpc_channel = open_grpc_channel(self.args.endpoint)
# we take list of all channels
unclaimed_payments = self._call_GetListUnclaimed(grpc_channel)
Expand Down
11 changes: 10 additions & 1 deletion snet/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
import sys

from snet.cli.utils.config import encrypt_secret

default_snet_folder = Path("~").expanduser().joinpath(".snet")
DEFAULT_NETWORK = "sepolia"

Expand Down Expand Up @@ -138,12 +140,19 @@ def set_network_field(self, network, key, value):
self._get_network_section(network)[key] = str(value)
self._persist()

def add_identity(self, identity_name, identity, out_f=sys.stdout):
def add_identity(self, identity_name, identity, out_f=sys.stdout, password=None):
identity_section = "identity.%s" % identity_name
if identity_section in self:
raise Exception("Identity section %s already exists in config" % identity_section)
if "network" in identity and identity["network"] not in self.get_all_networks_names():
raise Exception("Network %s is not in config" % identity["network"])

if password:
if "mnemonic" in identity:
identity["mnemonic"] = encrypt_secret(identity["mnemonic"], password)
elif "private_key" in identity:
identity["private_key"] = encrypt_secret(identity["private_key"], password)

self[identity_section] = identity
self._persist()
# switch to it, if it was the first identity
Expand Down
17 changes: 16 additions & 1 deletion snet/cli/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def sign_message_after_solidity_keccak(self, message):
class KeyIdentityProvider(IdentityProvider):
def __init__(self, w3, private_key):
self.w3 = w3
if private_key.startswith("::"):
self.private_key = None
self.address = None
return
self.set_secret(private_key)

def set_secret(self, private_key):
self.private_key = normalize_private_key(private_key)
self.address = get_address_from_private(self.private_key)

Expand Down Expand Up @@ -109,8 +116,16 @@ def sign_message_after_solidity_keccak(self, message):
class MnemonicIdentityProvider(IdentityProvider):
def __init__(self, w3, mnemonic, index):
self.w3 = w3
self.index = index
if mnemonic.startswith("::"):
self.private_key = None
self.address = None
return
self.set_secret(mnemonic)

def set_secret(self, mnemonic):
Account.enable_unaudited_hdwallet_features()
account = Account.from_mnemonic(mnemonic, account_path=f"m/44'/60'/0'/0/{index}")
account = Account.from_mnemonic(mnemonic, account_path=f"m/44'/60'/0'/0/{self.index}")
self.private_key = account.key.hex()
self.address = account.address

Expand Down
Loading

0 comments on commit 95f50be

Please sign in to comment.