diff --git a/gdrivefs/core.py b/gdrivefs/core.py index f82f78d..a4890a3 100644 --- a/gdrivefs/core.py +++ b/gdrivefs/core.py @@ -3,6 +3,7 @@ import os from fsspec.spec import AbstractFileSystem, AbstractBufferedFile +from fsspec.utils import infer_storage_options from googleapiclient.discovery import build from googleapiclient.errors import HttpError from google.auth.credentials import AnonymousCredentials @@ -31,7 +32,7 @@ def _finfo_from_response(f, path_prefix=None): name = _normalize_path(path_prefix, f['name']) else: name = f['name'] - info = {'name': name.lstrip('/'), + info = {'name': name, 'size': int(f.get('size', 0)), 'type': ftype} f.update(info) @@ -40,35 +41,30 @@ def _finfo_from_response(f, path_prefix=None): class GoogleDriveFileSystem(AbstractFileSystem): protocol = "gdrive" - root_marker = '' + root_marker = "" def __init__(self, root_file_id=None, token="browser", - access="full_control", spaces='drive', creds=None, + access="full_control", spaces='drive', **kwargs): """ Access to dgrive as a file-system - :param root_file_id: str or None - If you have a share, drive or folder ID to treat as the FS root, enter - it here. Otherwise, you will get your default drive :param token: str - One of "anon", "browser", "cache", "service_account". Using "browser" will prompt a URL to - be put in a browser, and cache the response for future use with token="cache". - "browser" will remove any previously cached token file if it exists. - :param access: str - One of "full_control", "read_only - :param spaces: - Category of files to search, can be 'drive', 'appDataFolder' and 'photos'. - Of these, only the first is general - :param creds: None or dict - Required just for "service_account" token, a dict containing the service account - credentials obtainend in GCP console. The dict content is the same as the json file - downloaded from GCP console. More details can be found here: + One of "anon", "browser", "cache", the path for a service account json file or a dict + with service account credentials. Using "browser" will prompt a URL to be put in a + browser, and cache the response for future use with token="cache". "browser" will + remove any previously cached token file if it exists. The service account credentials + are obtainend in GCP console. More details can be found here: https://cloud.google.com/iam/docs/service-account-creds#key-types This credential can be usful when integrating with other GCP services, and when you don't want the user to be prompted to authenticate. The files need to be shared with the service account email address, that can be found in the json file. + :param access: str + One of "full_control", "read_only + :param spaces: + Category of files to search, can be 'drive', 'appDataFolder' and 'photos'. + Of these, only the first is general :param kwargs: Passed to parent """ @@ -77,8 +73,8 @@ def __init__(self, root_file_id=None, token="browser", self.scopes = [scope_dict[access]] self.token = token self.spaces = spaces - self.root_file_id = root_file_id or 'root' - self.creds = creds + if not self.root_file_id: + self.root_file_id = root_file_id or 'root' self.connect(method=token) def connect(self, method=None): @@ -88,10 +84,17 @@ def connect(self, method=None): cred = self._connect_cache() elif method == 'anon': cred = AnonymousCredentials() - elif method is "service_account": - cred = self._connect_service_account() else: - raise ValueError(f"Invalid connection method `{method}`.") + if isinstance(method, dict): + sa_creds = method + elif isinstance(method, str): + try: + with open(method) as f: + sa_creds = json.load(f) + except: + raise ValueError(f"Invalid connection method or path `{method}`.") + cred = self._connect_service_account(sa_creds) + srv = build('drive', 'v3', credentials=cred) self._drives = srv.drives() self.service = srv.files() @@ -107,11 +110,25 @@ def _connect_cache(self): return pydata_google_auth.get_user_credentials( self.scopes, use_local_webserver=True ) - - def _connect_service_account(self): + + def _connect_service_account(self, sa_creds): return service_account.Credentials.from_service_account_info( - info=self.creds, + info=sa_creds, scopes=self.scopes) + + @classmethod + def _strip_protocol(cls, path): + inferred_url = infer_storage_options(path) + path = inferred_url["path"] + path.partition("?RootId=") + if not getattr(cls, "root_file_id", None): + query = inferred_url.get("url_query") + if query: + cls.root_file_id = query.split("RootId=")[-1] + else: + path = path.lstrip(cls.root_file_id) + return super()._strip_protocol(path) + @property def drives(self): if self._drives is not None: @@ -157,6 +174,8 @@ def rmdir(self, path): def _info_by_id(self, file_id, path_prefix=None): response = self.service.get(fileId=file_id, fields=fields, ).execute() + if response["id"] == self.root_file_id: + response["name"] = response["id"] return _finfo_from_response(response, path_prefix) def export(self, path, mime_type): @@ -168,13 +187,12 @@ def export(self, path, mime_type): return self.service.export(fileId=file_id, mimeType=mime_type).execute() def ls(self, path, detail=False, trashed=False): - path = self._strip_protocol(path) - if path in [None, '/']: - path = "" + path = "/" if not path else path + path = "/" + self._strip_protocol(path).lstrip("/") files = self._ls_from_cache(path) if not files: - if path == "": - file_id = self.root_file_id + if path == "/": + file_id = "root" else: file_id = self.path_to_file_id(path, trashed=trashed) files = self._list_directory_by_id(file_id, trashed=trashed, @@ -203,7 +221,6 @@ def _list_directory_by_id(self, file_id, trashed=False, path_prefix=None): pageToken=page_token).execute() for f in response.get('files', []): all_files.append(_finfo_from_response(f, path_prefix)) - more = response.get('incompleteSearch', False) page_token = response.get('nextPageToken', None) if page_token is None: break @@ -215,6 +232,8 @@ def path_to_file_id(self, path, parent=None, trashed=False): return self.root_file_id if parent is None: parent = self.root_file_id + if path.lstrip('/')==parent: + return parent top_file_id = self._get_directory_child_by_name(items[0], parent, trashed=trashed) if len(items) == 1: