Skip to content

Commit

Permalink
Merge branch 'peft' into peft_xinhao
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc authored Feb 25, 2024
2 parents 99bcadf + 9075d3f commit dd1366f
Show file tree
Hide file tree
Showing 173 changed files with 6,373 additions and 1,851 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ python/flexflow/core/legion_cffi_header.py
/inference/tokenizer/*
/inference/prompt/*
/inference/output/*

/tests/inference/python_test_configs/*.json
24 changes: 20 additions & 4 deletions .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ ssms=[]
ssm = ff.SSM("JackFram/llama-68m")
ssms.append(ssm)
```
Next, we declare the generation configuration and compile both the LLM and SSMs. Note that all SSMs should run in the **beam search** mode, and the LLM should run in the **tree verification** mode to verify the speculated tokens from SSMs.
Next, we declare the generation configuration and compile both the LLM and SSMs. Note that all SSMs should run in the **beam search** mode, and the LLM should run in the **tree verification** mode to verify the speculated tokens from SSMs. You can also use the following arguments to specify serving configuration when compiling LLMs and SSMs:

* max\_requests\_per\_batch: the maximum number of requests to serve in a batch (default: 16)
* max\_seq\_length: the maximum number of tokens in a request (default: 256)
* max\_tokens\_per\_batch: the maximum number of tokens to process in a batch (default: 128)

```python
# Create the sampling configs
generation_config = ff.GenerationConfig(
Expand All @@ -91,11 +96,17 @@ for ssm in ssms:
ssm.compile(generation_config)

# Compile the LLM for inference and load the weights into memory
llm.compile(generation_config, ssms=ssms)
llm.compile(generation_config,
max_requests_per_batch = 16,
max_seq_length = 256,
max_tokens_per_batch = 128,
ssms=ssms)
```
Finally, we call `llm.generate` to generate the output, which is organized as a list of `GenerationResult`, which include the output tokens and text.
Next, we call `llm.start_server()` to start an LLM server running on a seperate background thread, which allows users to perform computations in parallel with LLM serving. Finally, we call `llm.generate` to generate the output, which is organized as a list of `GenerationResult`, which include the output tokens and text. After all serving requests are processed, you can either call `llm.stop_server()` to terminate the background thread or directly exit the python program, which will automatically terminate the background server thread.
```python
llm.start_server()
result = llm.generate("Here are some travel tips for Tokyo:\n")
llm.stop_server() # This invocation is optional
```

### Incremental decoding
Expand Down Expand Up @@ -124,10 +135,15 @@ generation_config = ff.GenerationConfig(
)

# Compile the LLM for inference and load the weights into memory
llm.compile(generation_config)
llm.compile(generation_config,
max_requests_per_batch = 16,
max_seq_length = 256,
max_tokens_per_batch = 128)

# Generation begins!
llm.start_server()
result = llm.generate("Here are some travel tips for Tokyo:\n")
llm.stop_server() # This invocation is optional
```

</details>
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ jobs:
CONDA: "3"
needs: inference-tests
container:
image: ghcr.io/flexflow/flexflow-environment-cuda:latest
image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest
options: --gpus all --shm-size=8192m
steps:
- name: Install updated git version
Expand All @@ -243,7 +243,7 @@ jobs:

- name: Build and Install FlexFlow
run: |
export PATH=/opt/conda/bin:$PATH
export PATH=$CONDA_PREFIX/bin:$PATH
export FF_HOME=$(pwd)
export FF_BUILD_ALL_EXAMPLES=ON
export FF_BUILD_ALL_INFERENCE_EXAMPLES=ON
Expand All @@ -252,18 +252,18 @@ jobs:
- name: Check FlexFlow Python interface (pip)
run: |
export PATH=/opt/conda/bin:$PATH
export PATH=$CONDA_PREFIX/bin:$PATH
export FF_HOME=$(pwd)
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
./tests/python_interface_test.sh after-installation
- name: Run multi-gpu tests
run: |
export PATH=/opt/conda/bin:$PATH
export PATH=$CONDA_PREFIX/bin:$PATH
export CUDNN_DIR=/usr/local/cuda
export CUDA_DIR=/usr/local/cuda
export FF_HOME=$(pwd)
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib
# C++ tests
./tests/cpp_gpu_tests.sh 4
# Python tests
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,4 @@ hf_peft_tensors

Untitled-1.ipynb
Untitled-2.ipynb
tests/inference/python_test_configs/*.json
23 changes: 14 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_LIST_DIR}/cmake)
set(FLEXFLOW_ROOT ${CMAKE_CURRENT_LIST_DIR})
set(CMAKE_CXX_FLAGS "-std=c++17 ${CMAKE_CXX_FLAGS} -fPIC -UNDEBUG")
set(CMAKE_HIP_FLAGS "-std=c++17 ${CMAKE_HIP_FLAGS} -fPIC -UNDEBUG")

# set std 17
#set(CMAKE_CXX_STANDARD 17)
Expand Down Expand Up @@ -51,6 +52,7 @@ endif()

# do not disable assertions even if in release mode
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG")
set(CMAKE_HIP_FLAGS_RELEASE "${CMAKE_HIP_FLAGS_RELEASE} -UNDEBUG")

if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
set(LIBEXT ".so")
Expand Down Expand Up @@ -157,6 +159,7 @@ endif()

# HIP
if (FF_GPU_BACKEND STREQUAL "hip_rocm" OR FF_GPU_BACKEND STREQUAL "hip_cuda")
enable_language(HIP)
include(hip)
endif()

Expand Down Expand Up @@ -261,14 +264,14 @@ if(NOT BUILD_LEGION_ONLY)
LIST_DIRECTORIES False
${FLEXFLOW_ROOT}/include/*.h)

list(APPEND FLEXFLOW_HDR ${FLEXFLOW_ROOT}/inference/file_loader.h)
#list(APPEND FLEXFLOW_HDR ${FLEXFLOW_ROOT}/inference/file_loader.h)

file(GLOB_RECURSE FLEXFLOW_SRC
LIST_DIRECTORIES False
${FLEXFLOW_ROOT}/src/*.cc)

list(REMOVE_ITEM FLEXFLOW_SRC "${FLEXFLOW_ROOT}/src/runtime/cpp_driver.cc")
list(APPEND FLEXFLOW_SRC ${FLEXFLOW_ROOT}/inference/file_loader.cc)
#list(APPEND FLEXFLOW_SRC ${FLEXFLOW_ROOT}/inference/file_loader.cc)

set(FLEXFLOW_CPP_DRV_SRC
${FLEXFLOW_ROOT}/src/runtime/cpp_driver.cc)
Expand Down Expand Up @@ -299,7 +302,10 @@ if(NOT BUILD_LEGION_ONLY)
LIST_DIRECTORIES False
${FLEXFLOW_ROOT}/src/*.cpp)

if(BUILD_SHARED_LIBS)
set_source_files_properties(${FLEXFLOW_GPU_SRC} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${FLEXFLOW_SRC} PROPERTIES LANGUAGE HIP)

if(BUILD_SHARED_LIBS)
add_library(flexflow SHARED ${FLEXFLOW_GPU_SRC} ${FLEXFLOW_SRC})
else()
add_library(flexflow STATIC ${FLEXFLOW_GPU_SRC} ${FLEXFLOW_SRC})
Expand Down Expand Up @@ -407,6 +413,7 @@ if(NOT BUILD_LEGION_ONLY)

# python related
if (FF_USE_PYTHON)
find_package(Python COMPONENTS Interpreter Development)
# create flexflow_cffi_header.py
add_custom_command(TARGET flexflow
PRE_BUILD
Expand All @@ -418,13 +425,13 @@ if(NOT BUILD_LEGION_ONLY)
# generate the Legion Python bindings library. When building from pip, we need to do this post-install to prevent Legion from overwriting the path to the Legion shared library
add_custom_command(TARGET flexflow
POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/deps/legion/bindings/python/setup.py build --cmake-build-dir ${Legion_BINARY_DIR}/runtime --prefix ${Legion_BINARY_DIR} --build-lib=${Legion_BINARY_DIR}/bindings/python ${Legion_PYTHON_EXTRA_INSTALL_ARGS}
COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/deps/legion/bindings/python/setup.py build --cmake-build-dir ${Legion_BINARY_DIR}/runtime --prefix ${Legion_BINARY_DIR} --build-lib=${Legion_BINARY_DIR}/bindings/python ${Legion_PYTHON_EXTRA_INSTALL_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/deps/legion/bindings/python
)
# create flexflow_python interpreter. When building from pip, we install the FF_HOME/python/flexflow_python script instead.
add_custom_command(TARGET flexflow
PRE_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${FLEXFLOW_ROOT}/python/flexflow_python_build.py --build-dir ${CMAKE_BINARY_DIR}
COMMAND ${Python_EXECUTABLE} ${FLEXFLOW_ROOT}/python/flexflow_python_build.py --build-dir ${CMAKE_BINARY_DIR}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Creating flexflow_python interpreter..."
)
Expand Down Expand Up @@ -474,9 +481,6 @@ if(NOT BUILD_LEGION_ONLY)
endif()

if(FF_BUILD_ALL_INFERENCE_EXAMPLES OR FF_BUILD_TOKENIZER)
if (FF_GPU_BACKEND STREQUAL "hip_rocm")
SET(SPM_USE_BUILTIN_PROTOBUF OFF CACHE BOOL "Use builtin version of protobuf to compile SentencePiece")
endif()
# Ensure Rust is installed
execute_process(COMMAND rustc --version
RESULT_VARIABLE RUST_COMMAND_RESULT
Expand Down Expand Up @@ -564,7 +568,8 @@ if(NOT BUILD_LEGION_ONLY)
install(TARGETS flexflow DESTINATION ${LIB_DEST})
# install python
if (FF_USE_PYTHON)
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import site, os; print([pkg for func in (site.getsitepackages(), site.getusersitepackages()) for pkg in ([func] if isinstance(func, str) else func) if os.access(pkg, os.W_OK)][0])" OUTPUT_VARIABLE PY_DEST OUTPUT_STRIP_TRAILING_WHITESPACE)
find_package(Python COMPONENTS Interpreter Development)
execute_process(COMMAND ${Python_EXECUTABLE} -c "import site, os; print([pkg for func in (site.getsitepackages(), site.getusersitepackages()) for pkg in ([func] if isinstance(func, str) else func) if os.access(pkg, os.W_OK)][0])" OUTPUT_VARIABLE PY_DEST OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT FF_BUILD_FROM_PYPI)
install(
DIRECTORY ${FLEXFLOW_ROOT}/python/flexflow/
Expand Down
72 changes: 69 additions & 3 deletions SERVE.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,80 @@ FlexFlow Serve supports int4 and int8 quantization. The compressed tensors are s
### Prompt Datasets
We provide five prompt datasets for evaluating FlexFlow Serve: [Chatbot instruction prompts](https://specinfer.s3.us-east-2.amazonaws.com/prompts/chatbot.json), [ChatGPT Prompts](https://specinfer.s3.us-east-2.amazonaws.com/prompts/chatgpt.json), [WebQA](https://specinfer.s3.us-east-2.amazonaws.com/prompts/webqa.json), [Alpaca](https://specinfer.s3.us-east-2.amazonaws.com/prompts/alpaca.json), and [PIQA](https://specinfer.s3.us-east-2.amazonaws.com/prompts/piqa.json).




## Python Interface Features and Interaction Methods

FlexFlow Serve provides a comprehensive Python interface for serving with low latency and high performance. This interface facilitates the deployment and interaction with the serving platform for a variety of applications, from chatbots and prompt templates to retrieval augmented generation and API services.

### Chatbot with Gradio

The Python interface allows setting up a chatbot application using Gradio, enabling interactive dialogues with users through a user-friendly web interface.

#### Implementation Steps
1. **FlexFlow Initialization:** Configure and initialize FlexFlow Serve with the desired settings and the specific LLM.
```python
import gradio as gr
import flexflow.serve as ff

ff.init(num_gpus=2, memory_per_gpu=14000, ...)
```
2. **Gradio Interface Setup:** Implement a function to generate responses from user inputs and set up the Gradio Chat Interface for interaction.
```python
def generate_response(user_input):
result = llm.generate(user_input)
return result.output_text.decode('utf-8')
```
3. **Running the Interface:** Launch the Gradio interface to interact with the LLM through a web-based chat interface.
```python
iface = gr.ChatInterface(fn=generate_response)
iface.launch()
```
4. **Shutdown:** Properly stop the FlexFlow server after interaction is complete.



### Langchain Usecases
FlexFlow Serve supports langchain usecases including dynamic prompt template handling and RAG usecases, enabling the customization of model responses based on structured input templates and Retrieval Augmented Generation.

#### Implementation Steps
1. **FlexFlow Initialization**: Start by initializing FlexFlow Serve with the appropriate configurations.
2. **LLM Setup**: Compile and load the LLM for text generation.
3. **Prompt Template/RAG Setup**: Configure prompt templates to guide the model's responses.
4. **Response Generation**: Use the LLM with the prompt template to generate responses.


### Python FastAPI Entrypoint
Flexflow Serve also supports deploying and managing LLMs with FastAPI, offering a RESTful API interface for generating responses from models.

```python
@app.on_event("startup")
async def startup_event():
global llm
# Initialize and compile the LLM model
llm.compile(
generation_config,
# ... other params as needed
)
llm.start_server()

@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
# ... exception handling
full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8')
# ... split prompt and response text for returning results
return {"prompt": prompt_request.prompt, "response": full_output}
```




## TODOs

FlexFlow Serve is still under active development. We currently focus on the following tasks and strongly welcome all contributions from bug fixes to new features and extensions.

* AMD benchmarking. We are actively working on benchmarking FlexFlow Serve on AMD GPUs and comparing it with the performance on NVIDIA GPUs.
* Chatbot prompt templates and Multi-round conversations
* Support for FastAPI server
* Integration with LangChain for document question answering

## Acknowledgements
This project is initiated by members from CMU, Stanford, and UCSD. We will be continuing developing and supporting FlexFlow Serve. Please cite FlexFlow Serve as:
Expand Down
4 changes: 2 additions & 2 deletions cmake/hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ if (NOT FF_HIP_ARCH STREQUAL "")
if (FF_HIP_ARCH STREQUAL "all")
set(FF_HIP_ARCH "gfx900,gfx902,gfx904,gfx906,gfx908,gfx909,gfx90a,gfx90c,gfx940,gfx1010,gfx1011,gfx1012,gfx1013,gfx1030,gfx1031,gfx1032,gfx1033,gfx1034,gfx1035,gfx1036,gfx1100,gfx1101,gfx1102,gfx1103")
endif()
string(REPLACE "," " " HIP_ARCH_LIST "${FF_HIP_ARCH}")
string(REPLACE "," "," HIP_ARCH_LIST "${FF_HIP_ARCH}")
endif()

message(STATUS "FF_HIP_ARCH: ${FF_HIP_ARCH}")
if(FF_GPU_BACKEND STREQUAL "hip_rocm")
set(HIP_CLANG_PATH ${ROCM_PATH}/llvm/bin CACHE STRING "Path to the clang compiler by ROCM" FORCE)
#set(HIP_CLANG_PATH ${ROCM_PATH}/llvm/bin CACHE STRING "Path to the clang compiler by ROCM" FORCE)
set(GPU_TARGETS "${FF_HIP_ARCH}" CACHE STRING "The GPU TARGETs")
endif()
4 changes: 2 additions & 2 deletions cmake/pip_install/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Use setup.py script to re-install the Python bindings library with the right library paths
if (FF_USE_PYTHON)
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import site, os; print([pkg for func in (site.getsitepackages(), site.getusersitepackages()) for pkg in ([func] if isinstance(func, str) else func) if os.access(pkg, os.W_OK)][0])" OUTPUT_VARIABLE PY_DEST OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${Python_EXECUTABLE} -c "import site, os; print([pkg for func in (site.getsitepackages(), site.getusersitepackages()) for pkg in ([func] if isinstance(func, str) else func) if os.access(pkg, os.W_OK)][0])" OUTPUT_VARIABLE PY_DEST OUTPUT_STRIP_TRAILING_WHITESPACE)
if(FF_BUILD_FROM_PYPI)
install(CODE "execute_process(COMMAND ${CMAKE_COMMAND} -E echo \"Editing path to Legion library using path: ${PY_DEST}/flexflow/lib \")")
# CMAKE_CURRENT_SOURCE_DIR=/usr/FlexFlow/cmake/pip_install
# Legion_BINARY_DIR=/usr/FlexFlow/build/<something>/deps/legion
install(CODE "execute_process(COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/../../deps/legion/bindings/python/setup.py install --cmake-build-dir ${Legion_BINARY_DIR}/runtime --prefix ${PY_DEST}/flexflow ${Legion_PYTHON_EXTRA_INSTALL_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../deps/legion/bindings/python)")
install(CODE "execute_process(COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/../../deps/legion/bindings/python/setup.py install --cmake-build-dir ${Legion_BINARY_DIR}/runtime --prefix ${PY_DEST}/flexflow ${Legion_PYTHON_EXTRA_INSTALL_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../deps/legion/bindings/python)")
endif()
endif()
7 changes: 5 additions & 2 deletions config/config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ if [ -n "$ROCM_PATH" ]; then
SET_ROCM_PATH="-DROCM_PATH=${ROCM_PATH}"
fi

ADD_ROCM_TO_PATH=""

# set GPU backend
if [ -n "$FF_GPU_BACKEND" ]; then
SET_FF_GPU_BACKEND="-DFF_GPU_BACKEND=${FF_GPU_BACKEND}"
Expand Down Expand Up @@ -222,7 +224,8 @@ if [ -n "$FF_GPU_BACKEND" ]; then
chmod +x "$(pwd)/nvidia_hipcc"
SET_CXX="-DCMAKE_CXX_COMPILER=$(pwd)/nvidia_hipcc -DCMAKE_CXX_LINKER=$(pwd)/nvidia_hipcc"
else
SET_CXX="-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -DCMAKE_CXX_LINKER=/opt/rocm/bin/hipcc"
ADD_ROCM_TO_PATH="PATH=${PATH}:${ROCM_PATH}/bin"
#SET_CXX="-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -DCMAKE_CXX_LINKER=/opt/rocm/bin/hipcc"
fi
fi
fi
Expand All @@ -232,7 +235,7 @@ CMAKE_FLAGS="-DCUDA_USE_STATIC_CUDA_RUNTIME=OFF -DLegion_HIJACK_CUDART=OFF ${SET

function run_cmake() {
SRC_LOCATION=${SRC_LOCATION:=`dirname $0`/../}
CMAKE_COMMAND="${SET_CC_FLAGS} ${SET_NVCC_FLAGS} ${SET_LD_FLAGS} ${SET_CUDA_LIB_PATH} cmake ${CMAKE_FLAGS} $* ${SRC_LOCATION}"
CMAKE_COMMAND="${SET_CC_FLAGS} ${SET_NVCC_FLAGS} ${SET_LD_FLAGS} ${SET_CUDA_LIB_PATH} ${ADD_ROCM_TO_PATH} cmake ${CMAKE_FLAGS} $* ${SRC_LOCATION}"
echo $CMAKE_COMMAND
eval $CMAKE_COMMAND
}
2 changes: 1 addition & 1 deletion deps/legion
Submodule legion updated from 626b55 to 24e8c4
5 changes: 1 addition & 4 deletions docker/flexflow-environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ RUN if [ "$FF_GPU_BACKEND" = "hip_cuda" ] || [ "$FF_GPU_BACKEND" = "hip_rocm" ]
rm ./${AMD_GPU_SCRIPT_NAME}; \
amdgpu-install -y --usecase=hip,rocm --no-dkms; \
apt-get install -y hip-dev hipblas miopen-hip rocm-hip-sdk rocm-device-libs; \
# Install protobuf v3.20.x manually
# Install protobuf dependencies
apt-get update -y && sudo apt-get install -y pkg-config zip g++ zlib1g-dev autoconf automake libtool make; \
git clone -b 3.20.x https://github.com/protocolbuffers/protobuf.git; cd protobuf/ ; git submodule update --init --recursive; \
./autogen.sh; ./configure; cores_available=$(nproc --all); n_build_cores=$(( cores_available -1 )); \
if (( n_build_cores < 1 )) ; then n_build_cores=1 ; fi; make -j $n_build_cores; make install; ldconfig; cd .. ; \
else \
echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping installing HIP dependencies"; \
fi
Expand Down
Loading

0 comments on commit dd1366f

Please sign in to comment.