diff --git a/botocore/utils.py b/botocore/utils.py index 3b5d9e7a20..30cc6013f0 100644 --- a/botocore/utils.py +++ b/botocore/utils.py @@ -22,6 +22,7 @@ import random import re import socket +import tempfile import time import warnings import weakref @@ -3577,11 +3578,40 @@ def __setitem__(self, cache_key, value): ) if not os.path.isdir(self._working_dir): os.makedirs(self._working_dir, exist_ok=True) - with os.fdopen( - os.open(full_key, os.O_WRONLY | os.O_CREAT, 0o600), 'w' - ) as f: - f.truncate() - f.write(file_content) + try: + temp_fd, temp_path = tempfile.mkstemp( + dir=self._working_dir, suffix='.tmp' + ) + # if os.name == 'posix': + # os.fchmod(temp_fd, 0o600) + # else: + # os.chmod(temp_path, stat.S_IREAD | stat.S_IWRITE) + # if os.chmod: + # os.chmod(temp_path, stat.S_IREAD | stat.S_IWRITE) + + # if hasattr(os, 'fchmod'): + # os.fchmod(temp_fd, 0o600) + with os.fdopen(temp_fd, 'w') as f: + temp_fd = None + f.write(file_content) + f.flush() + os.fsync(f.fileno()) + + os.replace(temp_path, full_key) + temp_path = None + + except Exception: + if temp_fd is not None: + try: + os.close(temp_fd) + except OSError: + pass + if temp_path is not None and os.path.exists(temp_path): + try: + os.unlink(temp_path) + except OSError: + pass + raise def _convert_cache_key(self, cache_key): full_path = os.path.join(self._working_dir, cache_key + '.json') diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index ad37324ad4..1ec33c58d1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -14,6 +14,9 @@ import datetime import io import operator +import os +import shutil +import tempfile from contextlib import contextmanager from sys import getrefcount @@ -59,6 +62,7 @@ InstanceMetadataFetcher, InstanceMetadataRegionFetcher, InvalidArnException, + JSONFileCache, S3ArnParamHandler, S3EndpointSetter, S3RegionRedirectorv2, @@ -3679,3 +3683,88 @@ def test_get_token_from_environment_returns_none( ): monkeypatch.delenv(env_var, raising=False) assert get_token_from_environment(signing_name) is None + + +class TestJSONFileCacheAtomicWrites(unittest.TestCase): + """Test atomic write operations in JSONFileCache.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.cache = JSONFileCache(working_dir=self.temp_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + @mock.patch('os.replace') + def test_uses_tempfile_and_replace_for_atomic_write(self, mock_replace): + self.cache['test_key'] = {'data': 'test_value'} + mock_replace.assert_called_once() + + call_args = mock_replace.call_args[0] + temp_path = call_args[0] + # final_path = call_args[1] + + assert '.tmp' in temp_path + + def test_concurrent_writes_to_multiple_temp_files(self): + """Test concurrent writes to same key don't cause corruption.""" + import threading + + errors = [] + + def write_worker(thread_id): + try: + key = f'concurrent_test_{thread_id}' + for i in range(3): + self.cache[key] = {'thread': thread_id, 'iteration': i} + except Exception as e: + errors.append(f'Thread {thread_id}: {e}') + + threads = [ + threading.Thread(target=write_worker, args=(i,)) for i in range(3) + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(len(errors), 0, f'Concurrent write errors: {errors}') + + for thread_id in range(3): + key = f'concurrent_test_{thread_id}' + final_data = self.cache[key] + self.assertIsInstance(final_data, dict) + self.assertEqual(final_data['thread'], thread_id) + self.assertIn('thread', final_data) + self.assertIn('iteration', final_data) + + def test_atomic_write_preserves_data_on_failure(self): + """Test write failures don't corrupt existing data.""" + key = 'atomic_test' + original_data = {'status': 'original'} + + self.cache[key] = original_data + + # Mock write failure + original_dumps = self.cache._dumps + self.cache._dumps = mock.Mock(side_effect=ValueError('Write failed')) + + with self.assertRaises(ValueError): + self.cache[key] = {'status': 'should_fail'} + + self.cache._dumps = original_dumps + + # Verify original data intact + self.assertEqual(self.cache[key], original_data) + + def test_no_temp_files_after_write(self): + """Test temporary files cleaned up after writes.""" + self.cache['test'] = {'data': 'value'} + + temp_files = [ + f for f in os.listdir(self.temp_dir) if f.endswith('.tmp') + ] + self.assertEqual( + len(temp_files), 0, f'Temp files not cleaned: {temp_files}' + )