Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions botocore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import random
import re
import socket
import tempfile
import time
import warnings
import weakref
Expand Down Expand Up @@ -3577,11 +3578,40 @@ def __setitem__(self, cache_key, value):
)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir, exist_ok=True)
with os.fdopen(
os.open(full_key, os.O_WRONLY | os.O_CREAT, 0o600), 'w'
) as f:
f.truncate()
f.write(file_content)
try:
temp_fd, temp_path = tempfile.mkstemp(
dir=self._working_dir, suffix='.tmp'
)
# if os.name == 'posix':
# os.fchmod(temp_fd, 0o600)
# else:
# os.chmod(temp_path, stat.S_IREAD | stat.S_IWRITE)
# if os.chmod:
# os.chmod(temp_path, stat.S_IREAD | stat.S_IWRITE)

# if hasattr(os, 'fchmod'):
# os.fchmod(temp_fd, 0o600)
with os.fdopen(temp_fd, 'w') as f:
temp_fd = None
f.write(file_content)
f.flush()
os.fsync(f.fileno())

os.replace(temp_path, full_key)
temp_path = None

except Exception:
if temp_fd is not None:
try:
os.close(temp_fd)
except OSError:
pass
if temp_path is not None and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except OSError:
pass
raise

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
Expand Down
89 changes: 89 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import datetime
import io
import operator
import os
import shutil
import tempfile
from contextlib import contextmanager
from sys import getrefcount

Expand Down Expand Up @@ -59,6 +62,7 @@
InstanceMetadataFetcher,
InstanceMetadataRegionFetcher,
InvalidArnException,
JSONFileCache,
S3ArnParamHandler,
S3EndpointSetter,
S3RegionRedirectorv2,
Expand Down Expand Up @@ -3679,3 +3683,88 @@ def test_get_token_from_environment_returns_none(
):
monkeypatch.delenv(env_var, raising=False)
assert get_token_from_environment(signing_name) is None


class TestJSONFileCacheAtomicWrites(unittest.TestCase):
"""Test atomic write operations in JSONFileCache."""

def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.cache = JSONFileCache(working_dir=self.temp_dir)

def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)

@mock.patch('os.replace')
def test_uses_tempfile_and_replace_for_atomic_write(self, mock_replace):
self.cache['test_key'] = {'data': 'test_value'}
mock_replace.assert_called_once()

call_args = mock_replace.call_args[0]
temp_path = call_args[0]
# final_path = call_args[1]

assert '.tmp' in temp_path

def test_concurrent_writes_to_multiple_temp_files(self):
"""Test concurrent writes to same key don't cause corruption."""
import threading

errors = []

def write_worker(thread_id):
try:
key = f'concurrent_test_{thread_id}'
for i in range(3):
self.cache[key] = {'thread': thread_id, 'iteration': i}
except Exception as e:
errors.append(f'Thread {thread_id}: {e}')

threads = [
threading.Thread(target=write_worker, args=(i,)) for i in range(3)
]

for thread in threads:
thread.start()
for thread in threads:
thread.join()

self.assertEqual(len(errors), 0, f'Concurrent write errors: {errors}')

for thread_id in range(3):
key = f'concurrent_test_{thread_id}'
final_data = self.cache[key]
self.assertIsInstance(final_data, dict)
self.assertEqual(final_data['thread'], thread_id)
self.assertIn('thread', final_data)
self.assertIn('iteration', final_data)

def test_atomic_write_preserves_data_on_failure(self):
"""Test write failures don't corrupt existing data."""
key = 'atomic_test'
original_data = {'status': 'original'}

self.cache[key] = original_data

# Mock write failure
original_dumps = self.cache._dumps
self.cache._dumps = mock.Mock(side_effect=ValueError('Write failed'))

with self.assertRaises(ValueError):
self.cache[key] = {'status': 'should_fail'}

self.cache._dumps = original_dumps

# Verify original data intact
self.assertEqual(self.cache[key], original_data)

def test_no_temp_files_after_write(self):
"""Test temporary files cleaned up after writes."""
self.cache['test'] = {'data': 'value'}

temp_files = [
f for f in os.listdir(self.temp_dir) if f.endswith('.tmp')
]
self.assertEqual(
len(temp_files), 0, f'Temp files not cleaned: {temp_files}'
)
Loading