From d2e3f00843cd438546999b51e9d1bb9c92fbb2db Mon Sep 17 00:00:00 2001 From: "Nathan Voxland (Activeloop)" <151186252+nvoxland-al@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:26:26 +0000 Subject: [PATCH] Ensure username is stored for use in commit messages (#2772) Ensure username is stored for use in commit messages and other places, even if token was passed as an argument to deeplake --- deeplake/api/dataset.py | 1 - deeplake/client/client.py | 3 +- deeplake/core/dataset/dataset.py | 12 ++++ deeplake/core/tests/test_dataset.py | 68 ++++++++++++++++++++ deeplake/core/version_control/commit_node.py | 6 +- deeplake/util/tests/test_version_control.py | 8 +-- deeplake/util/version_control.py | 2 +- 7 files changed, 90 insertions(+), 10 deletions(-) create mode 100644 deeplake/core/tests/test_dataset.py diff --git a/deeplake/api/dataset.py b/deeplake/api/dataset.py index 818cd0c128..7a97a95624 100644 --- a/deeplake/api/dataset.py +++ b/deeplake/api/dataset.py @@ -2,7 +2,6 @@ import os import deeplake -import jwt import pathlib import posixpath import warnings diff --git a/deeplake/client/client.py b/deeplake/client/client.py index 41f819490c..7c2d0271d8 100644 --- a/deeplake/client/client.py +++ b/deeplake/client/client.py @@ -61,7 +61,6 @@ def __init__(self, token: Optional[str] = None): ) self.version = deeplake.__version__ - self._token_from_env = False self.auth_header = None self.token = ( token @@ -76,7 +75,7 @@ def __init__(self, token: Optional[str] = None): if orgs == ["public"]: self.token = token or self.get_token() self.auth_header = f"Bearer {self.token}" - if self._token_from_env: + else: username = self.get_user_profile()["name"] if get_reporting_config().get("username") != username: save_reporting_config(True, username=username) diff --git a/deeplake/core/dataset/dataset.py b/deeplake/core/dataset/dataset.py index 67bd78fa30..9ce8d91f3b 100644 --- a/deeplake/core/dataset/dataset.py +++ b/deeplake/core/dataset/dataset.py @@ -10,6 +10,8 @@ import pathlib import numpy as np from time import time, sleep + +from jwt import DecodeError from tqdm import tqdm import deeplake @@ -332,6 +334,16 @@ def maybe_flush(self): self._flush_vc_info() self.storage.flush() + @property + def username(self) -> str: + if not self.token: + return "public" + + try: + return jwt.decode(self.token, options={"verify_signature": False})["id"] + except DecodeError: + return "public" + @property def num_samples(self) -> int: """Returns the length of the smallest tensor. diff --git a/deeplake/core/tests/test_dataset.py b/deeplake/core/tests/test_dataset.py new file mode 100644 index 0000000000..d7c6e57c13 --- /dev/null +++ b/deeplake/core/tests/test_dataset.py @@ -0,0 +1,68 @@ +import os + +from deeplake.client.config import DEEPLAKE_AUTH_TOKEN +from deeplake.core import LRUCache +from deeplake.core.storage.memory import MemoryProvider + +from deeplake.core.dataset import Dataset + + +def test_token_and_username(hub_cloud_dev_token): + assert DEEPLAKE_AUTH_TOKEN not in os.environ + + ds = Dataset( + storage=LRUCache( + cache_storage=MemoryProvider(), cache_size=0, next_storage=MemoryProvider() + ) + ) + assert ds.token is None + assert ds.username == "public" + + # invalid tokens come through as "public" + ds = Dataset( + token="invalid_value", + storage=LRUCache( + cache_storage=MemoryProvider(), cache_size=0, next_storage=MemoryProvider() + ), + ) + assert ds.token == "invalid_value" + assert ds.username == "public" + + # valid tokens come through correctly + ds = Dataset( + token=hub_cloud_dev_token, + storage=LRUCache( + cache_storage=MemoryProvider(), cache_size=0, next_storage=MemoryProvider() + ), + ) + assert ds.token == hub_cloud_dev_token + assert ds.username == "testingacc2" + + # When env is set, it takes precedence over None for the token but not over a set token + try: + os.environ[DEEPLAKE_AUTH_TOKEN] = hub_cloud_dev_token + ds = Dataset( + storage=LRUCache( + cache_storage=MemoryProvider(), + cache_size=0, + next_storage=MemoryProvider(), + ) + ) + assert ds.token == hub_cloud_dev_token + assert ds.username == "testingacc2" + + ds = Dataset( + token="invalid_value", + storage=LRUCache( + cache_storage=MemoryProvider(), + cache_size=0, + next_storage=MemoryProvider(), + ), + ) + assert ds.token == "invalid_value" + assert ds.username == "public" + + finally: + os.environ.pop(DEEPLAKE_AUTH_TOKEN) + + assert DEEPLAKE_AUTH_TOKEN not in os.environ diff --git a/deeplake/core/version_control/commit_node.py b/deeplake/core/version_control/commit_node.py index d0116beb1e..a30600e637 100644 --- a/deeplake/core/version_control/commit_node.py +++ b/deeplake/core/version_control/commit_node.py @@ -33,12 +33,14 @@ def copy(self): node.total_samples_processed = self.total_samples_processed return node - def add_successor(self, node: "CommitNode", message: Optional[str] = None): + def add_successor( + self, node: "CommitNode", author: str, message: Optional[str] = None + ): """Adds a successor (a type of child) to the node, used for commits.""" node.parent = self self.children.append(node) self.commit_message = message - self.commit_user_name = get_user_name() + self.commit_user_name = author self.commit_time = datetime.utcnow() def merge_from(self, node: "CommitNode"): diff --git a/deeplake/util/tests/test_version_control.py b/deeplake/util/tests/test_version_control.py index d971bb0dde..1b1e41e05f 100644 --- a/deeplake/util/tests/test_version_control.py +++ b/deeplake/util/tests/test_version_control.py @@ -14,10 +14,10 @@ def test_merge_commit_node_map(): b = CommitNode("main", "b") c = CommitNode("main", "c") e = CommitNode("main", "e") - root.add_successor(a, "commit a") - root.add_successor(b, "commit b") - a.add_successor(c, "commit c") - c.add_successor(e, "commit e") + root.add_successor(a, "me", "commit a") + root.add_successor(b, "me", "commit b") + a.add_successor(c, "me", "commit c") + c.add_successor(e, "me", "commit e") map1 = { FIRST_COMMIT_ID: root, "a": a, diff --git a/deeplake/util/version_control.py b/deeplake/util/version_control.py index 3e0a08f3a0..ebe371d6ef 100644 --- a/deeplake/util/version_control.py +++ b/deeplake/util/version_control.py @@ -175,7 +175,7 @@ def commit( hash = generate_hash() version_state["commit_id"] = hash new_node = CommitNode(version_state["branch"], hash) - stored_commit_node.add_successor(new_node, message) + stored_commit_node.add_successor(new_node, dataset.username, message) stored_commit_node.is_checkpoint = is_checkpoint stored_commit_node.total_samples_processed = total_samples_processed version_state["commit_node"] = new_node