Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: oneDNN wrapper based on oneDNN_jll
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 9, 2024
1 parent 40d9192 commit 475ac00
Show file tree
Hide file tree
Showing 11 changed files with 8,049 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ scripts
test_ext

benchmarks/results

deps
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ numer = "numer"
nd = "nd"
Ba = "Ba"
skipt = "skipt"
abd = "abd"
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.0"
version = "1.3.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
Expand All @@ -29,6 +30,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
oneDNN_jll = "3523a63d-8698-5b6f-b2c2-68eaa6bde0f0"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand All @@ -55,6 +57,7 @@ AMDGPU = "0.9.6, 1"
AppleAccelerate = "0.4"
ArrayInterface = "7.9"
BLISBLAS = "0.1"
CEnum = "0.5.0"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15.0"
Expand Down Expand Up @@ -85,3 +88,4 @@ Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
cuDNN = "1.3"
julia = "1.10"
oneDNN_jll = "3.5.3"
7 changes: 7 additions & 0 deletions generators/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
oneDNN_jll = "3523a63d-8698-5b6f-b2c2-68eaa6bde0f0"

[compat]
Clang = "0.18"
oneDNN_jll = "3.5.3"
1 change: 1 addition & 0 deletions generators/epilogue.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#! format: on
212 changes: 212 additions & 0 deletions generators/generator.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
[general]
# it could also be an expression as long as `Meta.parse` can parse this string successfully.
# basically, it should be the `expression` in the following code:
# ccall((function_name, expression), returntype, (argtype1, ...), argvalue1, ...)
library_name = "libdnnl"

# this entry allows you to specify different library names for different headers.
# in the following example:
# library_names = {"config.h" = "libclang_config", "libclang_p.*.h" = "libclang_patch"}
# those functions in the `config.h` will be generated as:
# ccall((function_name, libclang_config), returntype, (argtype1, ...), argvalue1, ...)
library_names = {}

# output file path relative to the working directory
output_file_path = "../src/onednn/lib.jl"

# if these are set, common file (types and constants) and API file (functions) will be separated
# this is for compatibility, so prologue and epilogue are not supported.
# output_api_file_path = "api.jl"
# output_common_file_path = "common.jl"

# if this entry is not empty, the generator will print the code below to the `output_file_path`.
# module module_name
#
# end # module
module_name = "Lib"

# if this entry is not empty, the generator will print the code below to the `output_file_path`.
# using jll_pkg_name
# export jll_pkg_name
jll_pkg_name = "oneDNN_jll"

# for packages that have extra JLL package dependencies
jll_pkg_extra = []

# identifiers that starts with the string listed in this entry will be exported.
export_symbol_prefixes = ["CX", "clang_"]

# the code in the following file will be copy-pasted to `output_file_path` before the generated code.
# this is often used for applying custom patches, e.g. adding missing definitions.
prologue_file_path = "prologue.jl"

# the code in the following file will be copy-pasted to `output_file_path` after the generated code.
# this is often used for applying custom patches.
epilogue_file_path = "epilogue.jl"

# node with an id in the `output_ignorelist` will be ignored in the printing passes.
# this is very useful for custom editing.
output_ignorelist = [
"CINDEX_EXPORTS",
"CINDEX_VERSION",
"CINDEX_VERSION_STRING",
"CINDEX_LINKAGE",
"CINDEX_DEPRECATED",
"LLVM_CLANG_C_STRICT_PROTOTYPES_BEGIN",
"LLVM_CLANG_C_STRICT_PROTOTYPES_END",
"LLVM_CLANG_C_EXTERN_C_BEGIN",
"LLVM_CLANG_C_EXTERN_C_END"
]

# Julia's `@enum` do not allow duplicated values, so by default, C enums are translated to
# CEnum.jl's `@cenum`.
# if this entry is true, `@enum` is used and those duplicated enum constants are just commented.
use_julia_native_enum_type = false

# use `@cenum` but do not print `using CEnum`.
# this is useful in the case of using `CEnum` directly in the source tree instead of using `CEnum` as a dependency
print_using_CEnum = false

# Print enums directly as integers without @(c)enum wrapper
# Override above two options
print_enum_as_integer = false

# use deterministic symbol instead of `gensym`-generated `var"##XXX"`
use_deterministic_symbol = true

# by default, only those declarations in the local header file are processed.
# those declarations in the system headers will be treated specially and will be generated if necessary.
# if you'd like to generate all of the symbols in the system headers, please set this option to false.
is_local_header_only = true

# set this option to false if you'd like to ignore the symbols(even if necessary) in the system headers.
generate_isystem_symbols = true

# if this option is set to true, C code with a style of
# ```c
# typedef struct {
# int x;
# } my_struct;
# ```
# will be generated as:
# ```julia
# struct my_struct
# x::Cint
# end
# ```
# instead of
# ```julia
# struct var"##Ctag#NUM"
# x::Cint
# end
# const my_struct = var"##Ctag#NUM"
# ```
smart_de_anonymize = true

# if set to true, static functions will be ignored
skip_static_functions = false

# EXPERIMENTAL
# if this option is set to true, those structs that are not necessary to be an
# immutable struct will be generated as a mutable struct.
# this option is default to false, do read the paragraph below before using this feature.
auto_mutability = false

# add inner constructor `Foo() = new()`
auto_mutability_with_new = true

# if you feel like certain structs should not be generated as mutable struct, please add them in the following list.
# for example, if a C function accepts a `Vector` of some type as its argument like:
# void foo(mutable_type *list, int n);
# when calling this function via `ccall`, passing a `Vector{mutable_type}(undef, n)` to the first
# argument will trigger a crash, the reason is mutable structs are not stored inline within a `Vector`,
# one should use `Ref{NTuple{n,mutable_type}}()` instead.
# this is not convenient and that's where the `auto_mutability_ignorelist` comes in.
auto_mutability_ignorelist = []

# opposite to `auto_mutability_ignorelist` and has a higher priority
auto_mutability_includelist = []

# if set to "raw", extract and dump raw c comment;
# if set to "doxygen", parse and format doxygen comment.
# note: by default, Clang only parses doxygen comment, pass `-fparse-all-comments` to Clang in order to parse non-doxygen comments.
extract_c_comment_style = "doxygen"

# Pass a function to explicitly generate documentation. It will be called like
# `callback_documentation(node::ExprNode, doc::Vector{String})` if it is
# set. The `doc` argument will contain the docs parsed from the headers if
# `extract_c_comment_style` is set, otherwise it will be an empty vector.
#
# Do *not* set this in the TOML file, it should be set in the generator script
# to a function that takes in an ExprNode and returns a String[] (string
# vector).
# callback_documentation = ""

# if set to true, single line comment will be printed as """comment""" instead of """\ncomment\n"""
fold_single_line_comment = false

# if set to "outofline", documentation of struct fields will be collected at the "Fields" section of the struct
# if set to "inline", documentation of struct fields will go right above struct definition
struct_field_comment_style = "outofline"

# if set to "outofline", documentation of enumerators will be collected at the "Enumerators" section of the enum
enumerator_comment_style = "outofline"

# if set to true, C function prototype will be included in documentation
show_c_function_prototype = false

[codegen]
# map C's bool to Julia's Bool instead of `Cuchar` a.k.a `UInt8`.
use_julia_bool = true

# set this to true if the C routine always expects a NUL-terminated string.
# TODO: support filtering
always_NUL_terminated_string = true

# generate strictly typed function
is_function_strictly_typed = false

# if true, opaque pointers in function arguments will be translated to `Ptr{Cvoid}`.
opaque_func_arg_as_PtrCvoid = false

# if true, opaque types are translated to `mutable struct` instead of `Cvoid`.
opaque_as_mutable_struct = true

# if true, use Julia 1.5's new `@ccall` macro
use_ccall_macro = true

# if true, variadic functions are wrapped with `@ccall` macro. Otherwise variadic functions are ignored.
wrap_variadic_function = false

# generate getproperty/setproperty! methods for the types in the following list
field_access_method_list = []

# the generator will prefix the function argument names in the following list with a "_" to
# prevent the generated symbols from conflicting with the symbols defined and exported in Base.
function_argument_conflict_symbols = []

# emit constructors for all custom-layout structs like bitfield in the list,
# or set to `true` to do so for all such structs
add_record_constructors = []

[codegen.macro]
# it‘s highly recommended to set this entry to "basic".
# if you'd like to skip all of the macros, please set this entry to "disable".
# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive".
macro_mode = "basic"

# function-like macros in the following list will always be translated.
functionlike_macro_includelist = [
"CINDEX_VERSION_ENCODE",
]

# if true, the generator prints the following message as comments.
# "# Skipping MacroDefinition: ..."
add_comment_for_skipped_macro = true

# if true, ignore any macros that is suffixed with "_H" or in the `ignore_header_guards_with_suffixes` list
ignore_header_guards = true
ignore_header_guards_with_suffixes = []

# if true, ignore those pure definition macros in the C code
ignore_pure_definition = true
8 changes: 8 additions & 0 deletions generators/prologue.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using CEnum: @cenum

const NULL = C_NULL

# This file is automatically generated by Clang.jl. Don't edit it manually. If needed,
# look at the "generators/" directory and modify the relevant files there.

#! format: off
61 changes: 61 additions & 0 deletions generators/wrap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using Clang.Generators
using oneDNN_jll

cur_dir = pwd()

cd(@__DIR__)

include_dir = joinpath(oneDNN_jll.artifact_dir, "include")

options = load_options(joinpath(@__DIR__, "generator.toml"))

onednn_headers = [
joinpath(include_dir, "dnnl.h"),
joinpath(include_dir, "dnnl_types.h"),
joinpath(include_dir, "dnnl_config.h"),
joinpath(include_dir, "dnnl_version.h")
]

args = get_default_args()
push!(args, "-I$include_dir")

ctx = create_context(onednn_headers, args, options)

# run generator
build!(ctx, BUILDSTAGE_NO_PRINTING)

function rewrite!(e::Expr)
# const DNNL_RUNTIME_SIZE_VAL = size_t(DNNL_RUNTIME_DIM_VAL)
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
e.args[1].args[1] == :DNNL_RUNTIME_SIZE_VAL && e.args[1].args[2] isa Expr &&
e.args[1].args[2].head == :call && e.args[1].args[2].args[1] == :size_t &&
e.args[1].args[2].args[2] == :DNNL_RUNTIME_DIM_VAL
e.args[1].args[2] = unsigned(typemin(Int64))
return
end
# const DNNL_RUNTIME_DIM_VAL = INT64_MIN
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
e.args[1].args[1] == :DNNL_RUNTIME_DIM_VAL && e.args[1].args[2] == :INT64_MIN
e.args[1].args[2] = typemin(Int64)
return
end
# const DNNL_RUNTIME_S32_VAL = DNNL_RUNTIME_S32_VAL_REP
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
e.args[1].args[1] == :DNNL_RUNTIME_S32_VAL && e.args[1].args[2] == :DNNL_RUNTIME_S32_VAL_REP
e.args[1].args[2] = 0
return
end
return
end

function rewrite!(dag::ExprDAG)
for node in get_nodes(dag), expr in get_exprs(node)
rewrite!(expr)
end
end

rewrite!(ctx.dag)

build!(ctx, BUILDSTAGE_PRINTING_ONLY)

cd(cur_dir)
4 changes: 4 additions & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ const CRC = ChainRulesCore

include("utils.jl")
include("traits.jl")

include("onednn/oneDNN.jl")

include("impl/Impl.jl")

include("api/API.jl")

@compat(public,
Expand Down
Loading

0 comments on commit 475ac00

Please sign in to comment.