forked from spotify/luigi
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix arbitrary file write during tarfile extraction
Fixes spotify#3302 spotify#3301
- Loading branch information
1 parent
74e6e63
commit 1ba1030
Showing
9 changed files
with
230 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright 2012-2015 Spotify AB | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while | ||
mitigating path traversal vulnerabilities, which can occur when files inside the archive are | ||
crafted to escape the intended extraction directory. | ||
The `SafeExtractor` ensures that the extracted file paths are validated before extraction to | ||
prevent malicious archives from extracting files outside the intended directory. | ||
Classes: | ||
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks. | ||
Usage Example: | ||
extractor = SafeExtractor("/desired/directory") | ||
extractor.safe_extract("archive.tar") | ||
""" | ||
|
||
import os | ||
import tarfile | ||
|
||
|
||
class SafeExtractor: | ||
""" | ||
A class to safely extract tar files, ensuring that no path traversal | ||
vulnerabilities are exploited. | ||
Attributes: | ||
path (str): The directory to extract files into. | ||
Methods: | ||
_is_within_directory(directory, target): | ||
Checks if a target path is within a given directory. | ||
safe_extract(tar_path, members=None, \\*, numeric_owner=False): | ||
Safely extracts the contents of a tar file to the specified directory. | ||
""" | ||
|
||
def __init__(self, path="."): | ||
""" | ||
Initializes the SafeExtractor with the specified directory path. | ||
Args: | ||
path (str): The directory to extract files into. Defaults to the current directory. | ||
""" | ||
self.path = path | ||
|
||
@staticmethod | ||
def _is_within_directory(directory, target): | ||
""" | ||
Checks if a target path is within a given directory. | ||
Args: | ||
directory (str): The directory to check against. | ||
target (str): The target path to check. | ||
Returns: | ||
bool: True if the target path is within the directory, False otherwise. | ||
""" | ||
abs_directory = os.path.abspath(directory) | ||
abs_target = os.path.abspath(target) | ||
prefix = os.path.commonprefix([abs_directory, abs_target]) | ||
return prefix == abs_directory | ||
|
||
def safe_extract(self, tar_path, members=None, *, numeric_owner=False): | ||
""" | ||
Safely extracts the contents of a tar file to the specified directory. | ||
Args: | ||
tar_path (str): The path to the tar file to extract. | ||
members (list, optional): A list of members to extract. Defaults to None. | ||
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False. | ||
Raises: | ||
ValueError: If a path traversal attempt is detected. | ||
""" | ||
with tarfile.open(tar_path, 'r') as tar: | ||
for member in tar.getmembers(): | ||
member_path = os.path.join(self.path, member.name) | ||
if not self._is_within_directory(self.path, member_path): | ||
raise ValueError("Attempted Path Traversal in Tar File") | ||
tar.extractall(self.path, members, numeric_owner=numeric_owner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright 2012-2015 Spotify AB | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
Safe Extractor Test | ||
============= | ||
Tests for the Safe Extractor class in luigi.safe_extractor module. | ||
""" | ||
|
||
import os | ||
import shutil | ||
import tarfile | ||
import tempfile | ||
import unittest | ||
|
||
from luigi.safe_extractor import SafeExtractor | ||
|
||
|
||
class TestSafeExtract(unittest.TestCase): | ||
""" | ||
Unit test class for testing the SafeExtractor module. | ||
""" | ||
|
||
def setUp(self): | ||
"""Set up temporary directory for test files.""" | ||
self.temp_dir = tempfile.mkdtemp() | ||
self.test_file_template = 'test_file_{}.txt' | ||
self.tar_file_name = 'test.tar' | ||
|
||
def tearDown(self): | ||
"""Clean up the temporary directory after each test.""" | ||
shutil.rmtree(self.temp_dir) | ||
|
||
def create_tar(self, tar_path, file_count=1, file_contents=None, archive_name=None, with_traversal=False): | ||
""" | ||
Helper method to create a tar file with test data. | ||
Args: | ||
tar_path (str): Path to the tar file to be created. | ||
file_count (int): Number of test files to include in the tar. | ||
file_contents (list): Optional list of file contents. | ||
archive_name (str): Optional name for files added to tar. | ||
with_traversal (bool): If True, creates a path traversal tarball. | ||
""" | ||
file_contents = file_contents or [f'This is {self.test_file_template.format(i)}' for i in range(file_count)] | ||
with tarfile.open(tar_path, 'w') as tar: | ||
for i in range(file_count): | ||
file_name = self.test_file_template.format(i) | ||
file_path = os.path.join(self.temp_dir, file_name) | ||
with open(file_path, 'w') as f: | ||
f.write(file_contents[i]) | ||
|
||
# Handle path traversal if requested | ||
archive_name = archive_name or (f'../../{file_name}' if with_traversal else file_name) | ||
tar.add(file_path, arcname=archive_name) | ||
|
||
def verify_extracted_files(self, file_count): | ||
""" | ||
Helper method to verify files extracted from tar. | ||
Args: | ||
file_count (int): Number of files to verify. | ||
""" | ||
for i in range(file_count): | ||
file_name = self.test_file_template.format(i) | ||
file_path = os.path.join(self.temp_dir, file_name) | ||
self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.") | ||
with open(file_path, 'r') as f: | ||
content = f.read() | ||
self.assertEqual(content, f'This is {file_name}', f"Content mismatch in {file_name}.") | ||
|
||
def test_safe_extract(self): | ||
"""Test normal safe extraction of tar files.""" | ||
self.run_safe_extract_test() | ||
|
||
def test_safe_extract_with_traversal(self): | ||
"""Test safe extraction for tar files with path traversal.""" | ||
self.run_safe_extract_test(with_traversal=True, expect_error=True) | ||
|
||
def run_safe_extract_test(self, with_traversal=False, expect_error=False): | ||
""" | ||
Run a safe extract test with optional path traversal. | ||
Args: | ||
with_traversal (bool): If True, test for path traversal. | ||
expect_error (bool): If True, expect an error during extraction. | ||
""" | ||
tar_path = os.path.join(self.temp_dir, self.tar_file_name) | ||
self.create_tar(tar_path, file_count=3 if not with_traversal else 1, with_traversal=with_traversal) | ||
|
||
extractor = SafeExtractor(self.temp_dir) | ||
if expect_error: | ||
with self.assertRaises(ValueError): | ||
extractor.safe_extract(tar_path) | ||
else: | ||
extractor.safe_extract(tar_path) | ||
self.verify_extracted_files(3) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |