Skip to content

Commit

Permalink
Fix arbitrary file write during tarfile extraction
Browse files Browse the repository at this point in the history
Fixes #3302 #3301
  • Loading branch information
Ali-Razmjoo committed Sep 6, 2024
1 parent 74e6e63 commit 1ba1030
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 23 deletions.
4 changes: 2 additions & 2 deletions luigi/contrib/lsf.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def track_job(job_id):
- "EXIT"
based on the LSF documentation
"""
cmd = "bjobs -noheader -o stat {}".format(job_id)
cmd = ["bjobs", "-noheader", "-o", "stat", str(job_id)]
track_job_proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, shell=True)
cmd, stdout=subprocess.PIPE, shell=False)
status = track_job_proc.communicate()[0].strip('\n')
return status

Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/lsf_runner.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def do_work_on_compute_node(work_dir):
Expand All @@ -52,10 +52,8 @@ def extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
2 changes: 1 addition & 1 deletion luigi/contrib/pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init_token(self):

request_json = json.dumps({'username': self.__openpai.username, 'password': self.__openpai.password,
'expiration': self.__openpai.expiration})
logger.debug('Get token request {0}'.format(request_json))
logger.debug('Requesting token from OpenPai')
response = rs.post(urljoin(self.__openpai.pai_url, '/api/v1/token'),
headers={'Content-Type': 'application/json'}, data=request_json)
logger.debug('Get token response {0}'.format(response.text))
Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/sge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import sys
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def _do_work_on_compute_node(work_dir, tarball=True):
Expand Down Expand Up @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
2 changes: 1 addition & 1 deletion luigi/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def acquire_for(pid_dir, num_available=1, kill_signal=None):
# Create a pid file if it does not exist
try:
os.mkdir(pid_dir)
os.chmod(pid_dir, 0o777)
os.chmod(pid_dir, 0o700)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
Expand Down
97 changes: 97 additions & 0 deletions luigi/safe_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while
mitigating path traversal vulnerabilities, which can occur when files inside the archive are
crafted to escape the intended extraction directory.
The `SafeExtractor` ensures that the extracted file paths are validated before extraction to
prevent malicious archives from extracting files outside the intended directory.
Classes:
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks.
Usage Example:
extractor = SafeExtractor("/desired/directory")
extractor.safe_extract("archive.tar")
"""

import os
import tarfile


class SafeExtractor:
"""
A class to safely extract tar files, ensuring that no path traversal
vulnerabilities are exploited.
Attributes:
path (str): The directory to extract files into.
Methods:
_is_within_directory(directory, target):
Checks if a target path is within a given directory.
safe_extract(tar_path, members=None, \\*, numeric_owner=False):
Safely extracts the contents of a tar file to the specified directory.
"""

def __init__(self, path="."):
"""
Initializes the SafeExtractor with the specified directory path.
Args:
path (str): The directory to extract files into. Defaults to the current directory.
"""
self.path = path

@staticmethod
def _is_within_directory(directory, target):
"""
Checks if a target path is within a given directory.
Args:
directory (str): The directory to check against.
target (str): The target path to check.
Returns:
bool: True if the target path is within the directory, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(self, tar_path, members=None, *, numeric_owner=False):
"""
Safely extracts the contents of a tar file to the specified directory.
Args:
tar_path (str): The path to the tar file to extract.
members (list, optional): A list of members to extract. Defaults to None.
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False.
Raises:
ValueError: If a path traversal attempt is detected.
"""
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
member_path = os.path.join(self.path, member.name)
if not self._is_within_directory(self.path, member_path):
raise ValueError("Attempted Path Traversal in Tar File")
tar.extractall(self.path, members, numeric_owner=numeric_owner)
12 changes: 5 additions & 7 deletions test/contrib/lsf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@
wrappers
"""

import subprocess
import logging
import os
import os.path
from glob import glob
import subprocess
import unittest
import logging
from glob import glob

import pytest
from mock import patch

import luigi
from luigi.contrib.lsf import LSFJobTask

import pytest

DEFAULT_HOME = ''

LOGGER = logging.getLogger('luigi-interface')
Expand All @@ -56,7 +56,6 @@ def on_lsf_master():


class TestJobTask(LSFJobTask):

'''Simple SGE job: write a test file to NSF shared drive and waits a minute'''

i = luigi.Parameter()
Expand All @@ -72,7 +71,6 @@ def output(self):

@pytest.mark.contrib
class TestSGEJob(unittest.TestCase):

'''Test from SGE master node'''

@patch('subprocess.Popen')
Expand Down
4 changes: 2 additions & 2 deletions test/lock_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_acquiring_partially_taken_lock(self):
self.assertTrue(acquired)

s = os.stat(self.pid_file)
self.assertEqual(s.st_mode & 0o777, 0o777)
self.assertEqual(s.st_mode & 0o700, 0o700)

def test_acquiring_lock_from_missing_process(self):
fake_pid = 99999
Expand All @@ -111,7 +111,7 @@ def test_acquiring_lock_from_missing_process(self):
self.assertTrue(acquired)

s = os.stat(self.pid_file)
self.assertEqual(s.st_mode & 0o777, 0o777)
self.assertEqual(s.st_mode & 0o700, 0o700)

@mock.patch('os.kill')
def test_take_lock_with_kill(self, kill_fn):
Expand Down
116 changes: 116 additions & 0 deletions test/safe_extractor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Safe Extractor Test
=============
Tests for the Safe Extractor class in luigi.safe_extractor module.
"""

import os
import shutil
import tarfile
import tempfile
import unittest

from luigi.safe_extractor import SafeExtractor


class TestSafeExtract(unittest.TestCase):
"""
Unit test class for testing the SafeExtractor module.
"""

def setUp(self):
"""Set up temporary directory for test files."""
self.temp_dir = tempfile.mkdtemp()
self.test_file_template = 'test_file_{}.txt'
self.tar_file_name = 'test.tar'

def tearDown(self):
"""Clean up the temporary directory after each test."""
shutil.rmtree(self.temp_dir)

def create_tar(self, tar_path, file_count=1, file_contents=None, archive_name=None, with_traversal=False):
"""
Helper method to create a tar file with test data.
Args:
tar_path (str): Path to the tar file to be created.
file_count (int): Number of test files to include in the tar.
file_contents (list): Optional list of file contents.
archive_name (str): Optional name for files added to tar.
with_traversal (bool): If True, creates a path traversal tarball.
"""
file_contents = file_contents or [f'This is {self.test_file_template.format(i)}' for i in range(file_count)]
with tarfile.open(tar_path, 'w') as tar:
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)
with open(file_path, 'w') as f:
f.write(file_contents[i])

# Handle path traversal if requested
archive_name = archive_name or (f'../../{file_name}' if with_traversal else file_name)
tar.add(file_path, arcname=archive_name)

def verify_extracted_files(self, file_count):
"""
Helper method to verify files extracted from tar.
Args:
file_count (int): Number of files to verify.
"""
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)
self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.")
with open(file_path, 'r') as f:
content = f.read()
self.assertEqual(content, f'This is {file_name}', f"Content mismatch in {file_name}.")

def test_safe_extract(self):
"""Test normal safe extraction of tar files."""
self.run_safe_extract_test()

def test_safe_extract_with_traversal(self):
"""Test safe extraction for tar files with path traversal."""
self.run_safe_extract_test(with_traversal=True, expect_error=True)

def run_safe_extract_test(self, with_traversal=False, expect_error=False):
"""
Run a safe extract test with optional path traversal.
Args:
with_traversal (bool): If True, test for path traversal.
expect_error (bool): If True, expect an error during extraction.
"""
tar_path = os.path.join(self.temp_dir, self.tar_file_name)
self.create_tar(tar_path, file_count=3 if not with_traversal else 1, with_traversal=with_traversal)

extractor = SafeExtractor(self.temp_dir)
if expect_error:
with self.assertRaises(ValueError):
extractor.safe_extract(tar_path)
else:
extractor.safe_extract(tar_path)
self.verify_extracted_files(3)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1ba1030

Please sign in to comment.