Skip to content

Commit

Permalink
Merge pull request #8 from openai/main
Browse files Browse the repository at this point in the history
pp
  • Loading branch information
1proprogrammerchant authored Aug 11, 2023
2 parents d077673 + 0f91775 commit 9d340b0
Show file tree
Hide file tree
Showing 231 changed files with 28,615 additions and 2,379 deletions.
22 changes: 19 additions & 3 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]'
echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]'
echo '::set-output name=matrix-optional::[]'
else
echo '::set-output name=matrix-required::["ubuntu-latest"]'
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
Expand All @@ -50,6 +50,8 @@ jobs:
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}"
- name: Clear cache
run: |
Expand Down Expand Up @@ -79,8 +81,22 @@ jobs:
fi
lit -v "${LIT_TEST_DIR}"
- name: Run python tests on CUDA
if: ${{ env.BACKEND == 'CUDA'}}
- name: Enable MMAV3 and TMA
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}}
run: |
echo "ENABLE_TMA=1" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}"
- name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/
- name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ venv.bak/
# JetBrains project files
.idea
cmake-build-*

# Third-party binaries
ptxas
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO
Expand Down
3 changes: 3 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
Expand All @@ -29,6 +30,7 @@ target_link_libraries(triton-reduce PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
Expand All @@ -48,6 +50,7 @@ llvm_update_compile_flags(triton-translate)
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO
Expand Down
6 changes: 6 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#pragma once
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include "triton/Conversion/NVGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"

Expand All @@ -23,15 +26,18 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::registerTritonNvidiaGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
mlir::triton::registerConvertNVGPUToLLVMPass();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect>();
Expand Down
8 changes: 6 additions & 2 deletions bin/triton-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
Expand All @@ -38,6 +39,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
mlir::DialectRegistry registry;
registry
.insert<TritonDialect, triton::gpu::TritonGPUDialect,
triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::math::MathDialect, arith::ArithDialect, scf::SCFDialect>();

context.appendDialectRegistry(registry);
Expand Down Expand Up @@ -121,8 +123,10 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}

llvm::LLVMContext llvmContext;
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), false /*isRocm*/);
mlir::triton::gpu::TMAMetadataTy tmaInfos;
auto llvmir = translateTritonGPUToLLVMIR(
&llvmContext, *module, SMArch.getValue(), tmaInfos, false /*isRocm*/);

if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
Expand Down
9 changes: 9 additions & 0 deletions docs/meetups/08-22-2023.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#### Agenda:

##### Announcements:
1. Triton conference registration opening soon. Conference on 20th September at the Microsoft Silicon Valley Campus.

##### Items:
1. H100 updates
2. Linalg updates
3. Open discussion
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,4 @@ Iterators
:nosignatures:

static_range
multiple_of
13 changes: 7 additions & 6 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include <atomic>
#include <limits>

Expand Down Expand Up @@ -147,17 +148,17 @@ class Allocation {
BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;

bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }

BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind)
: kind(kind), id(InvalidBufferId), size(0), offset(0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}
};

/// Op -> Scratch Buffer
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace mlir {
/// This lattice value represents known information on the axes of a lattice.
class AxisInfo {
public:
typedef SmallVector<int64_t, 4> DimVectorT;
typedef SmallVector<int64_t> DimVectorT;

public:
/// Default constructor
Expand Down
11 changes: 10 additions & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
Expand Down Expand Up @@ -121,7 +122,11 @@ bool isSingleValue(Value value);

bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);

Type getElementType(Value value);
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);

// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
Expand Down Expand Up @@ -324,6 +329,10 @@ template <typename T> class CallGraph {
FuncDataMapT funcMap;
SmallVector<FunctionOpInterface> roots;
};
// Create a basic DataFlowSolver with constant and dead code analysis included.
std::unique_ptr<DataFlowSolver> createDataFlowSolver();

triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);

} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions include/triton/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(TritonToTritonGPU)
add_subdirectory(TritonGPUToLLVM)
add_subdirectory(NVGPUToLLVM)
3 changes: 3 additions & 0 deletions include/triton/Conversion/NVGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name NVGPUToLLVM)
add_public_tablegen_target(NVGPUConversionPassIncGen)
19 changes: 19 additions & 0 deletions include/triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H
#define TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H

#include <memory>

namespace mlir {

class ModuleOp;
template <typename T> class OperationPass;

namespace triton {

std::unique_ptr<OperationPass<ModuleOp>> createConvertNVGPUToLLVMPass();

} // namespace triton

} // namespace mlir

#endif
16 changes: 16 additions & 0 deletions include/triton/Conversion/NVGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef NVGPU_CONVERSION_PASSES_H
#define NVGPU_CONVERSION_PASSES_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton/Conversion/NVGPUToLLVM/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
20 changes: 20 additions & 0 deletions include/triton/Conversion/NVGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef NVGPU_CONVERSION_PASSES
#define NVGPU_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"


def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert NVGPU to LLVM";
let description = [{

}];
let constructor = "mlir::triton::createConvertNVGPUToLLVMPass()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::LLVM::LLVMDialect",
"mlir::NVVM::NVVMDialect",
"mlir::triton::nvgpu::NVGPUDialect"];
}

#endif
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ struct PTXBuilder {
// aggressive optimizations that may lead to incorrect results.
Operand *newOperand(StringRef constraint, bool init = false);

// Create a new operand that is tied to a previous operand. In this case the
// asm would be permitted to write to an input register. Instead of providing
// constraint code for this operand, the constraint code of the tied operand
// is used.
Operand *newOperand(unsigned operandIndex);

// Create a constant integer operand.
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.
Expand Down
4 changes: 4 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::NVVM::NVVMDialect"];

let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">,
Option<"TmaMetadata", "tma-metadata",
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
"tma metadata to the runtime">,
Option<"isROCM", "is-rocm",
"bool", /*default*/"false",
"compile for ROCM-compatible LLVM">,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Target/PTX/TmaMetadata.h"

#include <memory>

namespace mlir {
Expand All @@ -12,9 +14,10 @@ template <typename T> class OperationPass;

namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = false);
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass(
int computeCapability = 80,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr,
bool isROCM = false);

} // namespace triton

Expand Down
1 change: 1 addition & 0 deletions include/triton/Conversion/TritonToTritonGPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_CONVERSION_PASSES_H

#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Target/PTX/TmaMetadata.h"

namespace mlir {
namespace triton {
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonToTritonGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"32",
"number of threads per warp">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of ctas in a cga">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"compute capability">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ template <typename T> class OperationPass;
namespace triton {

constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
constexpr static char AttrComputeCapabilityName[] =
"triton_gpu.compute-capability";

constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";

Expand All @@ -19,7 +22,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();

// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32,
int numCTAs = 1, int computeCapability = 80);

} // namespace triton
} // namespace mlir
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)
add_subdirectory(TritonNvidiaGPU)
add_subdirectory(NVGPU)
Loading

0 comments on commit 9d340b0

Please sign in to comment.