From e65eee11e3bec2534a6abe5f8b9bc676203fad6c Mon Sep 17 00:00:00 2001 From: ehsanmok <6980212+ehsanmok@users.noreply.github.com> Date: Fri, 27 Sep 2024 14:30:49 -0700 Subject: [PATCH] Fixes and content modification --- .../generic_safe_buffer.mojo | 21 ++++++++++------- .../hands_on_with_mojo_24_5/safe_buffer.mojo | 16 +++++-------- .../test_generic_safe_buffer.mojo | 1 + .../unsafe_buffer.mojo | 23 ++++++++----------- 4 files changed, 29 insertions(+), 32 deletions(-) diff --git a/blogs/hands_on_with_mojo_24_5/generic_safe_buffer.mojo b/blogs/hands_on_with_mojo_24_5/generic_safe_buffer.mojo index 3db74bb..5392cfc 100644 --- a/blogs/hands_on_with_mojo_24_5/generic_safe_buffer.mojo +++ b/blogs/hands_on_with_mojo_24_5/generic_safe_buffer.mojo @@ -36,14 +36,13 @@ struct SafeBuffer[T: CollectionElement]: fn __del__(owned self): self._data.free() - fn _get_ref(ref [_]self: Self, index: Int) -> Reference[Optional[T], __lifetime_of(self)]: - return Reference[Optional[T], __lifetime_of(self)](self._data[index]) - fn write(inout self, index: Int, value: Optional[T]): - self._get_ref(index)[] = value + debug_assert(0 <= index < self.size, "index must be within the buffer") + self._data[index] = value fn read(self, index: Int) -> Optional[T]: - return self._get_ref(index)[] + debug_assert(0 <= index < self.size, "index must be within the buffer") + return self._data[index] fn __str__[U: StringableFormattableCollectionElement](self: SafeBuffer[U]) -> String: ret = String() @@ -73,6 +72,7 @@ struct SafeBuffer[T: CollectionElement]: fn take(inout self, index: Int) -> Optional[T] as output: output = self.read(index) self.write(index, Optional[T](None)) + return fn process_buffers[T: CollectionElement](buffer1: SafeBuffer[T], inout buffer2: SafeBuffer[T]): @@ -82,9 +82,14 @@ fn process_buffers[T: CollectionElement](buffer1: SafeBuffer[T], inout buffer2: struct NotStringableNorFormattable(CollectionElement): - fn __init__(inout self): ... - fn __copyinit__(inout self, existing: Self): ... - fn __moveinit__(inout self, owned existing: Self): ... + fn __init__(inout self): + ... + + fn __copyinit__(inout self, existing: Self): + ... + + fn __moveinit__(inout self, owned existing: Self): + ... def main(): diff --git a/blogs/hands_on_with_mojo_24_5/safe_buffer.mojo b/blogs/hands_on_with_mojo_24_5/safe_buffer.mojo index d504a90..b842317 100644 --- a/blogs/hands_on_with_mojo_24_5/safe_buffer.mojo +++ b/blogs/hands_on_with_mojo_24_5/safe_buffer.mojo @@ -22,14 +22,13 @@ struct SafeBuffer(Stringable, Formattable): fn __del__(owned self): self._data.free() - fn _get_ref(ref [_]self: Self, index: Int) -> Reference[UInt8, __lifetime_of(self)]: - return Reference[UInt8, __lifetime_of(self)](self._data[index]) - fn write(inout self, index: Int, value: UInt8): - self._get_ref(index)[] = value + debug_assert(0 <= index < self.size, "index must be within the buffer") + self._data[index] = value fn read(self, index: Int) -> UInt8: - return self._get_ref(index)[] + debug_assert(0 <= index < self.size, "index must be within the buffer") + return self._data[index] fn __str__(self) -> String: return String.format_sequence(self) @@ -53,14 +52,11 @@ fn process_buffers(buffer1: SafeBuffer, inout buffer2: SafeBuffer): def main(): sb = SafeBuffer(10) sb.write(0, 255) - sb.write(1, 128) - print("safe buffer outputs:") + print("value at index 0 after getting set to 255:") print(sb.read(0)) - print(sb.read(1)) buffer1 = SafeBuffer.initialize_with_value(size=10, value=128) buffer2 = SafeBuffer(10) # process_buffers(buffer1, buffer1) # <-- argument exclusivity detects such errors at compile time process_buffers(buffer1, buffer2) - print("buffer2:") - print(buffer2) + print("buffer2:", buffer2) diff --git a/blogs/hands_on_with_mojo_24_5/test_generic_safe_buffer.mojo b/blogs/hands_on_with_mojo_24_5/test_generic_safe_buffer.mojo index 99eb58e..ea2327e 100644 --- a/blogs/hands_on_with_mojo_24_5/test_generic_safe_buffer.mojo +++ b/blogs/hands_on_with_mojo_24_5/test_generic_safe_buffer.mojo @@ -2,6 +2,7 @@ from testing import assert_equal from generic_safe_buffer import SafeBuffer + def test_buffer(): buffer = SafeBuffer[String].initialize_with_value(size=5, value=String("hi")) val = buffer.take(2).value() diff --git a/blogs/hands_on_with_mojo_24_5/unsafe_buffer.mojo b/blogs/hands_on_with_mojo_24_5/unsafe_buffer.mojo index 122c874..5787a0b 100644 --- a/blogs/hands_on_with_mojo_24_5/unsafe_buffer.mojo +++ b/blogs/hands_on_with_mojo_24_5/unsafe_buffer.mojo @@ -1,28 +1,23 @@ +from memory import memset_zero + + struct UnsafeBuffer: var data: UnsafePointer[UInt8] var size: Int fn __init__(inout self, size: Int): self.data = UnsafePointer[UInt8].alloc(size) + memset_zero(self.data, size) self.size = size - fn write(inout self, index: Int, value: UInt8): - # note `self.data` is uninitialized so we have to use `init_pointee_copy/move` - # methods to safely initialize the allocated memory - self.data.init_pointee_copy(value) - - fn read(self, index: Int) -> UInt8: - return self.data[index] - fn __del__(owned self): self.data.free() def main(): ub = UnsafeBuffer(10) - ub.write(0, 255) - ub.write(1, 128) - print("unsafe buffer outputs:") - print(ub.read(0)) - print("the data of the unsafe buffer is freed here bc there's no lifetime associate with it") - print(ub.read(1)) + print("initial value at index 0:") + print(ub.data[0]) + ub.data[0] = 255 + print("value at index 0 after getting set to 255:") + print(ub.data[0])