Skip to content

Commit

Permalink
tweak words
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 29, 2023
1 parent 2b7042b commit db8aadb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ julia> x |> cpu

In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways:

1. Iterating over the batches in a [DataLoader](@ref) object transferring each one of the training batches at a time to the GPU. This is recommended for large datasets. Done hand, it might look like this:
1. Iterating over the batches in a [`DataLoader`](@ref) object transferring each one of the training batches at a time to the GPU. This is recommended for large datasets. Done hand, it might look like this:
```julia
train_loader = Flux.DataLoader((X, Y), batchsize=64, shuffle=true)
# ... model definition, optimiser setup
Expand All @@ -153,7 +153,7 @@ In order to train the model using the GPU both model and the training data have
This is equivalent to `DataLoader(MLUtils.mapobs(gpu, (X, Y)); keywords...)`.
Something similar can also be done with [`CUDA.CuIterator`](https://cuda.juliagpu.org/stable/usage/memory/#Batching-iterator), `gpu_train_loader = CUDA.CuIterator(train_loader)`. However, this only works with a limited number of data types: `first(train_loader)` should be a tuple (or `NamedTuple`) of arrays.

2. Transferring all training data to the GPU at once before creating the [DataLoader](@ref) object. This is usually performed for smaller datasets which are sure to fit in the available GPU memory.
2. Transferring all training data to the GPU at once before creating the `DataLoader`. This is usually performed for smaller datasets which are sure to fit in the available GPU memory.
```julia
gpu_train_loader = Flux.DataLoader((X, Y) |> gpu, batchsize = 32)
# ...
Expand Down

0 comments on commit db8aadb

Please sign in to comment.