Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix arbitrary file write during tarfile extraction in luigi/contrib/lsf_runner.py and luigi/contrib/sge_runner.py #3309

Merged
merged 1 commit into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions luigi/contrib/lsf_runner.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def do_work_on_compute_node(work_dir):
Expand All @@ -52,10 +52,8 @@
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)

Check warning on line 56 in luigi/contrib/lsf_runner.py

View check run for this annotation

Codecov / codecov/patch

luigi/contrib/lsf_runner.py#L55-L56

Added lines #L55 - L56 were not covered by tests
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/sge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import sys
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def _do_work_on_compute_node(work_dir, tarball=True):
Expand Down Expand Up @@ -64,10 +64,8 @@
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)

Check warning on line 68 in luigi/contrib/sge_runner.py

View check run for this annotation

Codecov / codecov/patch

luigi/contrib/sge_runner.py#L67-L68

Added lines #L67 - L68 were not covered by tests
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
97 changes: 97 additions & 0 deletions luigi/safe_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while
mitigating path traversal vulnerabilities, which can occur when files inside the archive are
crafted to escape the intended extraction directory.

The `SafeExtractor` ensures that the extracted file paths are validated before extraction to
prevent malicious archives from extracting files outside the intended directory.

Classes:
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks.

Usage Example:
extractor = SafeExtractor("/desired/directory")
extractor.safe_extract("archive.tar")
"""

import os
import tarfile


class SafeExtractor:
"""
A class to safely extract tar files, ensuring that no path traversal
vulnerabilities are exploited.

Attributes:
path (str): The directory to extract files into.

Methods:
_is_within_directory(directory, target):
Checks if a target path is within a given directory.

safe_extract(tar_path, members=None, \\*, numeric_owner=False):
Safely extracts the contents of a tar file to the specified directory.
"""

def __init__(self, path="."):
"""
Initializes the SafeExtractor with the specified directory path.

Args:
path (str): The directory to extract files into. Defaults to the current directory.
"""
self.path = path

@staticmethod
def _is_within_directory(directory, target):
"""
Checks if a target path is within a given directory.

Args:
directory (str): The directory to check against.
target (str): The target path to check.

Returns:
bool: True if the target path is within the directory, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(self, tar_path, members=None, *, numeric_owner=False):
"""
Safely extracts the contents of a tar file to the specified directory.

Args:
tar_path (str): The path to the tar file to extract.
members (list, optional): A list of members to extract. Defaults to None.
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False.

Raises:
RuntimeError: If a path traversal attempt is detected.
"""
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
member_path = os.path.join(self.path, member.name)
if not self._is_within_directory(self.path, member_path):
raise RuntimeError("Attempted Path Traversal in Tar File")
tar.extractall(self.path, members, numeric_owner=numeric_owner)
125 changes: 125 additions & 0 deletions test/safe_extractor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Safe Extractor Test
=============

Tests for the Safe Extractor class in luigi.safe_extractor module.
"""

import os
import shutil
import tarfile
import tempfile
import unittest

from luigi.safe_extractor import SafeExtractor


class TestSafeExtract(unittest.TestCase):
"""
Unit test class for testing the SafeExtractor module.
"""

def setUp(self):
"""Set up a temporary directory for test files."""
self.temp_dir = tempfile.mkdtemp()
self.test_file_template = 'test_file_{}.txt'
self.tar_file_name = 'test.tar'
self.tar_file_name_with_traversal = f'traversal_{self.tar_file_name}'

def tearDown(self):
"""Clean up the temporary directory after each test."""
shutil.rmtree(self.temp_dir)

def create_test_tar(self, tar_path, file_count=1, with_traversal=False):
"""
Create a tar file containing test files.

Args:
tar_path (str): Path where the tar file will be created.
file_count (int): Number of test files to include.
with_traversal (bool): If True, creates a tar file with path traversal vulnerability.
"""
# Default content for the test files
file_contents = [f'This is {self.test_file_template.format(i)}' for i in range(file_count)]

with tarfile.open(tar_path, 'w') as tar:
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)

# Write content to each test file
with open(file_path, 'w') as f:
f.write(file_contents[i])

# If path traversal is enabled, create malicious paths
archive_name = f'../../{file_name}' if with_traversal else file_name

# Add the file to the tar archive
tar.add(file_path, arcname=archive_name)

def verify_extracted_files(self, file_count):
"""
Verify that the correct files were extracted and their contents match expectations.

Args:
file_count (int): Number of files to verify.
"""
for i in range(file_count):
file_name = self.test_file_template.format(i)
file_path = os.path.join(self.temp_dir, file_name)

# Check if the file exists
self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.")

# Check if the file content is correct
with open(file_path, 'r') as f:
content = f.read()
expected_content = f'This is {file_name}'
self.assertEqual(content, expected_content, f"Content mismatch in {file_name}.")

def test_safe_extract(self):
"""Test normal safe extraction of tar files."""
tar_path = os.path.join(self.temp_dir, self.tar_file_name)

# Create a tar file with 3 files
self.create_test_tar(tar_path, file_count=3)

# Initialize SafeExtractor and perform extraction
extractor = SafeExtractor(self.temp_dir)
extractor.safe_extract(tar_path)

# Verify that all 3 files were extracted correctly
self.verify_extracted_files(3)

def test_safe_extract_with_traversal(self):
"""Test safe extraction for tar files with path traversal (should raise an error)."""
tar_path = os.path.join(self.temp_dir, self.tar_file_name_with_traversal)

# Create a tar file with a path traversal file
self.create_test_tar(tar_path, file_count=1, with_traversal=True)

# Initialize SafeExtractor and expect RuntimeError due to path traversal
extractor = SafeExtractor(self.temp_dir)
with self.assertRaises(RuntimeError):
extractor.safe_extract(tar_path)


if __name__ == '__main__':
unittest.main()
Loading