Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix python enum ambiguities #19324

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions compiler/bindings/python/iree/compiler/dialects/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._flow_ops_gen import *
from ._flow_enum_gen import *


@register_attribute_builder("builtin.FLOW_CollectiveElementTypeAttr")
def _flow_collectiveelementtypeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.FLOW_CollectiveReductionOpAttr")
def _flow_collectivereductionopattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
90 changes: 90 additions & 0 deletions compiler/bindings/python/iree/compiler/dialects/hal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,93 @@

from ._hal_ops_gen import *
from ._hal_enum_gen import *


@register_attribute_builder("builtin.HAL_AccessScopeBitfieldAttr")
def _hal_accessscopebitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_BufferUsageBitfieldAttr")
def _hal_bufferusagebitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_CallingConventionAttr")
def _hal_callingconventionattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_CollectiveElementTypeAttr")
def _hal_collectiveelementtypeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_CollectiveKindAttr")
def _hal_collectivekindattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_CollectiveReductionOpAttr")
def _hal_collectivereductionopattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_CommandBufferModeBitfieldAttr")
def _hal_commandbuffermodebitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_CommandCategoryBitfieldAttr")
def _hal_commandcategorybitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_DescriptorFlagsAttr")
def _hal_descriptorflagsattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_DispatchFlagsAttr")
def _hal_dispatchflagsattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))


@register_attribute_builder("builtin.HAL_ExecutionBarrierFlagBitfieldAttr")
def _hal_executionbarrierflagbitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_ExecutionStageBitfieldAttr")
def _hal_executionstagebitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_FenceFlagBitfieldAttr")
def _hal_fenceflagbitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_MemoryAccessBitfieldAttr")
def _hal_memoryaccessbitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_MemoryModelAttr")
def _hal_memorymodelattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_MemoryTypeBitfieldAttr")
def _hal_memorytypebitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.HAL_PipelineLayoutFlagsAttr")
def _hal_pipelinelayoutflagsattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.anonymous_538")
def _anonymous_538(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
11 changes: 11 additions & 0 deletions compiler/bindings/python/iree/compiler/dialects/iree_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._iree_codegen_ops_gen import *
from ._iree_codegen_enum_gen import *
from .._mlir_libs._ireeCompilerDialects.iree_codegen import *


@register_attribute_builder("builtin.DispatchLoweringPassPipelineEnum")
def _dispatchloweringpasspipelineenum(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.WorkgroupIdEnum")
def _workgroupidenum(x, context):
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
51 changes: 51 additions & 0 deletions compiler/bindings/python/iree/compiler/dialects/iree_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,57 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._iree_gpu_ops_gen import *
from ._iree_gpu_enum_gen import *
from .._mlir_libs._ireeCompilerDialects.iree_gpu import *


@register_attribute_builder("builtin.IREEGPU_ComputeBitwidths")
def _ireegpu_computebitwidths(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_DotProductOps")
def _ireegpu_dotproductops(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_MMAFragment")
def _ireegpu_mmafragment(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_MMAIntrinsic")
def _ireegpu_mmaintrinsic(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_MMAScope")
def _ireegpu_mmascope(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_ReorderWorkgroupsStrategy")
def _ireegpu_reorderworkgroupsstrategy(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_StorageBitwidths")
def _ireegpu_storagebitwidths(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_SubgroupOps")
def _ireegpu_subgroupops(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_TilingLevel")
def _ireegpu_tilinglevel(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.IREEGPU_VirtualMMAIntrinsic")
def _ireegpu_virtualmmaintrinsic(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
36 changes: 36 additions & 0 deletions compiler/bindings/python/iree/compiler/dialects/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,41 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._stream_ops_gen import *
from ._stream_enum_gen import *


@register_attribute_builder("builtin.Stream_CollectiveElementTypeAttr")
def _stream_collectiveelementtypeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.Stream_CollectiveKindAttr")
def _stream_collectivekindattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.Stream_CollectiveReductionOpAttr")
def _stream_collectivereductionopattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.Stream_FavorAttr")
def _stream_favorattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.Stream_LifetimeAttr")
def _stream_lifetimeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.Stream_MemoryModelAttr")
def _stream_memorymodelattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))


@register_attribute_builder("builtin.Stream_ResourceAccessBitfieldAttr")
def _stream_resourceaccessbitfieldattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
16 changes: 16 additions & 0 deletions compiler/bindings/python/iree/compiler/dialects/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._vm_ops_gen import *
from ._vm_enum_gen import *


@register_attribute_builder("builtin.VM_CoreOpcodeAttr")
def _vm_coreopcodeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(8, context=context), int(x))


@register_attribute_builder("builtin.VM_ExtF32OpcodeAttr")
def _vm_extf32opcodeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(8, context=context), int(x))


@register_attribute_builder("builtin.VM_ExtF64OpcodeAttr")
def _vm_extf64opcodeattr(x, context):
return IntegerAttr.get(IntegerType.get_signless(8, context=context), int(x))
8 changes: 8 additions & 0 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,11 @@ def compilation_info():
assert compilation_info is not None
assert compilation_info.lowering_config == lowering_config
assert compilation_info.translation_info == translation_info


@run
def enum_collision():
from iree.compiler.dialects import linalg, vector

linalg_iter_type_e = linalg._iteratortype(0, None)
vector_iter_type_e = vector._vector_iteratortype(0, None)
Loading