-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* port CUDA apis * update * update * fix ut * update * update * update * more unittest * update * update * update Co-authored-by: Yi Wang <[email protected]>
- Loading branch information
1 parent
5a7e4a1
commit 84a41d0
Showing
10 changed files
with
256 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
// Copyright 2020, GoTorch Authors | ||
#include <iostream> | ||
#include <sstream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "torch/script.h" | ||
#include "torch/torch.h" | ||
|
||
#ifdef WITH_CUDA | ||
#include "c10/cuda/CUDAStream.h" | ||
#endif | ||
|
||
// FIXME(shendiaomo): including cgotorch.h before torch/torch.h will fail | ||
#include "cgotorch/cgotorch.h" | ||
|
||
bool IsCUDAAvailable() { return torch::cuda::is_available(); } | ||
|
||
bool IsCUDNNAvailable() { return torch::cuda::cudnn_is_available(); } | ||
|
||
const char *CUDA_GetCurrentCUDAStream(CUDAStream *stream, Device *device) { | ||
#ifdef WITH_CUDA | ||
try { | ||
*stream = static_cast<void *>(new at::cuda::CUDAStream( | ||
at::cuda::getCurrentCUDAStream((*device)->index()))); | ||
return nullptr; | ||
} catch (const std::exception &e) { | ||
return exception_str(e.what()); | ||
} | ||
#else | ||
return exception_str("CUDA API needs -DWITH_CUDA on building libcgotorch.so"); | ||
#endif | ||
} | ||
|
||
const char *CUDA_GetCUDAStreamFromPool(CUDAStream *stream, Device *device) { | ||
#ifdef WITH_CUDA | ||
try { | ||
*stream = static_cast<void *>( | ||
new at::cuda::CUDAStream(at::cuda::getStreamFromPool( | ||
false /**isHighPriority**/, (*device)->index()))); | ||
return nullptr; | ||
} catch (const std::exception &e) { | ||
return exception_str(e.what()); | ||
} | ||
#else | ||
return exception_str("CUDA API needs -DWITH_CUDA on building libcgotorch.so"); | ||
#endif | ||
} | ||
|
||
const char *CUDA_SetCurrentCUDAStream(CUDAStream stream) { | ||
#ifdef WITH_CUDA | ||
try { | ||
at::cuda::setCurrentCUDAStream( | ||
*static_cast<at::cuda::CUDAStream *>(stream)); | ||
return nullptr; | ||
} catch (const std::exception &e) { | ||
return exception_str(e.what()); | ||
} | ||
#else | ||
return exception_str("CUDA API needs -DWITH_CUDA on building libcgotorch.so"); | ||
#endif | ||
} | ||
|
||
const char *CUDA_Synchronize(CUDAStream stream) { | ||
#ifdef WITH_CUDA | ||
try { | ||
static_cast<at::cuda::CUDAStream *>(stream)->synchronize(); | ||
return nullptr; | ||
} catch (const std::exception &e) { | ||
return exception_str(e.what()); | ||
} | ||
#else | ||
return exception_str("CUDA API needs -DWITH_CUDA on building libcgotorch.so"); | ||
#endif | ||
} | ||
|
||
const char *CUDA_Query(CUDAStream stream, int8_t *result) { | ||
#ifdef WITH_CUDA | ||
try { | ||
*result = static_cast<at::cuda::CUDAStream *>(stream)->query() ? 1 : 0; | ||
return nullptr; | ||
} catch (const std::exception &e) { | ||
return exception_str(e.what()); | ||
} | ||
#else | ||
return exception_str("CUDA API needs -DWITH_CUDA on building libcgotorch.so"); | ||
#endif | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package gotorch | ||
|
||
// #cgo CFLAGS: -I ${SRCDIR}/cgotorch -I ${SRCDIR}/cgotorch/libtorch/include | ||
// #cgo LDFLAGS: -L ${SRCDIR}/cgotorch -Wl,-rpath ${SRCDIR}/cgotorch -lcgotorch | ||
// #cgo LDFLAGS: -L ${SRCDIR}/cgotorch/libtorch/lib -Wl,-rpath ${SRCDIR}/cgotorch/libtorch/lib -lc10 -ltorch -ltorch_cpu | ||
// #include "cgotorch.h" | ||
import "C" | ||
import "unsafe" | ||
|
||
// CUDAStream struct wrapped Nvidia CUDA Stream | ||
type CUDAStream struct { | ||
P C.CUDAStream | ||
} | ||
|
||
// Query returns true if all tasks completed on this CUDA stream | ||
func (s CUDAStream) Query() bool { | ||
var b int8 | ||
MustNil(unsafe.Pointer(C.CUDA_Query(s.P, (*C.int8_t)(&b)))) | ||
return b != 0 | ||
} | ||
|
||
// Synchronize wait until all tasks completed on this CUDA stream | ||
func (s CUDAStream) Synchronize() { | ||
MustNil(unsafe.Pointer(C.CUDA_Synchronize(s.P))) | ||
} | ||
|
||
// GetCurrentCUDAStream returns the current stream on device | ||
func GetCurrentCUDAStream(device Device) CUDAStream { | ||
var stream C.CUDAStream | ||
MustNil(unsafe.Pointer(C.CUDA_GetCurrentCUDAStream(&stream, &device.T))) | ||
return CUDAStream{stream} | ||
} | ||
|
||
// SetCurrentCUDAStream set stream as the current CUDA stream | ||
func SetCurrentCUDAStream(stream CUDAStream) { | ||
MustNil(unsafe.Pointer(C.CUDA_SetCurrentCUDAStream(stream.P))) | ||
} | ||
|
||
// NewCUDAStream returns a new CUDA stream from the pool | ||
func NewCUDAStream(device Device) CUDAStream { | ||
var stream C.CUDAStream | ||
MustNil(unsafe.Pointer(C.CUDA_GetCUDAStreamFromPool(&stream, &device.T))) | ||
return CUDAStream{stream} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package gotorch_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
torch "github.com/wangkuiyi/gotorch" | ||
) | ||
|
||
func getDefaultDevice() torch.Device { | ||
var device torch.Device | ||
if torch.IsCUDAAvailable() { | ||
device = torch.NewDevice("cuda") | ||
} else { | ||
device = torch.NewDevice("cpu") | ||
} | ||
return device | ||
} | ||
func TestCUDAStreamPanics(t *testing.T) { | ||
a := assert.New(t) | ||
device := getDefaultDevice() | ||
if torch.IsCUDAAvailable() { | ||
a.NotPanics(func() { | ||
torch.GetCurrentCUDAStream(device) | ||
}) | ||
} else { | ||
a.Panics(func() { | ||
torch.GetCurrentCUDAStream(device) | ||
}) | ||
a.Panics(func() { | ||
torch.NewCUDAStream(device) | ||
}) | ||
} | ||
} | ||
|
||
func TestMultiCUDAStream(t *testing.T) { | ||
if !torch.IsCUDAAvailable() { | ||
t.Skip("skip TestMultiCUDAStream which only run on CUDA device") | ||
} | ||
a := assert.New(t) | ||
device := getDefaultDevice() | ||
currStream := torch.GetCurrentCUDAStream(device) | ||
defer torch.SetCurrentCUDAStream(currStream) | ||
// create a new CUDA stream | ||
stream := torch.NewCUDAStream(device) | ||
// switch to the new CUDA stream | ||
torch.SetCurrentCUDAStream(stream) | ||
// copy Tensor from host to device async | ||
input := torch.RandN([]int64{100, 200}, true).PinMemory() | ||
input.CUDA(device, true /**nonBlocking=true**/) | ||
// wait until all tasks completed | ||
stream.Synchronize() | ||
// make sure all tasks completed | ||
a.True(stream.Query()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters