Skip to content

Commit

Permalink
fixes path handling and cache support rootid in
Browse files Browse the repository at this point in the history
path
  • Loading branch information
sehnem committed Nov 11, 2023
1 parent 2b48baa commit 1dab0cb
Showing 1 changed file with 51 additions and 32 deletions.
83 changes: 51 additions & 32 deletions gdrivefs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

from fsspec.spec import AbstractFileSystem, AbstractBufferedFile
from fsspec.utils import infer_storage_options
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from google.auth.credentials import AnonymousCredentials
Expand Down Expand Up @@ -31,7 +32,7 @@ def _finfo_from_response(f, path_prefix=None):
name = _normalize_path(path_prefix, f['name'])
else:
name = f['name']
info = {'name': name.lstrip('/'),
info = {'name': name,
'size': int(f.get('size', 0)),
'type': ftype}
f.update(info)
Expand All @@ -40,35 +41,30 @@ def _finfo_from_response(f, path_prefix=None):

class GoogleDriveFileSystem(AbstractFileSystem):
protocol = "gdrive"
root_marker = ''
root_marker = ""

def __init__(self, root_file_id=None, token="browser",
access="full_control", spaces='drive', creds=None,
access="full_control", spaces='drive',
**kwargs):
"""
Access to dgrive as a file-system
:param root_file_id: str or None
If you have a share, drive or folder ID to treat as the FS root, enter
it here. Otherwise, you will get your default drive
:param token: str
One of "anon", "browser", "cache", "service_account". Using "browser" will prompt a URL to
be put in a browser, and cache the response for future use with token="cache".
"browser" will remove any previously cached token file if it exists.
:param access: str
One of "full_control", "read_only
:param spaces:
Category of files to search, can be 'drive', 'appDataFolder' and 'photos'.
Of these, only the first is general
:param creds: None or dict
Required just for "service_account" token, a dict containing the service account
credentials obtainend in GCP console. The dict content is the same as the json file
downloaded from GCP console. More details can be found here:
One of "anon", "browser", "cache", the path for a service account json file or a dict
with service account credentials. Using "browser" will prompt a URL to be put in a
browser, and cache the response for future use with token="cache". "browser" will
remove any previously cached token file if it exists. The service account credentials
are obtainend in GCP console. More details can be found here:
https://cloud.google.com/iam/docs/service-account-creds#key-types
This credential can be usful when integrating with other GCP services, and when you
don't want the user to be prompted to authenticate.
The files need to be shared with the service account email address, that can be found
in the json file.
:param access: str
One of "full_control", "read_only
:param spaces:
Category of files to search, can be 'drive', 'appDataFolder' and 'photos'.
Of these, only the first is general
:param kwargs:
Passed to parent
"""
Expand All @@ -77,8 +73,8 @@ def __init__(self, root_file_id=None, token="browser",
self.scopes = [scope_dict[access]]
self.token = token
self.spaces = spaces
self.root_file_id = root_file_id or 'root'
self.creds = creds
if not self.root_file_id:
self.root_file_id = root_file_id or 'root'
self.connect(method=token)

def connect(self, method=None):
Expand All @@ -88,10 +84,17 @@ def connect(self, method=None):
cred = self._connect_cache()
elif method == 'anon':
cred = AnonymousCredentials()
elif method is "service_account":
cred = self._connect_service_account()
else:
raise ValueError(f"Invalid connection method `{method}`.")
if isinstance(method, dict):
sa_creds = method
elif isinstance(method, str):
try:
with open(method) as f:
sa_creds = json.load(f)
except:
raise ValueError(f"Invalid connection method or path `{method}`.")
cred = self._connect_service_account(sa_creds)

srv = build('drive', 'v3', credentials=cred)
self._drives = srv.drives()
self.service = srv.files()
Expand All @@ -107,11 +110,25 @@ def _connect_cache(self):
return pydata_google_auth.get_user_credentials(
self.scopes, use_local_webserver=True
)
def _connect_service_account(self):

def _connect_service_account(self, sa_creds):
return service_account.Credentials.from_service_account_info(
info=self.creds,
info=sa_creds,
scopes=self.scopes)

@classmethod
def _strip_protocol(cls, path):
inferred_url = infer_storage_options(path)
path = inferred_url["path"]
path.partition("?RootId=")
if not getattr(cls, "root_file_id", None):
query = inferred_url.get("url_query")
if query:
cls.root_file_id = query.split("RootId=")[-1]
else:
path = path.lstrip(cls.root_file_id)
return super()._strip_protocol(path)

@property
def drives(self):
if self._drives is not None:
Expand Down Expand Up @@ -157,6 +174,8 @@ def rmdir(self, path):
def _info_by_id(self, file_id, path_prefix=None):
response = self.service.get(fileId=file_id, fields=fields,
).execute()
if response["id"] == self.root_file_id:
response["name"] = response["id"]
return _finfo_from_response(response, path_prefix)

def export(self, path, mime_type):
Expand All @@ -168,13 +187,12 @@ def export(self, path, mime_type):
return self.service.export(fileId=file_id, mimeType=mime_type).execute()

def ls(self, path, detail=False, trashed=False):
path = self._strip_protocol(path)
if path in [None, '/']:
path = ""
path = "/" if not path else path
path = "/" + self._strip_protocol(path).lstrip("/")
files = self._ls_from_cache(path)
if not files:
if path == "":
file_id = self.root_file_id
if path == "/":
file_id = "root"
else:
file_id = self.path_to_file_id(path, trashed=trashed)
files = self._list_directory_by_id(file_id, trashed=trashed,
Expand Down Expand Up @@ -203,7 +221,6 @@ def _list_directory_by_id(self, file_id, trashed=False, path_prefix=None):
pageToken=page_token).execute()
for f in response.get('files', []):
all_files.append(_finfo_from_response(f, path_prefix))
more = response.get('incompleteSearch', False)
page_token = response.get('nextPageToken', None)
if page_token is None:
break
Expand All @@ -215,6 +232,8 @@ def path_to_file_id(self, path, parent=None, trashed=False):
return self.root_file_id
if parent is None:
parent = self.root_file_id
if path.lstrip('/')==parent:
return parent
top_file_id = self._get_directory_child_by_name(items[0], parent,
trashed=trashed)
if len(items) == 1:
Expand Down

0 comments on commit 1dab0cb

Please sign in to comment.