From c1ffcde9539c5ba5ead5e2cd0cbc6245978269cc Mon Sep 17 00:00:00 2001 From: Koutheir Attouchi Date: Wed, 17 Jul 2024 23:51:32 -0400 Subject: [PATCH] Use `Option>` instead of `*const DatabaseError`. Both types have the same ABI, but the first one is safer to deal with. --- src/idioms/ffi/errors.md | 46 +++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/src/idioms/ffi/errors.md b/src/idioms/ffi/errors.md index 3547b8f0..a1cf96f7 100644 --- a/src/idioms/ffi/errors.md +++ b/src/idioms/ffi/errors.md @@ -55,45 +55,43 @@ pub mod errors { pub mod c_api { use super::errors::DatabaseError; + use core::ptr; #[no_mangle] - pub extern "C" fn db_error_description(e: *const DatabaseError) -> *mut libc::c_char { - let error: &DatabaseError = unsafe { - // SAFETY: pointer lifetime is greater than the current stack frame - &*e - }; + pub extern "C" fn db_error_description( + e: Option>, + ) -> Option> { + // SAFETY: we assume that the lifetime of `e` is greater than + // the current stack frame. + let error = unsafe { e?.as_ref() }; let error_str: String = match error { DatabaseError::IsReadOnly => { - format!("cannot write to read-only database"); + format!("cannot write to read-only database") } DatabaseError::IOError(e) => { - format!("I/O Error: {e}"); + format!("I/O Error: {e}") } DatabaseError::FileCorrupted(s) => { - format!("File corrupted, run repair: {}", &s); + format!("File corrupted, run repair: {}", &s) } }; - let c_error = unsafe { - // SAFETY: copying error_str to an allocated buffer with a NUL - // character at the end - let mut malloc: *mut u8 = libc::malloc(error_str.len() + 1) as *mut _; - - if malloc.is_null() { - return std::ptr::null_mut(); - } - - let src = error_str.as_bytes().as_ptr(); - - std::ptr::copy_nonoverlapping(src, malloc, error_str.len()); + let error_bytes = error_str.as_bytes(); - std::ptr::write(malloc.add(error_str.len()), 0); - - malloc as *mut libc::c_char + let c_error = unsafe { + // SAFETY: copying error_bytes to an allocated buffer with a '\0' + // byte at the end. + let buffer = ptr::NonNull::::new(libc::malloc(error_bytes.len() + 1).cast())?; + + buffer + .as_ptr() + .copy_from_nonoverlapping(error_bytes.as_ptr(), error_bytes.len()); + buffer.as_ptr().add(error_bytes.len()).write(0_u8); + buffer }; - c_error + Some(c_error.cast()) } } ```