From 87dd0b0f00b1f7d9f274c9754096abdbd47e786f Mon Sep 17 00:00:00 2001 From: Arif Wider Date: Tue, 4 Sep 2018 15:01:03 +0200 Subject: [PATCH] downloading from Google Cloud Storage instead of AWS S3 (and assuming local credentials for now) --- requirements.txt | 1 - run_decisiontree_pipeline.sh | 1 + src/download_data.py | 25 +++++++++++++++++++++++++ src/splitter.py | 15 --------------- 4 files changed, 26 insertions(+), 16 deletions(-) create mode 100644 src/download_data.py diff --git a/requirements.txt b/requirements.txt index d3f6a28..3d8b568 100644 --- a/requirements.txt +++ b/requirements.txt @@ -117,7 +117,6 @@ reflink==0.2.0 requests==2.19.1 rope==0.11.0 rsa==3.4.2 -s3fs==0.1.2 schema==0.6.8 scikit-learn==0.19.1 scipy==1.0.0 diff --git a/run_decisiontree_pipeline.sh b/run_decisiontree_pipeline.sh index 71cdd35..1561343 100755 --- a/run_decisiontree_pipeline.sh +++ b/run_decisiontree_pipeline.sh @@ -2,5 +2,6 @@ set -e +python3 src/download_data.py python3 src/splitter.py python3 src/decision_tree.py diff --git a/src/download_data.py b/src/download_data.py new file mode 100644 index 0000000..5a8985c --- /dev/null +++ b/src/download_data.py @@ -0,0 +1,25 @@ +import os +from google.cloud import storage + +def load_data(): + gcsBucket = "continuous-intelligence" + key = "store47-2016.csv" + + if not os.path.exists('data/raw'): + os.makedirs('data/raw') + + if not os.path.exists("data/" + key): + client = storage.Client() + bucket = client.get_bucket(gcsBucket) + blob = bucket.get_blob(key) + blob.download_to_filename('data/raw/store47-2016.csv') + + +def main(): + print("Loading data...") + load_data() + print("Finished downloading") + + +if __name__ == "__main__": + main() diff --git a/src/splitter.py b/src/splitter.py index 8e65d12..a3cdd92 100644 --- a/src/splitter.py +++ b/src/splitter.py @@ -1,18 +1,5 @@ import os import pandas as pd -import s3fs - -def load_data(): - s3bucket = "twde-datalab/" - key = "raw/store47-2016.csv" - - if not os.path.exists('data/raw'): - os.makedirs('data/raw') - - if not os.path.exists("data/" + key): - print("Downloading data...") - s3 = s3fs.S3FileSystem(anon=True) - s3.get(s3bucket + key, "data/" + key) def get_validation_period(latest_date_train, days_back=15): # for Kaggle we want from Wednesday to Thursday for a 15 day period @@ -37,8 +24,6 @@ def write_data(table, filename): def main(): - # load_data() - print("Loading data...") train = pd.read_csv("data/raw/store47-2016.csv")