From ad0bc0e73d811d04a81900b44949bce13d7697b7 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 28 Jun 2024 11:00:43 +0000 Subject: [PATCH] [stdlib] Add trait `CollectionElementNew` to a few structs Signed-off-by: gabrieldemarmiesse --- stdlib/src/builtin/dtype.mojo | 13 ++++++++++++- stdlib/src/builtin/error.mojo | 20 +++++++++++++++++++- stdlib/src/builtin/object.mojo | 14 ++++++++++++-- stdlib/src/builtin/string_literal.mojo | 10 ++++++++++ 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/stdlib/src/builtin/dtype.mojo b/stdlib/src/builtin/dtype.mojo index 5b9d4b649d..52f118e3e4 100644 --- a/stdlib/src/builtin/dtype.mojo +++ b/stdlib/src/builtin/dtype.mojo @@ -27,7 +27,9 @@ alias _mIsFloat = UInt8(1 << 6) @value @register_passable("trivial") -struct DType(Stringable, Formattable, Representable, KeyElement): +struct DType( + Stringable, Formattable, Representable, KeyElement, CollectionElementNew +): """Represents DType and provides methods for working with it.""" alias type = __mlir_type.`!kgen.dtype` @@ -83,6 +85,15 @@ struct DType(Stringable, Formattable, Representable, KeyElement): of the hardware's pointer type (32-bit on 32-bit machines and 64-bit on 64-bit machines).""" + @always_inline + fn __init__(inout self, *, other: Self): + """Copy this DType. + + Args: + other: The DType to copy. + """ + self = other + @always_inline("nodebug") fn __str__(self) -> String: """Gets the name of the DType. diff --git a/stdlib/src/builtin/error.mojo b/stdlib/src/builtin/error.mojo index d64bfbeedb..4c582bdd8a 100644 --- a/stdlib/src/builtin/error.mojo +++ b/stdlib/src/builtin/error.mojo @@ -26,7 +26,14 @@ from memory.memory import _free @register_passable -struct Error(Stringable, Boolable, Representable, Formattable): +struct Error( + Stringable, + Boolable, + Representable, + Formattable, + CollectionElement, + CollectionElementNew, +): """This type represents an Error.""" var data: UnsafePointer[UInt8] @@ -104,6 +111,17 @@ struct Error(Stringable, Boolable, Representable, Formattable): dest[length] = 0 return Error {data: dest, loaded_length: -length} + fn __init__(*, other: Self) -> Self: + """Copy the object. + + Args: + other: The value to copy. + + Returns: + The copied `Error`. + """ + return other + fn __del__(owned self): """Releases memory if allocated.""" if self.loaded_length < 0: diff --git a/stdlib/src/builtin/object.mojo b/stdlib/src/builtin/object.mojo index e313b7a895..615c7b9cd4 100644 --- a/stdlib/src/builtin/object.mojo +++ b/stdlib/src/builtin/object.mojo @@ -28,10 +28,11 @@ from utils import StringRef, Variant, unroll @register_passable("trivial") -struct _NoneMarker: +struct _NoneMarker(CollectionElementNew): """This is a trivial class to indicate that an object is `None`.""" - pass + fn __init__(inout self, *, other: Self): + pass @register_passable("trivial") @@ -316,6 +317,15 @@ struct _ObjectImpl(CollectionElement, Stringable): fn __init__(inout self, value: _RefCountedAttrsDictRef): self.value = Self.type(value) + @always_inline + fn __init__(inout self, *, other: Self): + """Copy the object. + + Args: + other: The value to copy. + """ + self = other.value + @always_inline fn __copyinit__(inout self, existing: Self): self = existing.value diff --git a/stdlib/src/builtin/string_literal.mojo b/stdlib/src/builtin/string_literal.mojo index 97ab53487d..89151d7864 100644 --- a/stdlib/src/builtin/string_literal.mojo +++ b/stdlib/src/builtin/string_literal.mojo @@ -41,6 +41,7 @@ struct StringLiteral( Boolable, Formattable, Comparable, + CollectionElementNew, ): """This type represents a string literal. @@ -68,6 +69,15 @@ struct StringLiteral( """ self.value = value + @always_inline("nodebug") + fn __init__(inout self, *, other: Self): + """Copy constructor. + + Args: + other: The string literal to copy. + """ + self = other + # ===-------------------------------------------------------------------===# # Operator dunders # ===-------------------------------------------------------------------===#