From 5850c445ace3c0f51243e573ffc4f6fa138e97bb Mon Sep 17 00:00:00 2001 From: Ali Razmjoo Date: Wed, 4 Sep 2024 15:19:54 +0200 Subject: [PATCH] Fix arbitrary file write during tarfile extraction Fixes #3302 #3301 --- luigi/contrib/lsf_runner.py | 8 ++-- luigi/contrib/sge_runner.py | 8 ++-- luigi/safe_extractor.py | 96 +++++++++++++++++++++++++++++++++++++ test/contrib/lsf_test.py | 44 +++++++++++++++++ 4 files changed, 146 insertions(+), 10 deletions(-) mode change 100755 => 100644 luigi/contrib/lsf_runner.py create mode 100644 luigi/safe_extractor.py diff --git a/luigi/contrib/lsf_runner.py b/luigi/contrib/lsf_runner.py old mode 100755 new mode 100644 index 5a6c8b5699..f483e7bf45 --- a/luigi/contrib/lsf_runner.py +++ b/luigi/contrib/lsf_runner.py @@ -28,7 +28,7 @@ except ImportError: import pickle import logging -import tarfile +from luigi.safe_extractor import SafeExtractor def do_work_on_compute_node(work_dir): @@ -52,10 +52,8 @@ def extract_packages_archive(work_dir): curdir = os.path.abspath(os.curdir) os.chdir(work_dir) - tar = tarfile.open(package_file) - for tarinfo in tar: - tar.extract(tarinfo) - tar.close() + extractor = SafeExtractor(work_dir) + extractor.safe_extract(package_file) if '' not in sys.path: sys.path.insert(0, '') diff --git a/luigi/contrib/sge_runner.py b/luigi/contrib/sge_runner.py index f0621fb475..2600f2d6dc 100755 --- a/luigi/contrib/sge_runner.py +++ b/luigi/contrib/sge_runner.py @@ -36,7 +36,7 @@ import sys import pickle import logging -import tarfile +from luigi.safe_extractor import SafeExtractor def _do_work_on_compute_node(work_dir, tarball=True): @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir): curdir = os.path.abspath(os.curdir) os.chdir(work_dir) - tar = tarfile.open(package_file) - for tarinfo in tar: - tar.extract(tarinfo) - tar.close() + extractor = SafeExtractor(work_dir) + extractor.safe_extract(package_file) if '' not in sys.path: sys.path.insert(0, '') diff --git a/luigi/safe_extractor.py b/luigi/safe_extractor.py new file mode 100644 index 0000000000..b4c279b193 --- /dev/null +++ b/luigi/safe_extractor.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides a class `SafeExtractor` that offers a secure way to extract tar files while +mitigating path traversal vulnerabilities, which can occur when files inside the archive are +crafted to escape the intended extraction directory. + +The `SafeExtractor` ensures that the extracted file paths are validated before extraction to +prevent malicious archives from extracting files outside the intended directory. + +Classes: + SafeExtractor: A class to securely extract tar files with protection against path traversal attacks. + +Usage Example: + extractor = SafeExtractor("/desired/directory") + extractor.safe_extract("archive.tar") +""" + +import os +import tarfile + + +class SafeExtractor: + """ + A class to safely extract tar files, ensuring that no path traversal + vulnerabilities are exploited. + + Attributes: + path (str): The directory to extract files into. + + Methods: + _is_within_directory(directory, target): + Checks if a target path is within a given directory. + + safe_extract(tar_path, members=None, *, numeric_owner=False): + Safely extracts the contents of a tar file to the specified directory. + """ + + def __init__(self, path="."): + """ + Initializes the SafeExtractor with the specified directory path. + + Args: + path (str): The directory to extract files into. Defaults to the current directory. + """ + self.path = path + + def _is_within_directory(self, directory, target): + """ + Checks if a target path is within a given directory. + + Args: + directory (str): The directory to check against. + target (str): The target path to check. + + Returns: + bool: True if the target path is within the directory, False otherwise. + """ + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory + + def safe_extract(self, tar_path, members=None, *, numeric_owner=False): + """ + Safely extracts the contents of a tar file to the specified directory. + + Args: + tar_path (str): The path to the tar file to extract. + members (list, optional): A list of members to extract. Defaults to None. + numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False. + + Raises: + ValueError: If a path traversal attempt is detected. + """ + with tarfile.open(tar_path, 'r') as tar: + for member in tar.getmembers(): + member_path = os.path.join(self.path, member.name) + if not self._is_within_directory(self.path, member_path): + raise ValueError("Attempted Path Traversal in Tar File") + tar.extractall(self.path, members, numeric_owner=numeric_owner) diff --git a/test/contrib/lsf_test.py b/test/contrib/lsf_test.py index dfbb7c0af5..b9fb0aaa49 100644 --- a/test/contrib/lsf_test.py +++ b/test/contrib/lsf_test.py @@ -33,8 +33,12 @@ import luigi from luigi.contrib.lsf import LSFJobTask +import tarfile +import tempfile +import shutil import pytest +from luigi.safe_extractor import SafeExtractor DEFAULT_HOME = '' @@ -103,5 +107,45 @@ def tearDown(self): pass +class TestSafeExtract(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_safe_extract(self): + tar_path = os.path.join(self.temp_dir, 'test.tar') + with tarfile.open(tar_path, 'w') as tar: + for i in range(3): + file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt') + with open(file_path, 'w') as f: + f.write(f'This is test file {i}') + tar.add(file_path, arcname=f'test_file_{i}.txt') + + extractor = SafeExtractor(self.temp_dir) + extractor.safe_extract(tar_path) + + for i in range(3): + file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt') + self.assertTrue(os.path.exists(file_path)) + with open(file_path, 'r') as f: + content = f.read() + self.assertEqual(content, f'This is test file {i}') + + def test_safe_extract_with_traversal(self): + tar_path = os.path.join(self.temp_dir, 'test.tar') + with tarfile.open(tar_path, 'w') as tar: + file_path = os.path.join(self.temp_dir, 'test_file.txt') + with open(file_path, 'w') as f: + f.write('This is a test file') + tar.add(file_path, arcname='../../test_file.txt') + + extractor = SafeExtractor(self.temp_dir) + with self.assertRaises(ValueError): + extractor.safe_extract(tar_path) + + if __name__ == '__main__': unittest.main()