diff --git a/.github/workflows/build-aar.yml b/.github/workflows/build-aar.yml new file mode 100644 index 00000000..1fcd3453 --- /dev/null +++ b/.github/workflows/build-aar.yml @@ -0,0 +1,104 @@ +name: Build Android AAR + +on: + push: + branches: + - main + - android-sdk + pull_request: + +jobs: + build: + name: Build AAR + runs-on: ubuntu-latest + if: github.event_name == 'push' + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up JDK 17 + uses: actions/setup-java@v3 + with: + distribution: 'temurin' + java-version: '17' + + - name: Cache Gradle packages + uses: actions/cache@v3 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*', '**/gradle-wrapper.properties') }} + restore-keys: | + ${{ runner.os }}-gradle- + + - name: Navigate to android Directory and Build AAR + run: | + echo "Navigating to the example directory..." + cd android/llama.android + echo "Starting Gradle build process in $(pwd)..." + ./gradlew assembleRelease --stacktrace --info + shell: bash + + - name: Rename and upload AAR + run: | + echo "Navigating to the android directory to find AAR output..." + cd android/llama.android + mkdir -p ../artifacts + ls -ld ../artifacts || echo "Artifacts directory does not exist." + AAR_PATH=$(find ./llama/build/outputs/aar -type f -name "*.aar" | head -n 1) + if [ -z "$AAR_PATH" ]; then + echo "No AAR file found. Build might have failed." + exit 1 + fi + BRANCH_NAME=${{ github.ref_name }} + CUSTOM_NAME="com-nexa-${BRANCH_NAME}-${{ github.run_number }}.aar" + echo "Found AAR at $AAR_PATH, renaming to $CUSTOM_NAME..." + mv "$AAR_PATH" "../artifacts/$CUSTOM_NAME" + shell: bash + + - name: Upload AAR as an artifact + uses: actions/upload-artifact@v3 + with: + name: custom-aar-${{ github.ref_name }}-${{ github.run_number }} + path: android/artifacts/ + + release: + name: Create GitHub Release + needs: build + runs-on: ubuntu-latest + if: github.event_name == 'push' && contains(github.ref, 'main') + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Download Artifacts + uses: actions/download-artifact@v3 + with: + name: custom-aar-${{ github.ref_name }}-${{ github.run_number }} + path: release-artifacts + + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: v${{ github.run_number }} + release_name: "Release v${{ github.run_number }}" + body: | + This is an automated release containing the latest AAR build. + - **Branch:** ${{ github.ref_name }} + - **Build Number:** ${{ github.run_number }} + draft: false + prerelease: false + + - name: Upload AAR to Release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: release-artifacts/com-nexa-${{ github.ref_name }}-${{ github.run_number }}.aar + asset_name: com-nexa-${{ github.ref_name }}-${{ github.run_number }}.aar + asset_content_type: application/java-archive \ No newline at end of file diff --git a/.github/workflows/build-wheels-vulkan-win.yaml b/.github/workflows/build-wheels-vulkan-win.yaml index ac362195..ca68208d 100644 --- a/.github/workflows/build-wheels-vulkan-win.yaml +++ b/.github/workflows/build-wheels-vulkan-win.yaml @@ -1,11 +1,8 @@ name: Build Wheels (Vulkan) (Windows) - on: workflow_dispatch: - permissions: contents: write - jobs: define_matrix: name: Define Build Matrix @@ -15,7 +12,6 @@ jobs: defaults: run: shell: pwsh - steps: - name: Define Job Output id: set-matrix @@ -26,10 +22,8 @@ jobs: 'vulkan_version' = @("1.3.261.1") 'releasetag' = @("basic") } - $matrixOut = ConvertTo-Json $matrix -Compress Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT - build_wheels: name: Build Wheel ${{ matrix.os }} Python ${{ matrix.pyver }} needs: define_matrix @@ -42,64 +36,58 @@ jobs: env: VULKAN_VERSION: ${{ matrix.vulkan_version }} RELEASE_TAG: ${{ matrix.releasetag }} - steps: - name: Add MSBuild to PATH if: runner.os == 'Windows' uses: microsoft/setup-msbuild@v2 with: vs-version: "[16.11,16.12)" - + msbuild-architecture: x64 - name: Checkout Repository uses: actions/checkout@v4 with: submodules: "recursive" - - name: Install Vulkan SDK run: | curl.exe -o $env:RUNNER_TEMP\VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${{ env.VULKAN_VERSION }}/windows/VulkanSDK-${{ env.VULKAN_VERSION }}-Installer.exe" & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install Add-Content -Path $env:GITHUB_ENV -Value "VULKAN_SDK=C:\VulkanSDK\${{ env.VULKAN_VERSION }}" Add-Content -Path $env:GITHUB_PATH -Value "C:\VulkanSDK\${{ env.VULKAN_VERSION }}\Bin" - - name: Setup Python uses: actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} architecture: 'x64' cache: "pip" - - name: Install Ninja Build System run: choco install ninja -y - - name: Install Build Dependencies run: | python -m pip install --upgrade pip python -m pip install build wheel setuptools cmake ninja - # Install additional dependencies if needed python -m pip install scikit-build - - name: Build Wheel run: | # Set environment variables for CMake and Vulkan - $env:CMAKE_ARGS="-DGGML_VULKAN=ON" + $env:CMAKE_ARGS="-DGGML_VULKAN=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=cl" $env:VULKAN_SDK="C:\VulkanSDK\${{ env.VULKAN_VERSION }}" $env:PATH="$env:VULKAN_SDK\Bin;$env:PATH" - + + # Set MSVC compiler flags to fix Windows SDK header issues + $env:CFLAGS="/D_CRT_SECURE_NO_WARNINGS /DWIN32_LEAN_AND_MEAN /DNOMINMAX /D_WIN32_WINNT=0x0601" + $env:CXXFLAGS="/D_CRT_SECURE_NO_WARNINGS /DWIN32_LEAN_AND_MEAN /DNOMINMAX /D_WIN32_WINNT=0x0601" + # Build the wheel python -m build --wheel - - name: Upload Wheel Artifact uses: actions/upload-artifact@v4 with: path: dist/*.whl name: llama-vulkan-wheel-python${{ matrix.pyver }}.whl - - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: files: dist/* - # Set tag name to -vulkan - tag_name: ${{ github.ref_name }}-vulkan${{ env.VULKAN_VERSION }} + tag_name: ${{ github.ref_name }}-vulkan env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8ce64f0f..c63c3e1f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,6 +29,14 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Cache pip packages + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install dependencies run: | python -m pip install --upgrade pip @@ -61,4 +69,4 @@ jobs: - name: Run tests run: | python -m pytest tests - shell: bash + shell: bash \ No newline at end of file diff --git a/.github/workflows/generate-index-from-release.yaml b/.github/workflows/generate-index-from-release.yaml index 11741cab..b993847a 100644 --- a/.github/workflows/generate-index-from-release.yaml +++ b/.github/workflows/generate-index-from-release.yaml @@ -3,7 +3,7 @@ name: Wheels Index on: # Trigger on new release workflow_run: - workflows: ["Release", "Build Wheels (CUDA)", "Build Wheels (Metal)", "Build Wheels (ROCm)", "Build Wheels (Vulkan)"] + workflows: ["Build Wheels (CPU)", "Build Wheels (CUDA)", "Build Wheels (Metal)", "Build Wheels (ROCm)", "Build Wheels (Vulkan)"] types: - completed diff --git a/.gitignore b/.gitignore index 9063bffa..a22e3c43 100644 --- a/.gitignore +++ b/.gitignore @@ -90,4 +90,13 @@ build_*/ .cache/ # tests -quantization_test.py \ No newline at end of file +quantization_test.py + +# Swift +.swiftpm/ +UserInterfaceState.xcuserstate +xcuserdata/ +*.xcworkspace/xcuserdata/ +*.playground/playground.xcworkspace/xcuserdata/ +*.generated.plist +.build/ \ No newline at end of file diff --git a/CLI.md b/CLI.md index 5c4f4ab4..5f219047 100644 --- a/CLI.md +++ b/CLI.md @@ -31,7 +31,7 @@ options: ### List Local Models -List all models on your local computer. +List all models on your local computer. You can use `nexa run ` to run any model shown in the list. ``` nexa list @@ -46,11 +46,12 @@ nexa pull MODEL_PATH usage: nexa pull [-h] model_path positional arguments: - model_path Path or identifier for the model in Nexa Model Hub, or Hugging Face repo ID when using -hf flag + model_path Path or identifier for the model in Nexa Model Hub, Hugging Face repo ID when using -hf flag, or ModelScope model ID when using -ms flag options: -h, --help show this help message and exit -hf, --huggingface Pull model from Hugging Face Hub + -ms, --modelscope Pull model from ModelScope Hub -o, --output_path OUTPUT_PATH Custom output path for the pulled model ``` @@ -96,11 +97,13 @@ Run a model on your local computer. If the model file is not yet downloaded, it By default, `nexa` will run gguf models. To run onnx models, use `nexa onnx MODEL_PATH` +You can run any model shown in `nexa list` command. + #### Run Text-Generation Model ``` nexa run MODEL_PATH -usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -109,9 +112,10 @@ options: -h, --help show this help message and exit -pf, --profiling Enable profiling logs for the inference process -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf - -lp, --local_path Indicate that the model path provided is the local path, must be used with -mt - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] - -hf, --huggingface Load model from Hugging Face Hub, must be used with -mt + -lp, --local_path Indicate that the model path provided is the local path + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub Text generation options: -t, --temperature TEMPERATURE @@ -135,7 +139,7 @@ nexa run llama2 ``` nexa run MODEL_PATH -usage: nexa run [-h] [-i2i] [-ns NUM_INFERENCE_STEPS] [-np NUM_IMAGES_PER_PROMPT] [-H HEIGHT] [-W WIDTH] [-g GUIDANCE_SCALE] [-o OUTPUT] [-s RANDOM_SEED] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-i2i] [-ns NUM_INFERENCE_STEPS] [-np NUM_IMAGES_PER_PROMPT] [-H HEIGHT] [-W WIDTH] [-g GUIDANCE_SCALE] [-o OUTPUT] [-s RANDOM_SEED] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -143,9 +147,10 @@ positional arguments: options: -h, --help show this help message and exit -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf - -lp, --local_path Indicate that the model path provided is the local path, must be used with -mt - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] - -hf, --huggingface Load model from Hugging Face Hub, must be used with -mt + -lp, --local_path Indicate that the model path provided is the local path + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub Image generation options: -i2i, --img2img Whether to run image-to-image generation @@ -180,7 +185,7 @@ nexa run sd1-4 ``` nexa run MODEL_PATH -usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -189,9 +194,10 @@ options: -h, --help show this help message and exit -pf, --profiling Enable profiling logs for the inference process -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf - -lp, --local_path Indicate that the model path provided is the local path, must be used with -mt - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] - -hf, --huggingface Load model from Hugging Face Hub, must be used with -mt + -lp, --local_path Indicate that the model path provided is the local path + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub VLM generation options: -t, --temperature TEMPERATURE @@ -215,7 +221,7 @@ nexa run nanollava ``` nexa run MODEL_PATH -usage: nexa run [-h] [-o OUTPUT_DIR] [-b BEAM_SIZE] [-l LANGUAGE] [--task TASK] [-t TEMPERATURE] [-c COMPUTE_TYPE] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-o OUTPUT_DIR] [-b BEAM_SIZE] [-l LANGUAGE] [--task TASK] [-t TEMPERATURE] [-c COMPUTE_TYPE] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -223,9 +229,10 @@ positional arguments: options: -h, --help show this help message and exit -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf - -lp, --local_path Indicate that the model path provided is the local path, must be used with -mt - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] - -hf, --huggingface Load model from Hugging Face Hub, must be used with -mt + -lp, --local_path Indicate that the model path provided is the local path + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub Automatic Speech Recognition options: -b, --beam_size BEAM_SIZE @@ -249,7 +256,7 @@ nexa run faster-whisper-tiny ``` nexa embed MODEL_PATH -usage: nexa embed [-h] [-lp] [-hf] [-n] [-nt] model_path prompt +usage: nexa embed [-h] [-lp] [-hf] [-ms] [-n] [-nt] model_path prompt positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -257,8 +264,9 @@ positional arguments: options: -h, --help show this help message and exit - -lp, --local_path Indicate that the model path provided is the local path, must be used with -mt - -hf, --huggingface Load model from Hugging Face Hub, must be used with -mt + -lp, --local_path Indicate that the model path provided is the local path + -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub -n, --normalize Normalize the embeddings -nt, --no_truncate Not truncate the embeddings ``` @@ -274,6 +282,10 @@ nexa embed sentence-transformers/all-MiniLM-L6-v2:gguf-fp16 "I love Nexa AI." >> ### Convert and quantize a Hugging Face Model to GGUF +Additional package `nexa-gguf` is required to run this command. + +You can install it by `pip install "nexaai[convert]"` or `pip install nexa-gguf`. + ``` nexa convert HF_MODEL_PATH [ftype] [output_file] usage: nexa convert [-h] [-t NTHREAD] [--convert_type CONVERT_TYPE] [--bigendian] [--use_temp_file] [--no_lazy] @@ -312,6 +324,7 @@ options: --only_copy Only copy tensors (ignores ftype, allow_requantize, and quantize_output_tensor) --pure Quantize all tensors to the default type --keep_split Quantize to the same number of shards + -ms --modelscope Load model from ModelScope Hub ``` #### Example @@ -335,16 +348,17 @@ Start a local server using models on your local computer. ``` nexa server MODEL_PATH -usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in S3 options: -h, --help show this help message and exit - -lp, --local_path Indicate that the model path provided is the local path, must be used with -mt - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] - -hf, --huggingface Load model from Hugging Face Hub, must be used with -mt + -lp, --local_path Indicate that the model path provided is the local path + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub --host HOST Host to bind the server to --port PORT Port to bind the server to --reload Enable automatic reloading on code changes diff --git a/CMakeLists.txt b/CMakeLists.txt index 4670bff2..41738eb8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,9 +3,43 @@ project(nexa_gguf) include(ExternalProject) +# Platform-specific settings +if(WIN32) + # Windows-specific settings + add_definitions(-D_CRT_SECURE_NO_WARNINGS) + # OpenMP is optional on Windows + find_package(OpenMP QUIET) + if(NOT OpenMP_FOUND) + message(STATUS "OpenMP not found - OpenMP support will be disabled") + set(OpenMP_C_FLAGS "") + set(OpenMP_CXX_FLAGS "") + set(OpenMP_EXE_LINKER_FLAGS "") + endif() +elseif(APPLE) + # macOS-specific settings + find_package(OpenMP QUIET) + if(NOT OpenMP_FOUND) + message(STATUS "OpenMP not found - OpenMP support will be disabled") + set(OpenMP_C_FLAGS "") + set(OpenMP_CXX_FLAGS "") + set(OpenMP_EXE_LINKER_FLAGS "") + endif() +else() + # Linux and other Unix systems + find_package(OpenMP REQUIRED) +endif() + set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_CXX_STANDARD 17) +# Windows-specific configurations +if(WIN32) + add_definitions(-D_CRT_SECURE_NO_WARNINGS) + add_definitions(-DNOMINMAX) + add_definitions(-D_WIN32_WINNT=0x0A00) # Target Windows 10 or later + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + # Function to collect all user-defined options function(get_all_options output_var) get_cmake_property(variables VARIABLES) @@ -101,10 +135,11 @@ if(STABLE_DIFFUSION_BUILD) -DBUILD_SHARED_LIBS=ON -DSD_METAL=${GGML_METAL} -DSD_CUBLAS=${GGML_CUDA} + -DSD_HIPBLAS=${GGML_HIPBLAS} + -DSD_VULKAN=${GGML_VULKAN} BUILD_ALWAYS 1 BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release -- ${MSBUILD_ARGS} INSTALL_COMMAND ${CMAKE_COMMAND} --build . --config Release --target install - LOG_INSTALL 1 ) endif() @@ -114,6 +149,18 @@ if(LLAMA_BUILD) set(LLAMA_CUDA ${GGML_CUDA}) set(LLAMA_METAL ${GGML_METAL}) + if(WIN32) + # Add Windows-specific definitions and flags for llama.cpp + list(APPEND COMMON_CMAKE_OPTIONS + -DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=ON + -DLLAMA_NATIVE=OFF # Disable native CPU optimizations on Windows + -DLLAMA_DISABLE_CXXABI=ON # Disable cxxabi.h dependency + ) + + # Add compile definition for all targets + add_compile_definitions(LLAMA_DISABLE_CXXABI) + endif() + ExternalProject_Add(llama_project SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/dependency/llama.cpp BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/llama_build @@ -123,8 +170,12 @@ if(LLAMA_BUILD) -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR}/llama_install -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_CXX_STANDARD=17 + -DBUILD_SHARED_LIBS=ON -DLLAMA_CUDA=${LLAMA_CUDA} -DLLAMA_METAL=${LLAMA_METAL} + -DCMAKE_C_FLAGS=${OpenMP_C_FLAGS} + -DCMAKE_CXX_FLAGS=${OpenMP_CXX_FLAGS} + -DCMAKE_EXE_LINKER_FLAGS=${OpenMP_EXE_LINKER_FLAGS} -DGGML_AVX=$,$>>,OFF,ON> -DGGML_AVX2=$,$>>,OFF,ON> -DGGML_FMA=$,$>>,OFF,ON> @@ -137,8 +188,13 @@ if(LLAMA_BUILD) endif() # bark_cpp project -option(BARK_BUILD "Build bark.cpp" ON) +# Temporarily disabled since version v0.0.9.3 +option(BARK_BUILD "Build bark.cpp" OFF) if(BARK_BUILD) + # Filter out HIPBLAS and Vulkan options for bark.cpp since it doesn't support them + set(BARK_CMAKE_OPTIONS ${USER_DEFINED_OPTIONS}) + list(FILTER BARK_CMAKE_OPTIONS EXCLUDE REGEX "GGML_HIPBLAS|GGML_VULKAN") + ExternalProject_Add(bark_project SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/dependency/bark.cpp BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/bark_build @@ -149,7 +205,7 @@ if(BARK_BUILD) -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_CXX_STANDARD=17 -DGGML_CUDA=${GGML_CUDA} - -DGGML_METAL=${GGML_METAL} + -DGGML_METAL=OFF -DBUILD_SHARED_LIBS=ON -DBARK_BUILD_EXAMPLES=OFF BUILD_ALWAYS 1 diff --git a/Package.swift b/Package.swift new file mode 100644 index 00000000..a5ffa87f --- /dev/null +++ b/Package.swift @@ -0,0 +1,32 @@ +// swift-tools-version: 6.0 + +import PackageDescription + +let package = Package( + name: "NexaSwift", + platforms: [ + .macOS(.v15), + .iOS(.v18), + .watchOS(.v11), + .tvOS(.v18), + .visionOS(.v2) + ], + products: [ + .library(name: "NexaSwift", targets: ["NexaSwift"]), + ], + dependencies: [ + .package(url: "https://github.com/ggerganov/llama.cpp.git", branch: "master") + ], + targets: [ + .target( + name: "NexaSwift", + dependencies: [ + .product(name: "llama", package: "llama.cpp") + ], + path: "swift/Sources/NexaSwift"), + .testTarget( + name: "NexaSwiftTests", + dependencies: ["NexaSwift"], + path: "swift/Tests/NexaSwiftTests"), + ] +) diff --git a/README.md b/README.md index d685f756..17887873 100644 --- a/README.md +++ b/README.md @@ -1,59 +1,50 @@ -
- -

Nexa SDK

- -[![MacOS][MacOS-image]][release-url] [![Linux][Linux-image]][release-url] [![Windows][Windows-image]][release-url] - -[![GitHub Release](https://img.shields.io/github/v/release/NexaAI/nexa-sdk)](https://github.com/NexaAI/nexa-sdk/releases/latest) [![Build workflow](https://img.shields.io/github/actions/workflow/status/NexaAI/nexa-sdk/ci.yaml?label=CI&logo=github)](https://github.com/NexaAI/nexa-sdk/actions/workflows/ci.yaml?query=branch%3Amain) ![GitHub License](https://img.shields.io/github/license/NexaAI/nexa-sdk) - - - -[![Discord](https://dcbadge.limes.pink/api/server/thRu2HaK4D?style=flat&compact=true)](https://discord.gg/thRu2HaK4D) + -[On-device Model Hub](https://model-hub.nexa4ai.com/) / [Nexa SDK Documentation](https://docs.nexaai.com/) +

Nexa SDK - Local On-Device Inference Framework

[release-url]: https://github.com/NexaAI/nexa-sdk/releases [Windows-image]: https://img.shields.io/badge/windows-0078D4?logo=windows [MacOS-image]: https://img.shields.io/badge/-MacOS-black?logo=apple [Linux-image]: https://img.shields.io/badge/-Linux-333?logo=ubuntu -
+[![MacOS][MacOS-image]][release-url] [![Linux][Linux-image]][release-url] [![Windows][Windows-image]][release-url] [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FNexaAI%2Fnexa-sdk%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/NexaAI/nexa-sdk) [![Build workflow](https://img.shields.io/github/actions/workflow/status/NexaAI/nexa-sdk/ci.yaml?label=CI&logo=github)](https://github.com/NexaAI/nexa-sdk/actions/workflows/ci.yaml?query=branch%3Amain) ![GitHub License](https://img.shields.io/github/license/NexaAI/nexa-sdk) [![GitHub Release](https://img.shields.io/github/v/release/NexaAI/nexa-sdk)](https://github.com/NexaAI/nexa-sdk/releases/latest) -Nexa SDK is a comprehensive toolkit for supporting **ONNX** and **GGML** models. It supports text generation, image generation, vision-language models (VLM), and speech-to-text (ASR), and text-to-speech (TTS) capabilities. Additionally, it offers an OpenAI-compatible API server with JSON schema mode for function calling and streaming support, and a user-friendly Streamlit UI. Users can run Nexa SDK in any device with Python environment, and GPU acceleration is supported, including CUDA, Metal, and ROCm. An executable version is also available. +[**On-Device Model Hub**](https://nexa.ai/models) | [**Documentation**](https://docs.nexa.ai/) | [**Discord**](https://discord.gg/thRu2HaK4D) | [**Blogs**](https://nexa.ai/blogs) | [**X (Twitter)**](https://x.com/nexa_ai) - +**Nexa SDK** is a local on-device inference framework for ONNX and GGML models, supporting text generation, image generation, vision-language models (VLM), audio-language models, speech-to-text (ASR), and text-to-speech (TTS) capabilities. Installable via Python Package or Executable Installer. + +### Features + +- **Device Support:** CPU, GPU (CUDA, Metal, ROCm), iOS +- **Server:** OpenAI-compatible API, JSON schema for function calling and streaming support +- **Local UI:** Streamlit for interactive model deployment and testing ## Latest News 🔥 -- [2024/10] Support embedding model: `nexa embed ` -- [2024/10] Support pull and run supported Computer Vision models in GGUF format from HuggingFace: `nexa run -hf -mt COMPUTER_VISION` -- [2024/10] Support VLM in local server. -- [2024/10] Added option to customize maximum context window for NLP and VLM models. -- [2024/10] Support running model from user's local path -- [2024/10] Added LoRA support for NLP models. -- [2024/10] Added support for whisper-large-v3-turbo: `nexa run faster-whisper-large-turbo` -- [2024/10] Added support for AMD-Llama-135m: `nexa run AMD-Llama-135m:fp16` -- [2024/09] Nexa now has executables for easy installation: [Install Nexa SDK](https://nexaai.com/download-sdk) ✨ -- [2024/09] Added support for Llama 3.2 models: `nexa run llama3.2` -- [2024/09] Added support for Qwen2.5, Qwen2.5-coder and Qwen2.5-Math models: `nexa run qwen2.5` -- [2024/09] Support pull and run NLP models in GGUF format from HuggingFace: `nexa run -hf -mt NLP` -- [2024/09] Added support for ROCm -- [2024/09] Added support for Phi-3.5 models: `nexa run phi3.5` -- [2024/09] Added support for OpenELM models: `nexa run openelm` -- [2024/09] Introduced logits API support for more advanced model interactions -- [2024/09] Added support for Flux models: `nexa run flux` -- [2024/09] Added support for Stable Diffusion 3 model: `nexa run sd3` -- [2024/09] Added support for Stable Diffusion 2.1 model: `nexa run sd2-1` +- Support Nexa AI's own vision language model (0.9B parameters): `nexa run omniVLM` and audio language model (2.9B parameters): `nexa run omniaudio` +- Support audio language model: `nexa run qwen2audio`, **we are the first open-source toolkit to support audio language model with GGML tensor library.** +- Support iOS Swift binding for local inference on **iOS mobile** devices. +- Support embedding model: `nexa embed ` +- Support pull and run supported Computer Vision models in GGUF format from HuggingFace or ModelScope: `nexa run -hf -mt COMPUTER_VISION` or `nexa run -ms -mt COMPUTER_VISION` +- Support pull and run NLP models in GGUF format from HuggingFace or ModelScope: `nexa run -hf -mt NLP` or `nexa run -ms -mt NLP` Welcome to submit your requests through [issues](https://github.com/NexaAI/nexa-sdk/issues/new/choose), we ship weekly. -## Installation - Executable +## Install Option 1: Executable Installer -### macOS +

+ + macOS Installer + +

-[Download](https://public-storage.nexa4ai.com/nexa-sdk-executable-installer/nexa-macos-installer.pkg) +

+ + Windows Installer + +

-### Linux + Linux Installer ```bash curl -fsSL https://public-storage.nexa4ai.com/install.sh | sh @@ -70,32 +61,24 @@ nexa-exe -### Windows - -Coming soon. Install with Python package below 👇 - -## Installation - Python Package - -We have released pre-built wheels for various Python versions, platforms, and backends for convenient installation on our [index page](https://nexaai.github.io/nexa-sdk/whl/). +## Install Option 2: Python Package -> [!NOTE] -> -> 1. If you want to use ONNX model, just replace `pip install nexaai` with `pip install "nexaai[onnx]"` in provided commands. -> 2. If you want to convert and quantize huggingface models to GGUF models, just replace `pip install nexaai` with `pip install "nexaai[nexa-gguf]"`. -> 3. For Chinese developers, we recommend you to use Tsinghua Open Source Mirror as extra index url, just replace `--extra-index-url https://pypi.org/simple` with `--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple` in provided commands. +We have released pre-built wheels for various Python versions, platforms, and backends for convenient installation on our [index page](https://github.nexa.ai/whl/). -#### CPU +
CPU ```bash -pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/cpu --extra-index-url https://pypi.org/simple --no-cache-dir +pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/cpu --extra-index-url https://pypi.org/simple --no-cache-dir ``` -#### GPU (Metal) +
+ +
Apple GPU (Metal) For the GPU version supporting **Metal (macOS)**: ```bash -CMAKE_ARGS="-DGGML_METAL=ON -DSD_METAL=ON" pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/metal --extra-index-url https://pypi.org/simple --no-cache-dir +CMAKE_ARGS="-DGGML_METAL=ON -DSD_METAL=ON" pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/metal --extra-index-url https://pypi.org/simple --no-cache-dir ```
@@ -108,37 +91,38 @@ wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge bash Miniforge3-MacOSX-arm64.sh conda create -n nexasdk python=3.10 conda activate nexasdk -CMAKE_ARGS="-DGGML_METAL=ON -DSD_METAL=ON" pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/metal --extra-index-url https://pypi.org/simple --no-cache-dir +CMAKE_ARGS="-DGGML_METAL=ON -DSD_METAL=ON" pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/metal --extra-index-url https://pypi.org/simple --no-cache-dir ``` +
-#### GPU (CUDA) +
Nvidia GPU (CUDA) To install with CUDA support, make sure you have [CUDA Toolkit 12.0 or later](https://developer.nvidia.com/cuda-12-0-0-download-archive) installed. For **Linux**: ```bash -CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON" pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir +CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON" pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir ``` For **Windows PowerShell**: ```bash -$env:CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON"; pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir +$env:CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON"; pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir ``` For **Windows Command Prompt**: ```bash -set CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON" & pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir +set CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON" & pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir ``` For **Windows Git Bash**: ```bash -CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON" pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir +CMAKE_ARGS="-DGGML_CUDA=ON -DSD_CUBLAS=ON" pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/cu124 --extra-index-url https://pypi.org/simple --no-cache-dir ```
@@ -156,39 +140,45 @@ CMAKE_ARGS="-DCMAKE_CXX_FLAGS=-fopenmp" pip install nexaai
-#### GPU (ROCm) +
+ +
AMD GPU (ROCm) To install with ROCm support, make sure you have [ROCm 6.2.1 or later](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.1/install/quick-start.html) installed. For **Linux**: ```bash -CMAKE_ARGS="-DGGML_HIPBLAS=on" pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/rocm621 --extra-index-url https://pypi.org/simple --no-cache-dir +CMAKE_ARGS="-DGGML_HIPBLAS=on" pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/rocm621 --extra-index-url https://pypi.org/simple --no-cache-dir ``` -#### GPU (Vulkan) +
+ +
GPU (Vulkan) To install with Vulkan support, make sure you have [Vulkan SDK 1.3.261.1 or later](https://vulkan.lunarg.com/sdk/home) installed. For **Windows PowerShell**: ```bash -$env:CMAKE_ARGS="-DGGML_VULKAN=on"; pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/vulkan --extra-index-url https://pypi.org/simple --no-cache-dir +$env:CMAKE_ARGS="-DGGML_VULKAN=on"; pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/vulkan --extra-index-url https://pypi.org/simple --no-cache-dir ``` For **Windows Command Prompt**: ```bash -set CMAKE_ARGS="-DGGML_VULKAN=on" & pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/vulkan --extra-index-url https://pypi.org/simple --no-cache-dir +set CMAKE_ARGS="-DGGML_VULKAN=on" & pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/vulkan --extra-index-url https://pypi.org/simple --no-cache-dir ``` For **Windows Git Bash**: ```bash -CMAKE_ARGS="-DGGML_VULKAN=on" pip install nexaai --prefer-binary --index-url https://nexaai.github.io/nexa-sdk/whl/vulkan --extra-index-url https://pypi.org/simple --no-cache-dir +CMAKE_ARGS="-DGGML_VULKAN=on" pip install nexaai --prefer-binary --index-url https://github.nexa.ai/whl/vulkan --extra-index-url https://pypi.org/simple --no-cache-dir ``` -### Local Build +
+ +
Local Build How to clone this repo @@ -208,105 +198,109 @@ Then you can build and install the package pip install -e . ``` -## Features - -- **Model Support:** - - - **ONNX & GGML models** - - **Conversion Engine** - - **Inference Engine**: - - **Text Generation** - - **Image Generation** - - **Vision-Language Models (VLM)** - - **Speech-to-Text (ASR)** - -Detailed API documentation is available [here](https://docs.nexaai.com/). +
-- **Server:** - - OpenAI-compatible API - - JSON schema mode for function calling - - Streaming support -- **Streamlit UI** for interactive model deployment and testing +## Differentiation Below is our differentiation from other similar tools: -| **Feature** | **[Nexa SDK](https://github.com/NexaAI/nexa-sdk)** | **[ollama](https://github.com/ollama/ollama)** | **[Optimum](https://github.com/huggingface/optimum)** | **[LM Studio](https://github.com/lmstudio-ai)** | -| -------------------------- | :------------------------------------------------: | :--------------------------------------------: | :---------------------------------------------------: | :---------------------------------------------: | -| **GGML Support** | ✅ | ✅ | ❌ | ✅ | -| **ONNX Support** | ✅ | ❌ | ✅ | ❌ | -| **Text Generation** | ✅ | ✅ | ✅ | ✅ | -| **Image Generation** | ✅ | ❌ | ❌ | ❌ | -| **Vision-Language Models** | ✅ | ✅ | ✅ | ✅ | -| **Text-to-Speech** | ✅ | ❌ | ✅ | ❌ | -| **Server Capability** | ✅ | ✅ | ✅ | ✅ | -| **User Interface** | ✅ | ❌ | ❌ | ✅ | +| **Feature** | **[Nexa SDK](https://github.com/NexaAI/nexa-sdk)** | **[ollama](https://github.com/ollama/ollama)** | **[Optimum](https://github.com/huggingface/optimum)** | **[LM Studio](https://github.com/lmstudio-ai)** | +| --------------------------- | :------------------------------------------------: | :--------------------------------------------: | :---------------------------------------------------: | :---------------------------------------------: | +| **GGML Support** | ✅ | ✅ | ❌ | ✅ | +| **ONNX Support** | ✅ | ❌ | ✅ | ❌ | +| **Text Generation** | ✅ | ✅ | ✅ | ✅ | +| **Image Generation** | ✅ | ❌ | ❌ | ❌ | +| **Vision-Language Models** | ✅ | ✅ | ✅ | ✅ | +| **Audio-Language Models** | ✅ | ❌ | ❌ | ❌ | +| **Text-to-Speech** | ✅ | ❌ | ✅ | ❌ | +| **Server Capability** | ✅ | ✅ | ✅ | ✅ | +| **User Interface** | ✅ | ❌ | ❌ | ✅ | +| **Executable Installation** | ✅ | ✅ | ❌ | ✅ | ## Supported Models & Model Hub Our on-device model hub offers all types of quantized models (text, image, audio, multimodal) with filters for RAM, file size, Tasks, etc. to help you easily explore models with UI. Explore on-device models at [On-device Model Hub](https://model-hub.nexa4ai.com/) -Supported models (full list at [Model Hub](https://nexa.ai/models)): +Supported model examples (full list at [Model Hub](https://nexa.ai/models)): | Model | Type | Format | Command | | ------------------------------------------------------------------------------------------------------- | --------------- | --------- | -------------------------------------- | -| [octopus-v2](https://www.nexaai.com/NexaAI/Octopus-v2/gguf-q4_0/readme) | NLP | GGUF | `nexa run octopus-v2` | -| [octopus-v4](https://www.nexaai.com/NexaAI/Octopus-v4/gguf-q4_0/readme) | NLP | GGUF | `nexa run octopus-v4` | -| [gpt2](https://nexaai.com/openai/gpt2/gguf-q4_0/readme) | NLP | GGUF | `nexa run gpt2` | -| [tinyllama](https://www.nexaai.com/TinyLlama/TinyLlama-1.1B-Chat-v1.0/gguf-fp16/readme) | NLP | GGUF | `nexa run tinyllama` | -| [llama2](https://www.nexaai.com/meta/Llama2-7b-chat/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run llama2` | -| [llama2-uncensored](https://www.nexaai.com/georgesung/Llama2-7b-chat-uncensored/gguf-q4_0/readme) | NLP | GGUF | `nexa run llama2-uncensored` | -| [llama2-function-calling](https://www.nexaai.com/Trelis/Llama2-7b-function-calling/gguf-q4_K_M/readme) | NLP | GGUF | `nexa run llama2-function-calling` | -| [llama3](https://www.nexaai.com/meta/Llama3-8B-Instruct/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run llama3` | -| [llama3.1](https://www.nexaai.com/meta/Llama3.1-8B-Instruct/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run llama3.1` | -| [llama3.2](https://nexaai.com/meta/Llama3.2-3B-Instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run llama3.2` | -| [llama3-uncensored](https://www.nexaai.com/Orenguteng/Llama3-8B-Lexi-Uncensored/gguf-q4_K_M/readme) | NLP | GGUF | `nexa run llama3-uncensored` | -| [gemma](https://www.nexaai.com/google/gemma-1.1-2b-instruct/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run gemma` | -| [gemma2](https://www.nexaai.com/google/gemma-2-2b-instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run gemma2` | -| [qwen1.5](https://www.nexaai.com/Qwen/Qwen1.5-7B-Instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run qwen1.5` | -| [qwen2](https://www.nexaai.com/Qwen/Qwen2-1.5B-Instruct/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run qwen2` | -| [qwen2.5](https://www.nexaai.com/Qwen/Qwen2.5-1.5B-Instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run qwen2.5` | -| [mathqwen](https://nexaai.com/Qwen/Qwen2.5-Math-1.5B-Instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run mathqwen` | -| [codeqwen](https://www.nexaai.com/Qwen/CodeQwen1.5-7B-Instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run codeqwen` | -| [mistral](https://www.nexaai.com/mistralai/Mistral-7B-Instruct-v0.3/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run mistral` | -| [dolphin-mistral](https://www.nexaai.com/CognitiveComputations/dolphin-2.8-mistral-7b/gguf-q4_0/readme) | NLP | GGUF | `nexa run dolphin-mistral` | -| [codegemma](https://www.nexaai.com/google/codegemma-2b/gguf-q4_0/readme) | NLP | GGUF | `nexa run codegemma` | -| [codellama](https://www.nexaai.com/meta/CodeLlama-7b-Instruct/gguf-q2_K/readme) | NLP | GGUF | `nexa run codellama` | -| [deepseek-coder](https://www.nexaai.com/DeepSeek/deepseek-coder-1.3b-instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run deepseek-coder` | -| [phi2](https://www.nexaai.com/microsoft/Phi-2/gguf-q4_0/readme) | NLP | GGUF | `nexa run phi2` | -| [phi3](https://www.nexaai.com/microsoft/Phi-3-mini-128k-instruct/gguf-q4_0/readme) | NLP | GGUF/ONNX | `nexa run phi3` | -| [phi3.5](https://nexaai.com/microsoft/Phi-3.5-mini-instruct/gguf-q4_0/readme) | NLP | GGUF | `nexa run phi3.5` | -| [openelm](https://nexaai.com/apple/OpenELM-3B/gguf-q4_K_M/readme) | NLP | GGUF | `nexa run openelm` | -| [AMD-Llama-135m](https://nexaai.com/amd/AMD-Llama-135m/gguf-fp16/readme) | NLP | GGUF | `nexa run AMD-Llama-135m:fp16` | +| [omniaudio](https://nexa.ai/NexaAI/omniaudio/gguf-q4_0/readme) | AudioLM | GGUF | `nexa run omniaudio` | +| [qwen2audio](https://nexa.ai/Qwen/Qwen2-Audio-7.8B-Instruct/gguf-q4_K_M/readme) | AudioLM | GGUF | `nexa run qwen2audio` | +| [octopus-v2](https://www.nexaai.com/NexaAI/Octopus-v2/gguf-q4_0/readme) | Function Call | GGUF | `nexa run octopus-v2` | +| [octo-net](https://www.nexaai.com/NexaAI/Octo-net/gguf-q4_0/readme) | Text | GGUF | `nexa run octo-net` | +| [omniVLM](https://nexa.ai/NexaAI/omniVLM/gguf-fp16/readme) | Multimodal | GGUF | `nexa run omniVLM` | | [nanollava](https://www.nexaai.com/qnguyen3/nanoLLaVA/gguf-fp16/readme) | Multimodal | GGUF | `nexa run nanollava` | | [llava-phi3](https://www.nexaai.com/xtuner/llava-phi-3-mini/gguf-q4_0/readme) | Multimodal | GGUF | `nexa run llava-phi3` | | [llava-llama3](https://www.nexaai.com/xtuner/llava-llama-3-8b-v1.1/gguf-q4_0/readme) | Multimodal | GGUF | `nexa run llava-llama3` | | [llava1.6-mistral](https://www.nexaai.com/liuhaotian/llava-v1.6-mistral-7b/gguf-q4_0/readme) | Multimodal | GGUF | `nexa run llava1.6-mistral` | | [llava1.6-vicuna](https://www.nexaai.com/liuhaotian/llava-v1.6-vicuna-7b/gguf-q4_0/readme) | Multimodal | GGUF | `nexa run llava1.6-vicuna` | -| [stable-diffusion-v1-4](https://www.nexaai.com/runwayml/stable-diffusion-v1-4/gguf-q4_0/readme) | Computer Vision | GGUF | `nexa run sd1-4` | -| [stable-diffusion-v1-5](https://www.nexaai.com/runwayml/stable-diffusion-v1-5/gguf-q4_0/readme) | Computer Vision | GGUF/ONNX | `nexa run sd1-5` | -| [stable-diffusion-v2-1](https://nexaai.com/StabilityAI/stable-diffusion-v2-1/gguf-q4_0/readme) | Computer Vision | GGUF | `nexa run sd2-1` | -| [stable-diffusion-3-medium](https://nexaai.com/StabilityAI/stable-diffusion-3-medium/gguf-q4_0/readme) | Computer Vision | GGUF | `nexa run sd3` | -| [FLUX.1-schnell](https://nexaai.com/BlackForestLabs/FLUX.1-schnell/gguf-q4_0/readme) | Computer Vision | GGUF | `nexa run flux` | -| [lcm-dreamshaper](https://www.nexaai.com/SimianLuo/lcm-dreamshaper-v7/gguf-fp16/readme) | Computer Vision | GGUF/ONNX | `nexa run lcm-dreamshaper` | -| [hassaku-lcm](https://nexaai.com/stablediffusionapi/hassaku-hentai-model-v13-LCM/gguf-fp16/readme) | Computer Vision | GGUF | `nexa run hassaku-lcm` | -| [anything-lcm](https://www.nexaai.com/Linaqruf/anything-v30-LCM/gguf-fp16/readme) | Computer Vision | GGUF | `nexa run anything-lcm` | -| [faster-whisper-tiny](https://www.nexaai.com/Systran/faster-whisper-tiny/bin-cpu-fp16/readme) | Audio | BIN | `nexa run faster-whisper-tiny` | -| [faster-whisper-small](https://www.nexaai.com/Systran/faster-whisper-small/bin-cpu-fp16/readme) | Audio | BIN | `nexa run faster-whisper-small` | -| [faster-whisper-medium](https://www.nexaai.com/Systran/faster-whisper-medium/bin-cpu-fp16/readme) | Audio | BIN | `nexa run faster-whisper-medium` | -| [faster-whisper-base](https://www.nexaai.com/Systran/faster-whisper-base/bin-cpu-fp16/readme) | Audio | BIN | `nexa run faster-whisper-base` | -| [faster-whisper-large](https://www.nexaai.com/Systran/faster-whisper-large-v3/bin-cpu-fp16/readme) | Audio | BIN | `nexa run faster-whisper-large` | -| [whisper-large-v3-turbo](https://nexaai.com/Systran/faster-whisper-large-v3-turbo/bin-cpu-fp16/readme) | Audio | BIN | `nexa run faster-whisper-large-turbo` | -| [whisper-tiny.en](https://nexaai.com/openai/whisper-tiny.en/onnx-cpu-fp32/readme) | Audio | ONNX | `nexa run whisper-tiny.en` | -| [whisper-tiny](https://nexaai.com/openai/whisper-tiny/onnx-cpu-fp32/readme) | Audio | ONNX | `nexa run whisper-tiny` | -| [whisper-small.en](https://nexaai.com/openai/whisper-small.en/onnx-cpu-fp32/readme) | Audio | ONNX | `nexa run whisper-small.en` | -| [whisper-small](https://nexaai.com/openai/whisper-small/onnx-cpu-fp32/readme) | Audio | ONNX | `nexa run whisper-small` | -| [whisper-base.en](https://nexaai.com/openai/whisper-base.en/onnx-cpu-fp32/readme) | Audio | ONNX | `nexa run whisper-base.en` | -| [whisper-base](https://nexaai.com/openai/whisper-base/onnx-cpu-fp32/readme) | Audio | ONNX | `nexa run whisper-base` | +| [llama3.2](https://nexaai.com/meta/Llama3.2-3B-Instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run llama3.2` | +| [llama3-uncensored](https://www.nexaai.com/Orenguteng/Llama3-8B-Lexi-Uncensored/gguf-q4_K_M/readme) | Text | GGUF | `nexa run llama3-uncensored` | +| [gemma2](https://www.nexaai.com/google/gemma-2-2b-instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run gemma2` | +| [qwen2.5](https://www.nexaai.com/Qwen/Qwen2.5-1.5B-Instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run qwen2.5` | +| [mathqwen](https://nexaai.com/Qwen/Qwen2.5-Math-1.5B-Instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run mathqwen` | +| [codeqwen](https://www.nexaai.com/Qwen/CodeQwen1.5-7B-Instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run codeqwen` | +| [mistral](https://www.nexaai.com/mistralai/Mistral-7B-Instruct-v0.3/gguf-q4_0/readme) | Text | GGUF/ONNX | `nexa run mistral` | +| [deepseek-coder](https://www.nexaai.com/DeepSeek/deepseek-coder-1.3b-instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run deepseek-coder` | +| [phi3.5](https://nexaai.com/microsoft/Phi-3.5-mini-instruct/gguf-q4_0/readme) | Text | GGUF | `nexa run phi3.5` | +| [openelm](https://nexaai.com/apple/OpenELM-3B/gguf-q4_K_M/readme) | Text | GGUF | `nexa run openelm` | +| [stable-diffusion-v2-1](https://nexaai.com/StabilityAI/stable-diffusion-v2-1/gguf-q4_0/readme) | Image Generation | GGUF | `nexa run sd2-1` | +| [stable-diffusion-3-medium](https://nexaai.com/StabilityAI/stable-diffusion-3-medium/gguf-q4_0/readme) | Image Generation | GGUF | `nexa run sd3` | +| [FLUX.1-schnell](https://nexaai.com/BlackForestLabs/FLUX.1-schnell/gguf-q4_0/readme) | Image Generation | GGUF | `nexa run flux` | +| [lcm-dreamshaper](https://www.nexaai.com/SimianLuo/lcm-dreamshaper-v7/gguf-fp16/readme) | Image Generation | GGUF/ONNX | `nexa run lcm-dreamshaper` | +| [whisper-large-v3-turbo](https://nexaai.com/Systran/faster-whisper-large-v3-turbo/bin-cpu-fp16/readme) | Speech-to-Text | BIN | `nexa run faster-whisper-large-turbo` | +| [whisper-tiny.en](https://nexaai.com/openai/whisper-tiny.en/onnx-cpu-fp32/readme) | Speech-to-Text | ONNX | `nexa run whisper-tiny.en` | | [mxbai-embed-large-v1](https://nexa.ai/mixedbread-ai/mxbai-embed-large-v1/gguf-fp16/readme) | Embedding | GGUF | `nexa embed mxbai` | | [nomic-embed-text-v1.5](https://nexa.ai/nomic-ai/nomic-embed-text-v1.5/gguf-fp16/readme) | Embedding | GGUF | `nexa embed nomic` | -| [all-MiniLM-L6-v2](https://nexa.ai/sentence-transformers/all-MiniLM-L6-v2/gguf-fp16/readme) | Embedding | GGUF | `nexa embed all-MiniLM-L6-v2:fp16` | | [all-MiniLM-L12-v2](https://nexa.ai/sentence-transformers/all-MiniLM-L12-v2/gguf-fp16/readme) | Embedding | GGUF | `nexa embed all-MiniLM-L12-v2:fp16` | +| [bark-small](https://nexa.ai/suno/bark-small/gguf-fp16/readme) | Text-to-Speech | GGUF | `nexa run bark-small:fp16` | + +## Run Models from 🤗 HuggingFace or 🤖 ModelScope -## CLI Reference +You can pull, convert (to .gguf), quantize and run [llama.cpp supported](https://github.com/ggerganov/llama.cpp#description) text generation models from HF or MS with Nexa SDK. + +### Run .gguf File + +Use `nexa run -hf ` or `nexa run -ms ` to run models with provided .gguf files: + +```bash +nexa run -hf Qwen/Qwen2.5-Coder-7B-Instruct-GGUF +``` + +```bash +nexa run -ms Qwen/Qwen2.5-Coder-7B-Instruct-GGUF +``` + +> **Note:** You will be prompted to select a single .gguf file. If your desired quantization version has multiple split files (like fp16-00001-of-00004), please use Nexa's conversion tool (see below) to convert and quantize the model locally. + +### Convert .safetensors Files + +Install [Nexa Python package](https://github.com/NexaAI/nexa-sdk?tab=readme-ov-file#install-option-2-python-package), and install Nexa conversion tool with `pip install "nexaai[convert]"`, then convert models from huggingface with `nexa convert `: + +```bash +nexa convert HuggingFaceTB/SmolLM2-135M-Instruct +``` + +Or you can convert models from ModelScope with `nexa convert -ms `: + +```bash +nexa convert -ms Qwen/Qwen2.5-7B-Instruct +``` + +> **Note:** Check our [leaderboard](https://nexa.ai/leaderboard) for performance benchmarks of different quantized versions of mainstream language models and [HuggingFace docs](https://huggingface.co/docs/optimum/en/concept_guides/quantization) to learn about quantization options. + +📋 You can view downloaded and converted models with `nexa list` + +## Documentation + +> [!NOTE] +> +> 1. If you want to use ONNX model, just replace `pip install nexaai` with `pip install "nexaai[onnx]"` in provided commands. +> 2. If you want to run benchmark evaluation, just replace `pip install nexaai` with `pip install "nexaai[eval]"` in provided commands. +> 3. If you want to convert and quantize huggingface models to GGUF models, just replace `pip install nexaai` with `pip install "nexaai[convert]"` in provided commands. +> 4. For Chinese developers, we recommend you to use Tsinghua Open Source Mirror as extra index url, just replace `--extra-index-url https://pypi.org/simple` with `--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple` in provided commands. + +### CLI Reference Here's a brief overview of the main CLI commands: @@ -325,15 +319,22 @@ Here's a brief overview of the main CLI commands: For detailed information on CLI commands and usage, please refer to the [CLI Reference](CLI.md) document. -## Start Local Server +### Start Local Server To start a local server using models on your local computer, you can use the `nexa server` command. For detailed information on server setup, API endpoints, and usage examples, please refer to the [Server Reference](SERVER.md) document. +### Swift Package + +**[Swift SDK](https://github.com/NexaAI/nexa-sdk/tree/main/swift):** Provides a Swifty API, allowing Swift developers to easily integrate and use llama.cpp models in their projects. + +[**More Docs**](https://docs.nexa.ai/) + ## Acknowledgements We would like to thank the following projects: - [llama.cpp](https://github.com/ggerganov/llama.cpp) - [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) +- [bark.cpp](https://github.com/PABannier/bark.cpp) - [optimum](https://github.com/huggingface/optimum) diff --git a/SERVER.md b/SERVER.md index 10462e8c..b75efa45 100644 --- a/SERVER.md +++ b/SERVER.md @@ -8,9 +8,10 @@ usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] model_path ### Options: -- `-lp, --local_path`: Indicate that the model path provided is the local path, must be used with -mt -- `-mt, --model_type`: Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] -- `-hf, --huggingface`: Load model from Hugging Face Hub, must be used with -mt +- `-lp, --local_path`: Indicate that the model path provided is the local path +- `-mt, --model_type`: Indicate the model running type, must be used with -lp or -hf or ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] +- `-hf, --huggingface`: Load model from Hugging Face Hub +- `-ms, --modelscope`: Load model from ModelScope Hub - `--host`: Host to bind the server to - `--port`: Port to bind the server to - `--reload`: Enable automatic reloading on code changes diff --git a/android/llama.android/.gitignore b/android/llama.android/.gitignore new file mode 100644 index 00000000..347e252e --- /dev/null +++ b/android/llama.android/.gitignore @@ -0,0 +1,33 @@ +# Gradle files +.gradle/ +build/ + +# Local configuration file (sdk path, etc) +local.properties + +# Log/OS Files +*.log + +# Android Studio generated files and folders +captures/ +.externalNativeBuild/ +.cxx/ +*.apk +output.json + +# IntelliJ +*.iml +.idea/ +misc.xml +deploymentTargetDropDown.xml +render.experimental.xml + +# Keystore files +*.jks +*.keystore + +# Google Services (e.g. APIs or Firebase) +google-services.json + +# Android Profiling +*.hprof diff --git a/android/llama.android/README.md b/android/llama.android/README.md new file mode 100644 index 00000000..aa91234c --- /dev/null +++ b/android/llama.android/README.md @@ -0,0 +1,54 @@ +# Nexa + +**Nexa** is a Kotlin wrapper for the [llama.cpp](https://github.com/ggerganov/llama.cpp.git) library. offering a convenient Kotlin API for Android developers. It allows seamless integration of llama.cpp models into Android applications. +**NOTE:** Currently, Nexa supports Vision-Language Model (VLM) inference capabilities. + +## Installation + +To add Nexa to your Android project, follow these steps: + +- Create a libs folder in your project’s root directory. +- Copy the .aar file into the libs folder. +- Add dependency to your build.gradle file: + +``` +implementation files("libs/com.nexa.aar") +``` + +## Usage +### 1. Initialize NexaSwift with model path and projector path + +Create a configuration and initialize NexaSwift with the path to your model file: + +```kotlin +nexaVlmInference = NexaVlmInference(pathToModel, + mmprojectorPath, imagePath, + maxNewTokens = 128, + stopWords = listOf("")) +nexaVlmInference.loadModel() +``` + +### 2. Completion API + +#### Streaming Mode + +```swift +nexaVlmInference.createCompletionStream(prompt, imagePath) + ?.catch { + print(it.message) + } + ?.collect { print(it) } +``` + +### 3. release all resources +```kotlin +nexaVlmInference.dispose() +``` + +## Quick Start + +Open the [android test project](./app-java) folder in Android Studio and run the project. + +## Download Models + +You can download models from the [Nexa AI ModelHub](https://nexa.ai/models). \ No newline at end of file diff --git a/android/llama.android/app-java/.gitignore b/android/llama.android/app-java/.gitignore new file mode 100644 index 00000000..42df58a2 --- /dev/null +++ b/android/llama.android/app-java/.gitignore @@ -0,0 +1,2 @@ +/build +!*.png \ No newline at end of file diff --git a/android/llama.android/app-java/build.gradle b/android/llama.android/app-java/build.gradle new file mode 100644 index 00000000..2729f317 --- /dev/null +++ b/android/llama.android/app-java/build.gradle @@ -0,0 +1,52 @@ +plugins { + id 'com.android.application' + id 'kotlin-android' +} + +android { + namespace 'ai.nexa.app_java' + compileSdk 34 + + defaultConfig { + applicationId "ai.nexa.app_java" + minSdk 33 + targetSdk 34 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_17 // or VERSION_1_8 + targetCompatibility JavaVersion.VERSION_17 // or VERSION_1_8 + } + + kotlinOptions { + jvmTarget = "17" // or "1.8" + } +} + +dependencies { + + implementation 'androidx.appcompat:appcompat:1.7.0' + implementation 'com.google.android.material:material:1.12.0' + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.2.1' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.6.1' + + implementation "org.jetbrains.kotlin:kotlin-stdlib:1.9.20" + implementation "org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3" + + implementation 'com.github.bumptech.glide:glide:4.16.0' + annotationProcessor 'com.github.bumptech.glide:compiler:4.16.0' + + implementation project(":llama") + // implementation files("libs/com.nexa.aar") +} \ No newline at end of file diff --git a/android/llama.android/app-java/proguard-rules.pro b/android/llama.android/app-java/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/android/llama.android/app-java/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/android/llama.android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java b/android/llama.android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java new file mode 100644 index 00000000..7f3c2198 --- /dev/null +++ b/android/llama.android/app-java/src/androidTest/java/ai/nexa/app_java/ExampleInstrumentedTest.java @@ -0,0 +1,26 @@ +package ai.nexa.app_java; + +import android.content.Context; + +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see Testing documentation + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void useAppContext() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + assertEquals("ai.nexa.app_java", appContext.getPackageName()); + } +} \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/AndroidManifest.xml b/android/llama.android/app-java/src/main/AndroidManifest.xml new file mode 100644 index 00000000..8aaea0a2 --- /dev/null +++ b/android/llama.android/app-java/src/main/AndroidManifest.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + > + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/ImagePathHelper.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/ImagePathHelper.java new file mode 100644 index 00000000..a8b0ef00 --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/ImagePathHelper.java @@ -0,0 +1,112 @@ +package ai.nexa.app_java; + +import android.content.Context; +import android.database.Cursor; +import android.net.Uri; +import android.provider.DocumentsContract; +import android.provider.MediaStore; +import android.util.Log; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +public class ImagePathHelper { + private static final String TAG = "MessageProcessor"; + private final Context context; + + public ImagePathHelper(Context context) { + this.context = context; + } + + public String getPathFromUri(String uriString) { + try { + Uri uri = Uri.parse(uriString); + + // Handle "content://" scheme + if ("content".equals(uri.getScheme())) { + // Handle Google Photos and other document providers + if (DocumentsContract.isDocumentUri(context, uri)) { + final String docId = DocumentsContract.getDocumentId(uri); + + // MediaStore documents + if ("com.android.providers.media.documents".equals(uri.getAuthority())) { + final String[] split = docId.split(":"); + final String type = split[0]; + Uri contentUri = null; + + if ("image".equals(type)) { + contentUri = MediaStore.Images.Media.EXTERNAL_CONTENT_URI; + } + + final String selection = "_id=?"; + final String[] selectionArgs = new String[]{split[1]}; + return getDataColumn(context, contentUri, selection, selectionArgs); + } + } + // MediaStore (general case) + return getDataColumn(context, uri, null, null); + } + // Handle "file://" scheme + else if ("file".equals(uri.getScheme())) { + return uri.getPath(); + } + // Handle absolute path + else if (new File(uriString).exists()) { + return uriString; + } + + return null; + } catch (Exception e) { + Log.e(TAG, "Error getting path from URI: " + uriString, e); + return null; + } + } + + public String copyUriToPrivateFile(Context context, String uriString) throws IOException { + // 将字符串转换回 Uri + Uri uri = Uri.parse(uriString); + + // 应用私有目录 + File privateDir = context.getExternalFilesDir("images"); + if (privateDir == null) { + throw new IOException("Private directory not available"); + } + + // 创建目标文件 + File destFile = new File(privateDir, "temp_image_" + System.currentTimeMillis() + ".jpg"); + + try (InputStream inputStream = context.getContentResolver().openInputStream(uri); + OutputStream outputStream = new FileOutputStream(destFile)) { + + if (inputStream == null) { + throw new IOException("Failed to open URI input stream"); + } + + // 读取并写入数据 + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } + } + + // 返回文件路径 + return destFile.getAbsolutePath(); + } + + private String getDataColumn(Context context, Uri uri, String selection, String[] selectionArgs) { + final String[] projection = {MediaStore.Images.Media.DATA}; + try (Cursor cursor = context.getContentResolver().query(uri, projection, selection, selectionArgs, null)) { + if (cursor != null && cursor.moveToFirst()) { + final int columnIndex = cursor.getColumnIndexOrThrow(MediaStore.Images.Media.DATA); + return cursor.getString(columnIndex); + } + } catch (Exception e) { + Log.e(TAG, "Error getting data column", e); + } + return null; + } +} diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/KotlinFlowHelper.kt b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/KotlinFlowHelper.kt new file mode 100644 index 00000000..0183ff14 --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/KotlinFlowHelper.kt @@ -0,0 +1,44 @@ +package ai.nexa.app_java + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.cancelChildren +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext + +class KotlinFlowHelper { + private val scope = CoroutineScope(Dispatchers.IO) + + fun collectFlow( + flow: Flow, // Added missing flow parameter + onToken: (String) -> Unit, + onComplete: (String) -> Unit, + onError: (String) -> Unit + ) { + scope.launch { + try { + val fullResponse = StringBuilder() + withContext(Dispatchers.IO) { + flow.collect { value -> + fullResponse.append(value) + withContext(Dispatchers.Main) { + onToken(value) + } + } + } + withContext(Dispatchers.Main) { + onComplete(fullResponse.toString()) + } + } catch (e: Exception) { + withContext(Dispatchers.Main) { + onError(e.message ?: "Unknown error") + } + } + } + } + + fun cancel() { + scope.coroutineContext.cancelChildren() + } +} \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/KotlinJavaUtils.kt b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/KotlinJavaUtils.kt new file mode 100644 index 00000000..1fc8437c --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/KotlinJavaUtils.kt @@ -0,0 +1,11 @@ +package ai.nexa.app_java + +import java.util.function.Consumer + +object KotlinJavaUtils { + @JvmStatic + fun toKotlinCallback(callback: Consumer): (String) -> Unit = { value -> + callback.accept(value) + Unit + } +} \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java new file mode 100644 index 00000000..e48ec8d5 --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/LlamaBridge.java @@ -0,0 +1,283 @@ +package ai.nexa.app_java; + +import android.content.Context; +import com.nexa.NexaVlmInference; +import android.util.Log; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import kotlin.Unit; +import kotlin.coroutines.Continuation; +import kotlin.jvm.functions.Function1; +import kotlinx.coroutines.BuildersKt; +import kotlinx.coroutines.CoroutineStart; +import kotlinx.coroutines.Dispatchers; +import kotlinx.coroutines.GlobalScope; +import kotlinx.coroutines.Job; +import kotlinx.coroutines.flow.Flow; +import kotlinx.coroutines.flow.FlowCollector; + +public class LlamaBridge { + private static final String TAG = "LlamaBridge"; + private final Context context; + private final ExecutorService executor; + private final MessageHandler messageHandler; + private final VlmModelManager modelManager; + private final ImagePathHelper imagePathHelper; + private NexaVlmInference nexaVlmInference; + private boolean isModelLoaded = false; + + private final KotlinFlowHelper flowHelper = new KotlinFlowHelper(); + + // Default inference parameters + private static final float DEFAULT_TEMPERATURE = 1.0f; + private static final int DEFAULT_MAX_TOKENS = 64; + private static final int DEFAULT_TOP_K = 50; + private static final float DEFAULT_TOP_P = 0.9f; + + public interface InferenceCallback { + void onStart(); + void onToken(String token); + void onComplete(String fullResponse); + void onError(String error); + } + + public LlamaBridge(Context context, MessageHandler messageHandler) { + this.context = context; + this.messageHandler = messageHandler; + this.executor = Executors.newSingleThreadExecutor(); + this.modelManager = new VlmModelManager(context); + this.imagePathHelper = new ImagePathHelper(context); + } + + public boolean areModelsAvailable() { + return modelManager.areModelsAvailable(); + } + + public void loadModel() { + executor.execute(() -> { + try { + if (!modelManager.areModelsAvailable()) { + throw new IOException("Required model files are not available"); + } + + String modelPath = modelManager.getTextModelPath(); + String projectorPath = modelManager.getMmProjModelPath(); + + Log.d(TAG, "Loading model from: " + modelPath); + Log.d(TAG, "Loading projector from: " + projectorPath); + + // Create with default values for optional parameters + nexaVlmInference = new NexaVlmInference( + modelPath, // modelPath + projectorPath, // projectorPath + "", // imagePath (empty string as default) + new ArrayList<>(Arrays.asList("")), // stopWords (empty list) + DEFAULT_TEMPERATURE, // temperature + DEFAULT_MAX_TOKENS, // maxNewTokens + DEFAULT_TOP_K, // topK + DEFAULT_TOP_P // topP + ); + nexaVlmInference.loadModel(); + isModelLoaded = true; + + Log.d(TAG, "Model loaded successfully."); +// messageHandler.addMessage(new MessageModal("Model loaded successfully", "assistant", null)); + } catch (Exception e) { + Log.e(TAG, "Failed to load model", e); + messageHandler.addMessage(new MessageModal("Error loading model: " + e.getMessage(), "assistant", null)); + } + }); + } + +// public void processMessage(String message, String imageUri, InferenceCallback callback) { +// if (!isModelLoaded) { +// callback.onError("Model not loaded yet"); +// return; +// } +// +// try { +// // Add user message first +// MessageModal userMessage = new MessageModal(message, "user", imageUri); +// messageHandler.addMessage(userMessage); +// +// // Create an initial empty assistant message +// MessageModal assistantMessage = new MessageModal("", "assistant", null); +// messageHandler.addMessage(assistantMessage); +// +// // Convert image URI to absolute path +// String imageAbsolutePath = imagePathHelper.getPathFromUri(imageUri); +// +// Flow flow = nexaVlmInference.createCompletionStream( +// message, +// imageAbsolutePath, +// new ArrayList<>(), +// DEFAULT_TEMPERATURE, +// DEFAULT_MAX_TOKENS, +// DEFAULT_TOP_K, +// DEFAULT_TOP_P +// ); +// +// if (flow != null) { +// CoroutineScope scope = CoroutineScopeKt.CoroutineScope(Dispatchers.getMain()); +// +// Job job = FlowKt.launchIn( +// FlowKt.onEach(flow, new Function2, Object>() { +// @Override +// public Object invoke(String token, Continuation continuation) { +// messageHandler.updateLastAssistantMessage(token); +// callback.onToken(token); +// return Unit.INSTANCE; +// } +// }), +// scope +// ); +// } else { +// messageHandler.finalizeLastAssistantMessage("Error: Failed to create completion stream"); +// callback.onError("Failed to create completion stream"); +// } +// } catch (Exception e) { +// Log.e(TAG, "Error processing message", e); +// messageHandler.finalizeLastAssistantMessage("Error: " + e.getMessage()); +// callback.onError(e.getMessage()); +// } +// } + + public void processMessage(String message, String imageUri, InferenceCallback callback) { + if (!isModelLoaded) { + callback.onError("Model not loaded yet"); + return; + } + + String imageAbsolutePath = null; + try { + imageAbsolutePath = imagePathHelper.copyUriToPrivateFile(context, imageUri); + } catch (IOException e) { + callback.onError("Failed to process image: " + e.getMessage()); + return; + } + + final String imagePath = imageAbsolutePath; + MessageModal assistantMessage = new MessageModal("", "bot", null); + messageHandler.addMessage(assistantMessage); + + try { + Flow flow = nexaVlmInference.createCompletionStream( + message, + imagePath, + new ArrayList<>(Arrays.asList("")), + DEFAULT_TEMPERATURE, + DEFAULT_MAX_TOKENS, + DEFAULT_TOP_K, + DEFAULT_TOP_P + ); + + callback.onStart(); + StringBuilder fullResponse = new StringBuilder(); + + Job collectJob = BuildersKt.launch( + GlobalScope.INSTANCE, + Dispatchers.getIO(), + CoroutineStart.DEFAULT, + (coroutineScope, continuation) -> { + flow.collect(new FlowCollector() { + @Override + public Object emit(String token, Continuation continuation) { + fullResponse.append(token); + callback.onToken(token); + return Unit.INSTANCE; + } + }, continuation); + callback.onComplete(fullResponse.toString()); + return Unit.INSTANCE; + } + ); + + collectJob.invokeOnCompletion(new Function1() { + @Override + public Unit invoke(Throwable throwable) { + if (throwable != null && !(throwable instanceof CancellationException)) { + callback.onError("Stream collection failed: " + throwable.getMessage()); + } + return Unit.INSTANCE; + } + }); + + } catch (Exception e) { + Log.e(TAG, "Inference failed", e); + callback.onError(e.getMessage()); + } + } + + public void cleanup() { + flowHelper.cancel(); + } + +// public void processMessageWithParams( +// String message, +// String imageUri, +// float temperature, +// int maxTokens, +// int topK, +// float topP, +// InferenceCallback callback) { +// +// if (!isModelLoaded) { +// callback.onError("Model not loaded yet"); +// return; +// } +// +// executor.execute(() -> { +// StringBuilder fullResponse = new StringBuilder(); +// try { +// callback.onStart(); +// +// Flow completionStream = nexaVlmInference.createCompletionStream( +// message, +// imageUri, +// new ArrayList<>(), +// temperature, +// maxTokens, +// topK, +// topP +// ); +// +// completionStream.collect(new FlowCollector() { +// @Override +// public Object emit(String value, Continuation continuation) { +// fullResponse.append(value); +// callback.onToken(value); +// return Unit.INSTANCE; +// } +// }); +// +// callback.onComplete(fullResponse.toString()); +// +// } catch (Exception e) { +// Log.e(TAG, "Inference failed", e); +// callback.onError(e.getMessage()); +// } +// }); +// } + + + public void shutdown() { + if (nexaVlmInference != null) { + executor.execute(() -> { + try { + nexaVlmInference.dispose(); + } catch (Exception e) { + Log.e(TAG, "Error closing inference", e); + } + nexaVlmInference = null; + isModelLoaded = false; + }); + } + executor.shutdown(); + } +} \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MainActivity.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MainActivity.java new file mode 100644 index 00000000..29be7214 --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MainActivity.java @@ -0,0 +1,345 @@ +package ai.nexa.app_java; + +import android.Manifest; +import android.content.Context; +import android.content.Intent; +import android.content.pm.PackageManager; +import android.net.Uri; +import android.os.Bundle; +import android.os.Message; +import android.provider.MediaStore; +import android.speech.RecognizerIntent; +import android.speech.SpeechRecognizer; +import android.util.Log; +import android.view.MotionEvent; +import android.view.View; +import android.view.inputmethod.InputMethodManager; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.LinearLayout; +import android.widget.TextView; +import android.widget.Toast; + +import androidx.annotation.NonNull; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.app.ActivityCompat; +import androidx.recyclerview.widget.LinearLayoutManager; +import androidx.recyclerview.widget.RecyclerView; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +public class MainActivity extends AppCompatActivity { + + private static final String TAG = "ChatApp"; + private static final int PICK_IMAGE_REQUEST = 30311; + private static final int REQUEST_RECORD_AUDIO_PERMISSION = 200; + private static final int READ_EXTERNAL_STORAGE_PERMISSION = 303; + + private RecyclerView chatsRV; + private ImageButton selectImageButton; + private ImageButton sendMsgIB; + private EditText userMsgEdt; + private String justSelectedImageUri; + + private LinearLayout linearLayout; + private TextView titleAfterChatTextView; + private RecyclerView recyclerView; + + private ArrayList messageModalArrayList; + private MessageRVAdapter messageRVAdapter; + private MessageHandler messageHandler; + private LlamaBridge llamaBridge; + private SpeechRecognizer speechRecognizer; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + Log.d(TAG, "onCreate: Starting MainActivity"); + + initializeViews(); + setupRecyclerView(); + initializeLlamaBridge(); + createSpeechRecognizerIntent(); + setupClickListeners(); + + Log.d(TAG, "onCreate: MainActivity setup complete"); + } + + private void initializeViews() { + chatsRV = findViewById(R.id.idRVChats); + selectImageButton = findViewById(R.id.btnUploadImage); + sendMsgIB = findViewById(R.id.idIBSend); + userMsgEdt = findViewById(R.id.idEdtMessage); + linearLayout = findViewById(R.id.idLayoutBeforeChat); + titleAfterChatTextView = findViewById(R.id.textView); + recyclerView = findViewById(R.id.idRVChats); + } + + private void setupRecyclerView() { + messageModalArrayList = new ArrayList<>(); + messageRVAdapter = new MessageRVAdapter(messageModalArrayList, this); + chatsRV.setLayoutManager(new LinearLayoutManager(this, RecyclerView.VERTICAL, false)); + chatsRV.setAdapter(messageRVAdapter); + messageHandler = new MessageHandler(messageModalArrayList, messageRVAdapter, recyclerView); + } + + private void initializeLlamaBridge() { + llamaBridge = new LlamaBridge(this, messageHandler); + if (!llamaBridge.areModelsAvailable()) { + Toast.makeText(this, "Required model files are not available", Toast.LENGTH_LONG).show(); + return; + } + llamaBridge.loadModel(); + } + + private void setupClickListeners() { + selectImageButton.setOnClickListener(v -> { + Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI); + startActivityForResult(intent, PICK_IMAGE_REQUEST); + }); + + sendMsgIB.setOnClickListener(v -> { + hideKeyboard(v); + sendTextMessage(); + }); + } + + private void updateChatBotDisplay() { + linearLayout.setVisibility(View.GONE); + titleAfterChatTextView.setVisibility(View.VISIBLE); + recyclerView.setVisibility(View.VISIBLE); + } + + private void sendTextMessage() { + updateChatBotDisplay(); + + String userMessage = userMsgEdt.getText().toString().trim(); + if (!userMessage.isEmpty()) { + Log.d(TAG, "Sending message: " + userMessage); + messageHandler.addMessage(new MessageModal(userMessage, "user", null)); + + if (justSelectedImageUri == null) { + messageHandler.addMessage(new MessageModal("Please select an image first.", "bot", null)); + return; + } + + // Use LlamaBridge for inference + llamaBridge.processMessage(userMessage, justSelectedImageUri, new LlamaBridge.InferenceCallback() { + @Override + public void onStart() { + // Optional: Show loading indicator + } + + @Override + public void onToken(String token) { + // Update the UI with each token as it comes in + runOnUiThread(() -> { + messageHandler.updateLastBotMessage(token); + }); + } + + @Override + public void onComplete(String fullResponse) { + // Final update with complete response + runOnUiThread(() -> { + messageHandler.finalizeLastBotMessage(fullResponse); + }); + } + + @Override + public void onError(String error) { + runOnUiThread(() -> { + Toast.makeText(MainActivity.this, "Error: " + error, Toast.LENGTH_SHORT).show(); + messageHandler.addMessage(new MessageModal("Error processing message: " + error, "assistant", null)); + }); + } + }); + + userMsgEdt.setText(""); // Clear the input field after sending + justSelectedImageUri = null; // Clear the image URI after sending + } else { + Toast.makeText(MainActivity.this, "Please enter your message.", Toast.LENGTH_SHORT).show(); + } + } + + private void sendImageAsMessage(String imageUri) { + updateChatBotDisplay(); + messageHandler.addMessage(new MessageModal("", "user", imageUri)); + justSelectedImageUri = imageUri; + } + + @Override + protected void onDestroy() { + super.onDestroy(); + if (llamaBridge != null) { + llamaBridge.shutdown(); + } + if (speechRecognizer != null) { + speechRecognizer.destroy(); + } + } + + private void createSpeechRecognizerIntent() { + requestMicrophonePermission(); + + ImageButton btnStart = findViewById(R.id.btnStart); + + speechRecognizer = SpeechRecognizer.createSpeechRecognizer(this); + + Intent speechRecognizerIntent = new Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH); + speechRecognizerIntent.putExtra(RecognizerIntent.EXTRA_LANGUAGE_MODEL, RecognizerIntent.LANGUAGE_MODEL_FREE_FORM); + speechRecognizerIntent.putExtra(RecognizerIntent.EXTRA_LANGUAGE, Locale.getDefault()); + speechRecognizerIntent.putExtra(RecognizerIntent.EXTRA_PARTIAL_RESULTS, true); + + speechRecognizer.setRecognitionListener(new android.speech.RecognitionListener() { + @Override + public void onReadyForSpeech(Bundle params) { + } + + @Override + public void onBeginningOfSpeech() { + } + + @Override + public void onRmsChanged(float rmsdB) { + } + + @Override + public void onBufferReceived(byte[] buffer) { + } + + @Override + public void onEndOfSpeech() { + } + + @Override + public void onError(int error) { + String errorMessage = getErrorText(error); + Log.d("SpeechRecognition", "Error occurred: " + errorMessage); + } + + public String getErrorText(int errorCode) { + String message; + switch (errorCode) { + case SpeechRecognizer.ERROR_AUDIO: + message = "Audio recording error"; + break; + case SpeechRecognizer.ERROR_CLIENT: + message = "Client side error"; + break; + case SpeechRecognizer.ERROR_INSUFFICIENT_PERMISSIONS: + message = "Insufficient permissions"; + break; + case SpeechRecognizer.ERROR_NETWORK: + message = "Network error"; + break; + case SpeechRecognizer.ERROR_NETWORK_TIMEOUT: + message = "Network timeout"; + break; + case SpeechRecognizer.ERROR_NO_MATCH: + message = "No match"; + break; + case SpeechRecognizer.ERROR_RECOGNIZER_BUSY: + message = "RecognitionService busy"; + break; + case SpeechRecognizer.ERROR_SERVER: + message = "Error from server"; + break; + case SpeechRecognizer.ERROR_SPEECH_TIMEOUT: + message = "No speech input"; + break; + default: + message = "Didn't understand, please try again."; + break; + } + return message; + } + + @Override + public void onResults(Bundle results) { + ArrayList matches = results.getStringArrayList(SpeechRecognizer.RESULTS_RECOGNITION); + if (matches != null && !matches.isEmpty()) { + userMsgEdt.setText(matches.get(0)); // Set the recognized text to the EditText + sendTextMessage(); + } + } + + @Override + public void onPartialResults(Bundle partialResults) { + // This is called for partial results + ArrayList partialMatches = partialResults.getStringArrayList(SpeechRecognizer.RESULTS_RECOGNITION); + if (partialMatches != null && !partialMatches.isEmpty()) { + userMsgEdt.setText(partialMatches.get(0)); // Update EditText with the partial result + } + } + + @Override + public void onEvent(int eventType, Bundle params) { + } + }); + + btnStart.setOnTouchListener(new View.OnTouchListener() { + @Override + public boolean onTouch(View v, MotionEvent event) { + switch (event.getAction()) { + case MotionEvent.ACTION_DOWN: + // Button is pressed + speechRecognizer.startListening(speechRecognizerIntent); + return true; // Return true to indicate the event was handled + case MotionEvent.ACTION_UP: + // Button is released + speechRecognizer.stopListening(); + return true; // Return true to indicate the event was handled + } + return false; // Return false for other actions + } + }); + } + + private void requestMicrophonePermission() { + ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO_PERMISSION); + } + + @Override + public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + switch (requestCode) { + case READ_EXTERNAL_STORAGE_PERMISSION: + if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) { + Toast.makeText(this, "Read External Storage Permission Granted", Toast.LENGTH_SHORT).show(); + Intent intent = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI); + startActivityForResult(intent, PICK_IMAGE_REQUEST); + } else { + Toast.makeText(this, "Read External Storage Permission Denied", Toast.LENGTH_SHORT).show(); + } + break; + default: + break; + } + + } + + @Override + protected void onActivityResult(int requestCode, int resultCode, Intent data) { + super.onActivityResult(requestCode, resultCode, data); + if (requestCode == PICK_IMAGE_REQUEST && resultCode == RESULT_OK && data != null) { + Uri selectedImage = data.getData(); + if (selectedImage != null) { + String imageUriString = selectedImage.toString(); + sendImageAsMessage(imageUriString); + } + } + } + + public void hideKeyboard(View view) { + InputMethodManager inputMethodManager = (InputMethodManager) getSystemService(Context.INPUT_METHOD_SERVICE); + if (inputMethodManager != null) { + inputMethodManager.hideSoftInputFromWindow(view.getWindowToken(), InputMethodManager.HIDE_NOT_ALWAYS); + } + } + +} \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageHandler.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageHandler.java new file mode 100644 index 00000000..39720c1f --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageHandler.java @@ -0,0 +1,127 @@ +package ai.nexa.app_java; + +import androidx.recyclerview.widget.RecyclerView; +import android.os.Handler; +import android.os.Looper; + +import java.util.ArrayList; + +public class MessageHandler { + private final ArrayList messageModalArrayList; + private final MessageRVAdapter messageRVAdapter; + private final RecyclerView recyclerView; + private final Handler mainHandler; + + public MessageHandler(ArrayList messageModalArrayList, MessageRVAdapter messageRVAdapter, RecyclerView recyclerView) { + this.messageModalArrayList = messageModalArrayList; + this.messageRVAdapter = messageRVAdapter; + this.recyclerView = recyclerView; + this.mainHandler = new Handler(Looper.getMainLooper()); + } + + /** + * Add a new message to the chat + */ + public void addMessage(MessageModal message) { + ensureMainThread(() -> { + messageModalArrayList.add(message); + messageRVAdapter.notifyItemInserted(messageModalArrayList.size() - 1); + scrollToBottom(); + }); + } + + /** + * Update the last bot message with new token + */ + public void updateLastBotMessage(String newToken) { + ensureMainThread(() -> { + if (!messageModalArrayList.isEmpty()) { + int lastIndex = messageModalArrayList.size() - 1; + MessageModal lastMessage = messageModalArrayList.get(lastIndex); + + // If last message is from bot, update it + if ("bot".equals(lastMessage.getSender())) { + String currentMessage = lastMessage.getMessage(); + lastMessage.setMessage(currentMessage + newToken); + messageRVAdapter.notifyItemChanged(lastIndex); + } else { + // Create new bot message + MessageModal newMessage = new MessageModal(newToken, "bot", null); + messageModalArrayList.add(newMessage); + messageRVAdapter.notifyItemInserted(messageModalArrayList.size() - 1); + } + scrollToBottom(); + } + }); + } + + /** + * Finalize the last bot message with complete response + */ + public void finalizeLastBotMessage(String completeMessage) { + ensureMainThread(() -> { + if (!messageModalArrayList.isEmpty()) { + int lastIndex = messageModalArrayList.size() - 1; + MessageModal lastMessage = messageModalArrayList.get(lastIndex); + + if ("bot".equals(lastMessage.getSender())) { + lastMessage.setMessage(completeMessage); + messageRVAdapter.notifyItemChanged(lastIndex); + } else { + MessageModal newMessage = new MessageModal(completeMessage, "bot", null); + messageModalArrayList.add(newMessage); + messageRVAdapter.notifyItemInserted(messageModalArrayList.size() - 1); + } + scrollToBottom(); + } + }); + } + + /** + * Clear all messages from the chat + */ + public void clearMessages() { + ensureMainThread(() -> { + messageModalArrayList.clear(); + messageRVAdapter.notifyDataSetChanged(); + }); + } + + /** + * Get the last message in the chat + */ + public MessageModal getLastMessage() { + if (!messageModalArrayList.isEmpty()) { + return messageModalArrayList.get(messageModalArrayList.size() - 1); + } + return null; + } + + /** + * Check if the last message is from the bot + */ + public boolean isLastMessageFromBot() { + MessageModal lastMessage = getLastMessage(); + return lastMessage != null && "bot".equals(lastMessage.getSender()); + } + + /** + * Scroll the RecyclerView to the bottom + */ + private void scrollToBottom() { + if (messageModalArrayList.size() > 1) { + recyclerView.smoothScrollToPosition(messageModalArrayList.size() - 1); + } + } + + /** + * Ensure all UI updates happen on the main thread + */ + private void ensureMainThread(Runnable action) { + if (Looper.myLooper() == Looper.getMainLooper()) { + action.run(); + } else { + mainHandler.post(action); + } + } +} \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java new file mode 100644 index 00000000..1e60921b --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageModal.java @@ -0,0 +1,42 @@ +package ai.nexa.app_java; + +public class MessageModal { + + + private String message; + private String sender; + + private String imageUri; + + public MessageModal(String message, String sender, String imageUri) { + this.message = message; + this.sender = sender; + this.imageUri = imageUri; + } + + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public String getSender() { + return sender; + } + + public void setSender(String sender) { + this.sender = sender; + } + + public String getImageUri() { + return imageUri; + } + + public void setImageUri(String imageUri) { + this.imageUri = imageUri; + } +} + diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java new file mode 100644 index 00000000..90977681 --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/MessageRVAdapter.java @@ -0,0 +1,102 @@ +package ai.nexa.app_java; + +import android.content.Context; +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ImageView; +import android.widget.TextView; + +import androidx.annotation.NonNull; +import androidx.recyclerview.widget.RecyclerView; + +import com.bumptech.glide.Glide; + +import java.util.ArrayList; + +public class MessageRVAdapter extends RecyclerView.Adapter { + + private ArrayList messageModalArrayList; + private Context context; + + public MessageRVAdapter(ArrayList messageModalArrayList, Context context) { + this.messageModalArrayList = messageModalArrayList; + this.context = context; + } + + @NonNull + @Override + public RecyclerView.ViewHolder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) { + View view; + switch (viewType) { + case 0: + view = LayoutInflater.from(parent.getContext()).inflate(R.layout.user_msg, parent, false); + return new UserViewHolder(view); + case 1: + view = LayoutInflater.from(parent.getContext()).inflate(R.layout.bot_msg, parent, false); + return new BotViewHolder(view); + } + return null; + } + + @Override + public void onBindViewHolder(@NonNull RecyclerView.ViewHolder holder, int position) { + MessageModal modal = messageModalArrayList.get(position); + switch (modal.getSender()) { + case "user": + UserViewHolder userHolder = (UserViewHolder) holder; + if (modal.getImageUri() != null && !modal.getImageUri().isEmpty()) { + userHolder.userImage.setVisibility(View.VISIBLE); + userHolder.userTV.setVisibility(View.GONE); + Glide.with(userHolder.itemView.getContext()) + .load(modal.getImageUri()) + .into(userHolder.userImage); + } else { + userHolder.userImage.setVisibility(View.GONE); + userHolder.userTV.setVisibility(View.VISIBLE); + userHolder.userTV.setText(modal.getMessage()); + } + break; + case "bot": + ((BotViewHolder) holder).botTV.setText(modal.getMessage()); + break; + } + } + + @Override + public int getItemCount() { + return messageModalArrayList.size(); + } + + @Override + public int getItemViewType(int position) { + switch (messageModalArrayList.get(position).getSender()) { + case "user": + return 0; + case "bot": + return 1; + default: + return -1; + } + } + + public static class UserViewHolder extends RecyclerView.ViewHolder { + TextView userTV; + ImageView userImage; + + public UserViewHolder(@NonNull View itemView) { + super(itemView); + userTV = itemView.findViewById(R.id.idTVUser); + userImage = itemView.findViewById(R.id.idIVUserImage); + } + } + + public static class BotViewHolder extends RecyclerView.ViewHolder { + TextView botTV; + + public BotViewHolder(@NonNull View itemView) { + super(itemView); + botTV = itemView.findViewById(R.id.idTVBot); + } + } +} diff --git a/android/llama.android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java new file mode 100644 index 00000000..9ebd8d45 --- /dev/null +++ b/android/llama.android/app-java/src/main/java/ai/nexa/app_java/VlmModelManager.java @@ -0,0 +1,125 @@ +package ai.nexa.app_java; + +import android.content.Context; +import android.os.Environment; +import android.util.Log; + +import java.io.File; +import java.io.IOException; + +public class VlmModelManager { + private static final String TAG = "LlamaBridge"; + private static final String MODELS_DIR = "models"; + private static final String MODEL_TEXT_FILENAME = "nanollava-text-model-q4_0.gguf"; + private static final String MODEL_MMPROJ_FILENAME = "nanollava-mmproj-f16.gguf"; + + private final Context context; + private File textModelFile; + private File mmProjModelFile; + private final File externalModelDir; + + public VlmModelManager(Context context) { + this.context = context; + this.externalModelDir = new File(Environment.getExternalStorageDirectory(), + "Android/data/" + context.getPackageName() + "/files"); + } + + /** + * Search for model in common locations + * @param modelFilename The name of the model file to find + * @return File path to the model if found, null otherwise + */ + private String findExistingModel(String modelFilename) { + // List of possible locations to check + File[] locations = { + // External storage specific path + new File(externalModelDir, modelFilename), + // Downloads folder + new File(Environment.getExternalStoragePublicDirectory( + Environment.DIRECTORY_DOWNLOADS), modelFilename), + // App's private external storage + new File(context.getExternalFilesDir(null), MODELS_DIR + "/" + modelFilename), + // App's private internal storage + new File(context.getFilesDir(), MODELS_DIR + "/" + modelFilename) + }; + + for (File location : locations) { + if (location.exists() && location.canRead()) { + Log.d(TAG, "Found model at: " + location.getAbsolutePath()); + return location.getAbsolutePath(); + } + } + return null; + } + + /** + * Get text model path, searching in storage locations + * @return Path to the model file + * @throws IOException if model cannot be found or accessed + */ + public String getTextModelPath() throws IOException { + // If we already have a valid model file, return it + if (textModelFile != null && textModelFile.exists() && textModelFile.canRead()) { + return textModelFile.getAbsolutePath(); + } + + // Search for existing model + String path = findExistingModel(MODEL_TEXT_FILENAME); + if (path != null) { + textModelFile = new File(path); + return path; + } + + throw new IOException("Text model not found in any storage location"); + } + + /** + * Get mmproj model path, searching in storage locations + * @return Path to the model file + * @throws IOException if model cannot be found or accessed + */ + public String getMmProjModelPath() throws IOException { + // If we already have a valid model file, return it + if (mmProjModelFile != null && mmProjModelFile.exists() && mmProjModelFile.canRead()) { + return mmProjModelFile.getAbsolutePath(); + } + + // Search for existing model + String path = findExistingModel(MODEL_MMPROJ_FILENAME); + if (path != null) { + mmProjModelFile = new File(path); + return path; + } + + throw new IOException("MMProj model not found in any storage location"); + } + + /** + * Check if both required models exist in any location + * @return true if both models are found + */ + public boolean areModelsAvailable() { + try { + getTextModelPath(); + getMmProjModelPath(); + return true; + } catch (IOException e) { + Log.w(TAG, "Models not available: " + e.getMessage()); + return false; + } + } + + /** + * Get the directory containing the models + * @return File object for the models directory, or null if models aren't found + */ + public File getModelsDirectory() { + try { + String textModelPath = getTextModelPath(); + return new File(textModelPath).getParentFile(); + } catch (IOException e) { + Log.w(TAG, "Could not determine models directory: " + e.getMessage()); + return null; + } + } +} diff --git a/android/llama.android/app-java/src/main/res/drawable-hdpi/ic_menu_send.png b/android/llama.android/app-java/src/main/res/drawable-hdpi/ic_menu_send.png new file mode 100644 index 00000000..f34a9658 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable-hdpi/ic_menu_send.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable-mdpi/ic_menu_send.png b/android/llama.android/app-java/src/main/res/drawable-mdpi/ic_menu_send.png new file mode 100644 index 00000000..e83f6010 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable-mdpi/ic_menu_send.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable-v24/ic_launcher_foreground.xml b/android/llama.android/app-java/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 00000000..2b068d11 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/drawable-xhdpi/ic_menu_send.png b/android/llama.android/app-java/src/main/res/drawable-xhdpi/ic_menu_send.png new file mode 100644 index 00000000..882722eb Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable-xhdpi/ic_menu_send.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable-xxhdpi/ic_menu_send.png b/android/llama.android/app-java/src/main/res/drawable-xxhdpi/ic_menu_send.png new file mode 100644 index 00000000..08108e76 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable-xxhdpi/ic_menu_send.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable-xxxhdpi/ic_menu_send.png b/android/llama.android/app-java/src/main/res/drawable-xxxhdpi/ic_menu_send.png new file mode 100644 index 00000000..8f7eb62c Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable-xxxhdpi/ic_menu_send.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable/bg_send_message.xml b/android/llama.android/app-java/src/main/res/drawable/bg_send_message.xml new file mode 100644 index 00000000..972981d8 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/bg_send_message.xml @@ -0,0 +1,9 @@ + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/bot_message.xml b/android/llama.android/app-java/src/main/res/drawable/bot_message.xml new file mode 100644 index 00000000..8dda5f87 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/bot_message.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/gradient_background.xml b/android/llama.android/app-java/src/main/res/drawable/gradient_background.xml new file mode 100644 index 00000000..6d9a5345 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/gradient_background.xml @@ -0,0 +1,8 @@ + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/ic_bot.xml b/android/llama.android/app-java/src/main/res/drawable/ic_bot.xml new file mode 100644 index 00000000..660ed4e0 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/ic_bot.xml @@ -0,0 +1,171 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/ic_launcher.png b/android/llama.android/app-java/src/main/res/drawable/ic_launcher.png new file mode 100644 index 00000000..e3c90853 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable/ic_launcher.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable/ic_launcher_background.xml b/android/llama.android/app-java/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000..07d5da9c --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/ic_launcher_fav_background.xml b/android/llama.android/app-java/src/main/res/drawable/ic_launcher_fav_background.xml new file mode 100644 index 00000000..ca3826a4 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/ic_launcher_fav_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/ic_user.xml b/android/llama.android/app-java/src/main/res/drawable/ic_user.xml new file mode 100644 index 00000000..725adb58 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/ic_user.xml @@ -0,0 +1,22 @@ + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/input_text_box.xml b/android/llama.android/app-java/src/main/res/drawable/input_text_box.xml new file mode 100644 index 00000000..1c132b0b --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/input_text_box.xml @@ -0,0 +1,10 @@ + + + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/drawable/microphone.xml b/android/llama.android/app-java/src/main/res/drawable/microphone.xml new file mode 100644 index 00000000..75fe9341 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/microphone.xml @@ -0,0 +1,13 @@ + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/octopus_background.xml b/android/llama.android/app-java/src/main/res/drawable/octopus_background.xml new file mode 100644 index 00000000..ca3826a4 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/octopus_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/octopus_menu_send.xml b/android/llama.android/app-java/src/main/res/drawable/octopus_menu_send.xml new file mode 100644 index 00000000..4254a34f --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/octopus_menu_send.xml @@ -0,0 +1,11 @@ + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/octopus_original.xml b/android/llama.android/app-java/src/main/res/drawable/octopus_original.xml new file mode 100644 index 00000000..92048641 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/octopus_original.xml @@ -0,0 +1,171 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/ocutopus_v3_full_size.png b/android/llama.android/app-java/src/main/res/drawable/ocutopus_v3_full_size.png new file mode 100644 index 00000000..de1bb864 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/drawable/ocutopus_v3_full_size.png differ diff --git a/android/llama.android/app-java/src/main/res/drawable/roundcorner.xml b/android/llama.android/app-java/src/main/res/drawable/roundcorner.xml new file mode 100644 index 00000000..5c795c41 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/roundcorner.xml @@ -0,0 +1,9 @@ + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/title.xml b/android/llama.android/app-java/src/main/res/drawable/title.xml new file mode 100644 index 00000000..a7bad4f8 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/title.xml @@ -0,0 +1,11 @@ + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/drawable/upload_image_icon.xml b/android/llama.android/app-java/src/main/res/drawable/upload_image_icon.xml new file mode 100644 index 00000000..f4a86832 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/drawable/upload_image_icon.xml @@ -0,0 +1,13 @@ + + + diff --git a/android/llama.android/app-java/src/main/res/font/abhaya_libre_bold.ttf b/android/llama.android/app-java/src/main/res/font/abhaya_libre_bold.ttf new file mode 100644 index 00000000..6f4a231d Binary files /dev/null and b/android/llama.android/app-java/src/main/res/font/abhaya_libre_bold.ttf differ diff --git a/android/llama.android/app-java/src/main/res/font/alegreya_sans_sc_extrabold.xml b/android/llama.android/app-java/src/main/res/font/alegreya_sans_sc_extrabold.xml new file mode 100644 index 00000000..8112a231 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/font/alegreya_sans_sc_extrabold.xml @@ -0,0 +1,7 @@ + + + diff --git a/android/llama.android/app-java/src/main/res/layout/activity_main.xml b/android/llama.android/app-java/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000..625d9923 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/layout/activity_main.xml @@ -0,0 +1,115 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/layout/bot_msg.xml b/android/llama.android/app-java/src/main/res/layout/bot_msg.xml new file mode 100644 index 00000000..5ee58e1d --- /dev/null +++ b/android/llama.android/app-java/src/main/res/layout/bot_msg.xml @@ -0,0 +1,35 @@ + + + + + + + + + + diff --git a/android/llama.android/app-java/src/main/res/layout/user_msg.xml b/android/llama.android/app-java/src/main/res/layout/user_msg.xml new file mode 100644 index 00000000..20aa126a --- /dev/null +++ b/android/llama.android/app-java/src/main/res/layout/user_msg.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 00000000..036d09bc --- /dev/null +++ b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 00000000..036d09bc --- /dev/null +++ b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/octopus.xml b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/octopus.xml new file mode 100644 index 00000000..2e533e65 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/octopus.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/octopus_round.xml b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/octopus_round.xml new file mode 100644 index 00000000..2e533e65 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/mipmap-anydpi-v26/octopus_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher.png b/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher.png new file mode 100644 index 00000000..cf0c3458 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher_foreground.png b/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher_foreground.png new file mode 100644 index 00000000..8acbf0ea Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher_foreground.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher_round.png b/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher_round.png new file mode 100644 index 00000000..12580bdb Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-hdpi/ic_launcher_round.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus.webp b/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus.webp new file mode 100644 index 00000000..29daecf1 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus_foreground.webp b/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus_foreground.webp new file mode 100644 index 00000000..88bf3149 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus_foreground.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus_round.webp b/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus_round.webp new file mode 100644 index 00000000..0883ba1c Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-hdpi/octopus_round.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher.png b/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher.png new file mode 100644 index 00000000..b3990457 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher_foreground.png b/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher_foreground.png new file mode 100644 index 00000000..b8a59f47 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher_foreground.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher_round.png b/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher_round.png new file mode 100644 index 00000000..75aec75a Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-mdpi/ic_launcher_round.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus.webp b/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus.webp new file mode 100644 index 00000000..c192866e Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus_foreground.webp b/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus_foreground.webp new file mode 100644 index 00000000..34251871 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus_foreground.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus_round.webp b/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus_round.webp new file mode 100644 index 00000000..edb57427 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-mdpi/octopus_round.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher.png b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher.png new file mode 100644 index 00000000..a6324636 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher_foreground.png b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher_foreground.png new file mode 100644 index 00000000..98708fa5 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher_foreground.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher_round.png b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher_round.png new file mode 100644 index 00000000..30de3067 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/ic_launcher_round.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus.webp b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus.webp new file mode 100644 index 00000000..372b8bdb Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus_foreground.webp b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus_foreground.webp new file mode 100644 index 00000000..fcdd6ddf Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus_foreground.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus_round.webp b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus_round.webp new file mode 100644 index 00000000..5b864a66 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xhdpi/octopus_round.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher.png b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher.png new file mode 100644 index 00000000..196c1ef5 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher_foreground.png b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher_foreground.png new file mode 100644 index 00000000..34fc4e7e Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher_foreground.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher_round.png b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher_round.png new file mode 100644 index 00000000..984bb8d9 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/ic_launcher_round.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus.webp b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus.webp new file mode 100644 index 00000000..ad3daafc Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus_foreground.webp b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus_foreground.webp new file mode 100644 index 00000000..ca878a67 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus_foreground.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus_round.webp b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus_round.webp new file mode 100644 index 00000000..fd780d66 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxhdpi/octopus_round.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher.png b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher.png new file mode 100644 index 00000000..1f10f330 Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher_foreground.png b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher_foreground.png new file mode 100644 index 00000000..13f3147e Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher_foreground.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png new file mode 100644 index 00000000..b81a70ba Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus.webp b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus.webp new file mode 100644 index 00000000..ef8923ce Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus_foreground.webp b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus_foreground.webp new file mode 100644 index 00000000..e8b6489c Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus_foreground.webp differ diff --git a/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus_round.webp b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus_round.webp new file mode 100644 index 00000000..d0b5881a Binary files /dev/null and b/android/llama.android/app-java/src/main/res/mipmap-xxxhdpi/octopus_round.webp differ diff --git a/android/llama.android/app-java/src/main/res/values-night/themes.xml b/android/llama.android/app-java/src/main/res/values-night/themes.xml new file mode 100644 index 00000000..2bd72d37 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values-night/themes.xml @@ -0,0 +1,11 @@ + + + + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/values/colors.xml b/android/llama.android/app-java/src/main/res/values/colors.xml new file mode 100644 index 00000000..b15af47b --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/colors.xml @@ -0,0 +1,17 @@ + + + #FF000000 + #FFFFFFFF + #813BBA + #FF202020 + #17CE92 + #E5E5E5 + #0A1528 + #313D50 + #03070D + #03070D + #03070D + #03070D + #FFFFFF + #B00020 + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/values/font_certs.xml b/android/llama.android/app-java/src/main/res/values/font_certs.xml new file mode 100644 index 00000000..d2226ac0 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/font_certs.xml @@ -0,0 +1,17 @@ + + + + @array/com_google_android_gms_fonts_certs_dev + @array/com_google_android_gms_fonts_certs_prod + + + + MIIEqDCCA5CgAwIBAgIJANWFuGx90071MA0GCSqGSIb3DQEBBAUAMIGUMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNTW91bnRhaW4gVmlldzEQMA4GA1UEChMHQW5kcm9pZDEQMA4GA1UECxMHQW5kcm9pZDEQMA4GA1UEAxMHQW5kcm9pZDEiMCAGCSqGSIb3DQEJARYTYW5kcm9pZEBhbmRyb2lkLmNvbTAeFw0wODA0MTUyMzM2NTZaFw0zNTA5MDEyMzM2NTZaMIGUMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNTW91bnRhaW4gVmlldzEQMA4GA1UEChMHQW5kcm9pZDEQMA4GA1UECxMHQW5kcm9pZDEQMA4GA1UEAxMHQW5kcm9pZDEiMCAGCSqGSIb3DQEJARYTYW5kcm9pZEBhbmRyb2lkLmNvbTCCASAwDQYJKoZIhvcNAQEBBQADggENADCCAQgCggEBANbOLggKv+IxTdGNs8/TGFy0PTP6DHThvbbR24kT9ixcOd9W+EaBPWW+wPPKQmsHxajtWjmQwWfna8mZuSeJS48LIgAZlKkpFeVyxW0qMBujb8X8ETrWy550NaFtI6t9+u7hZeTfHwqNvacKhp1RbE6dBRGWynwMVX8XW8N1+UjFaq6GCJukT4qmpN2afb8sCjUigq0GuMwYXrFVee74bQgLHWGJwPmvmLHC69EH6kWr22ijx4OKXlSIx2xT1AsSHee70w5iDBiK4aph27yH3TxkXy9V89TDdexAcKk/cVHYNnDBapcavl7y0RiQ4biu8ymM8Ga/nmzhRKya6G0cGw8CAQOjgfwwgfkwHQYDVR0OBBYEFI0cxb6VTEM8YYY6FbBMvAPyT+CyMIHJBgNVHSMEgcEwgb6AFI0cxb6VTEM8YYY6FbBMvAPyT+CyoYGapIGXMIGUMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNTW91bnRhaW4gVmlldzEQMA4GA1UEChMHQW5kcm9pZDEQMA4GA1UECxMHQW5kcm9pZDEQMA4GA1UEAxMHQW5kcm9pZDEiMCAGCSqGSIb3DQEJARYTYW5kcm9pZEBhbmRyb2lkLmNvbYIJANWFuGx90071MAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEEBQADggEBABnTDPEF+3iSP0wNfdIjIz1AlnrPzgAIHVvXxunW7SBrDhEglQZBbKJEk5kT0mtKoOD1JMrSu1xuTKEBahWRbqHsXclaXjoBADb0kkjVEJu/Lh5hgYZnOjvlba8Ld7HCKePCVePoTJBdI4fvugnL8TsgK05aIskyY0hKI9L8KfqfGTl1lzOv2KoWD0KWwtAWPoGChZxmQ+nBli+gwYMzM1vAkP+aayLe0a1EQimlOalO762r0GXO0ks+UeXde2Z4e+8S/pf7pITEI/tP+MxJTALw9QUWEv9lKTk+jkbqxbsh8nfBUapfKqYn0eidpwq2AzVp3juYl7//fKnaPhJD9gs= + + + + + MIIEQzCCAyugAwIBAgIJAMLgh0ZkSjCNMA0GCSqGSIb3DQEBBAUAMHQxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRQwEgYDVQQKEwtHb29nbGUgSW5jLjEQMA4GA1UECxMHQW5kcm9pZDEQMA4GA1UEAxMHQW5kcm9pZDAeFw0wODA4MjEyMzEzMzRaFw0zNjAxMDcyMzEzMzRaMHQxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRQwEgYDVQQKEwtHb29nbGUgSW5jLjEQMA4GA1UECxMHQW5kcm9pZDEQMA4GA1UEAxMHQW5kcm9pZDCCASAwDQYJKoZIhvcNAQEBBQADggENADCCAQgCggEBAKtWLgDYO6IIrgqWbxJOKdoR8qtW0I9Y4sypEwPpt1TTcvZApxsdyxMJZ2JORland2qSGT2y5b+3JKkedxiLDmpHpDsz2WCbdxgxRczfey5YZnTJ4VZbH0xqWVW/8lGmPav5xVwnIiJS6HXk+BVKZF+JcWjAsb/GEuq/eFdpuzSqeYTcfi6idkyugwfYwXFU1+5fZKUaRKYCwkkFQVfcAs1fXA5V+++FGfvjJ/CxURaSxaBvGdGDhfXE28LWuT9ozCl5xw4Yq5OGazvV24mZVSoOO0yZ31j7kYvtwYK6NeADwbSxDdJEqO4k//0zOHKrUiGYXtqw/A0LFFtqoZKFjnkCAQOjgdkwgdYwHQYDVR0OBBYEFMd9jMIhF1Ylmn/Tgt9r45jk14alMIGmBgNVHSMEgZ4wgZuAFMd9jMIhF1Ylmn/Tgt9r45jk14aloXikdjB0MQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNTW91bnRhaW4gVmlldzEUMBIGA1UEChMLR29vZ2xlIEluYy4xEDAOBgNVBAsTB0FuZHJvaWQxEDAOBgNVBAMTB0FuZHJvaWSCCQDC4IdGZEowjTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBAUAA4IBAQBt0lLO74UwLDYKqs6Tm8/yzKkEu116FmH4rkaymUIE0P9KaMftGlMexFlaYjzmB2OxZyl6euNXEsQH8gjwyxCUKRJNexBiGcCEyj6z+a1fuHHvkiaai+KL8W1EyNmgjmyy8AW7P+LLlkR+ho5zEHatRbM/YAnqGcFh5iZBqpknHf1SKMXFh4dd239FJ1jWYfbMDMy3NS5CTMQ2XFI1MvcyUTdZPErjQfTbQe3aDQsQcafEQPD+nqActifKZ0Np0IS9L9kR/wbNvyz6ENwPiTrjV2KRkEjH78ZMcUQXg0L3BYHJ3lc69Vs5Ddf9uUGGMYldX3WfMBEmh/9iFBDAaTCK + + + diff --git a/android/llama.android/app-java/src/main/res/values/ic_launcher_background.xml b/android/llama.android/app-java/src/main/res/values/ic_launcher_background.xml new file mode 100644 index 00000000..c5d5899f --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/ic_launcher_background.xml @@ -0,0 +1,4 @@ + + + #FFFFFF + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/values/preloaded_fonts.xml b/android/llama.android/app-java/src/main/res/values/preloaded_fonts.xml new file mode 100644 index 00000000..56657f17 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/preloaded_fonts.xml @@ -0,0 +1,6 @@ + + + + @font/alegreya_sans_sc_extrabold + + diff --git a/android/llama.android/app-java/src/main/res/values/strings.xml b/android/llama.android/app-java/src/main/res/values/strings.xml new file mode 100644 index 00000000..2ff67712 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/strings.xml @@ -0,0 +1,4 @@ + + LayoutTest + User Message + \ No newline at end of file diff --git a/android/llama.android/app-java/src/main/res/values/styles.xml b/android/llama.android/app-java/src/main/res/values/styles.xml new file mode 100644 index 00000000..864fcf30 --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/styles.xml @@ -0,0 +1,16 @@ + + + + diff --git a/android/llama.android/app-java/src/main/res/values/themes.xml b/android/llama.android/app-java/src/main/res/values/themes.xml new file mode 100644 index 00000000..2ef46f0c --- /dev/null +++ b/android/llama.android/app-java/src/main/res/values/themes.xml @@ -0,0 +1,13 @@ + + + + + - """, - unsafe_allow_html=True, -) -st.title("Nexa AI Image Generation") -st.caption("Powered by Nexa AI SDK🐙") + default_model = sys.argv[1] + is_local_path = sys.argv[2].lower() == "true" + hf = sys.argv[3].lower() == "true" + + # UI setup: + st.set_page_config(page_title="Nexa AI Image Generation", layout="wide") + st.markdown( + r""" + + """, + unsafe_allow_html=True, + ) + st.title("Nexa AI Image Generation") + st.caption("Powered by Nexa AI SDK🐙") + + # force refresh model options on every page load: + if 'model_options' not in st.session_state: + st.session_state.model_options = get_model_options(specified_run_type, model_map) + else: + update_model_options(specified_run_type, model_map) -st.sidebar.header("Model Configuration") -model_path = st.sidebar.text_input("Model path", default_model) + # init session state variables: + if 'initialized' not in st.session_state: + st.session_state.current_model_path = None + st.session_state.current_local_path = None + st.session_state.current_hub_model = None + + if not is_local_path and not hf: + try: + with st.spinner(f"Loading model: {default_model}"): + st.session_state.nexa_model = load_model(default_model) + if st.session_state.nexa_model: + st.session_state.current_hub_model = default_model + except Exception as e: + st.error(f"Error loading default model: {str(e)}") + + if default_model not in st.session_state.model_options: + st.session_state.current_model_index = st.session_state.model_options.index("Use Model From Nexa Model Hub 🔍") + else: + try: + st.session_state.current_model_index = st.session_state.model_options.index(default_model) + except ValueError: + st.session_state.current_model_index = 0 + + st.session_state.initialized = True + + # model selection sidebar: + st.sidebar.header("Model Configuration") + + # update selectbox index based on current model + if 'nexa_model' in st.session_state: + if st.session_state.current_hub_model: + current_index = st.session_state.model_options.index("Use Model From Nexa Model Hub 🔍") + elif st.session_state.current_local_path: + current_index = st.session_state.model_options.index("Local Model 📁") + elif st.session_state.current_model_path: + current_index = st.session_state.model_options.index(st.session_state.current_model_path) + else: + current_index = st.session_state.current_model_index + else: + current_index = st.session_state.current_model_index -if not model_path: - st.warning( - "Please enter a valid path or identifier for the model in Nexa Model Hub to proceed." + model_path = st.sidebar.selectbox( + "Select a Model", + st.session_state.model_options, + index=current_index, + key='model_selectbox' ) - st.stop() - -if ( - "current_model_path" not in st.session_state - or st.session_state.current_model_path != model_path -): - st.session_state.current_model_path = model_path - st.session_state.nexa_model = load_model(model_path) - if st.session_state.nexa_model is None: - st.stop() - -st.sidebar.header("Generation Parameters") -num_inference_steps = st.sidebar.slider( - "Number of Inference Steps", - 1, - 100, - st.session_state.nexa_model.params["num_inference_steps"], -) -height = st.sidebar.slider( - "Height", 64, 1024, st.session_state.nexa_model.params["height"] -) -width = st.sidebar.slider( - "Width", 64, 1024, st.session_state.nexa_model.params["width"] -) -guidance_scale = st.sidebar.slider( - "Guidance Scale", 0.0, 20.0, st.session_state.nexa_model.params["guidance_scale"] -) -random_seed = st.sidebar.slider( - "Random Seed", 0, 10000, st.session_state.nexa_model.params["random_seed"] -) -st.session_state.nexa_model.params.update( - { - "num_inference_steps": num_inference_steps, - "height": height, - "width": width, - "guidance_scale": guidance_scale, - "random_seed": random_seed, - } -) + # handle model path input: + if model_path == "Local Model 📁": + local_model_path = st.sidebar.text_input("Enter local model path") + if not local_model_path: + st.warning("Please enter a valid local model path to proceed.") + st.stop() + local_model_path = local_model_path.strip() # remove spaces -prompt = st.text_input("Enter your prompt:") -negative_prompt = st.text_input("Enter your negative prompt (optional):") + # handle local model path changes: + if 'nexa_model' not in st.session_state or st.session_state.current_local_path != local_model_path: + with st.spinner("Loading local model..."): + st.session_state.nexa_model = load_local_model(local_model_path) + st.session_state.current_local_path = local_model_path + + elif model_path == "Use Model From Nexa Model Hub 🔍": + initial_value = default_model if not is_local_path and not hf else "" + hub_model_name = st.sidebar.text_input( + "Enter model name from Nexa Model Hub", + value=initial_value + ) + + # empty string check: + if not hub_model_name: + st.warning(""" + How to add a model from Nexa Model Hub: + \n1. Visit [Nexa Model Hub](https://nexaai.com/models) + \n2. Find a vision model using the task filters + \n3. Select your desired model and copy either: + \n - The full nexa run command, or (e.g., nexa run stable-diffusion-v1-4:q4_0) + \n - Simply the model name (e.g., stable-diffusion-v1-4:q4_0) + \n4. Paste it into the field on the sidebar and press enter + """) + st.stop() + + # process the input after checking it's not empty: + if hub_model_name.startswith("nexa run"): + hub_model_name = hub_model_name.split("nexa run")[-1].strip() + else: + hub_model_name = hub_model_name.strip() + + # handle hub model name changes: + if 'nexa_model' not in st.session_state or st.session_state.current_hub_model != hub_model_name: + with st.spinner("Loading model from hub..."): + st.session_state.nexa_model = load_model(hub_model_name) + if st.session_state.nexa_model: # only update if load was successful + st.session_state.current_hub_model = hub_model_name -if st.button("Generate Image"): - if not prompt: - st.warning("Please enter a prompt to proceed.") else: - with st.spinner("Generating images..."): - images = generate_images( - st.session_state.nexa_model, prompt, negative_prompt - ) - st.success("Images generated successfully!") - for i, image in enumerate(images): - st.image(image, caption=f"Generated Image", use_column_width=True) - - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') - img_byte_arr = img_byte_arr.getvalue() - - st.download_button( - label=f"Download Image", - data=img_byte_arr, - file_name=f"generated_image.png", - mime="image/png" + # load selected model if it's not already loaded: + if ('nexa_model' not in st.session_state or getattr(st.session_state, 'current_model_path', None) != model_path): + with st.spinner(f"Loading model: {model_path}"): + st.session_state.nexa_model = load_model(model_path) + if st.session_state.nexa_model: # only update if load was successful + st.session_state.current_model_path = model_path + + # generation params: + if 'nexa_model' in st.session_state and st.session_state.nexa_model: + st.sidebar.header("Generation Parameters") + + model_to_check = (st.session_state.current_hub_model if st.session_state.current_hub_model else st.session_state.current_local_path if st.session_state.current_local_path else st.session_state.current_model_path) + + # get model specific defaults: + default_params = get_default_params(model_to_check) + + # adjust step range based on model type: + max_steps = 100 + if "lcm-dreamshaper" in model_to_check or "flux" in model_to_check: + max_steps = 8 # 4-8 steps + elif "sdxl-turbo" in model_to_check: + max_steps = 10 # 5-10 steps + + # adjust guidance scale range based on model type: + max_guidance = 20.0 + if "lcm-dreamshaper" in model_to_check or "flux" in model_to_check: + max_guidance = 2.0 # 1.0-2.0 + elif "sdxl-turbo" in model_to_check: + max_guidance = 10.0 # 5.0-10.0 + + num_inference_steps = st.sidebar.slider( + "Number of Inference Steps", + 1, + max_steps, + default_params["num_inference_steps"] + ) + height = st.sidebar.slider( + "Height", + 64, + 1024, + default_params["height"] + ) + width = st.sidebar.slider( + "Width", + 64, + 1024, + default_params["width"] + ) + guidance_scale = st.sidebar.slider( + "Guidance Scale", + 0.0, + max_guidance, + default_params["guidance_scale"] + ) + random_seed = st.sidebar.slider( + "Random Seed", + 0, + 10000, + default_params["random_seed"] + ) + + st.session_state.nexa_model.params.update({ + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "guidance_scale": guidance_scale, + "random_seed": random_seed, + }) + + # image generation interface: + prompt = st.text_input("Enter your prompt:") + negative_prompt = st.text_input("Enter your negative prompt (optional):") + + if st.button("Generate Image"): + if not prompt: + st.warning("Please enter a prompt to proceed.") + else: + with st.spinner("Generating images..."): + images = generate_images( + st.session_state.nexa_model, + prompt, + negative_prompt ) + st.success("Images generated successfully!") + for i, image in enumerate(images): + st.image(image, caption=f"Generated Image", use_column_width=True) + + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + + st.download_button( + label=f"Download Image", + data=img_byte_arr, + file_name=f"generated_image.png", + mime="image/png" + ) + +except Exception as e: + st.error(f"An unexpected error occurred: {str(e)}") + import traceback + st.error(f"Traceback: {traceback.format_exc()}") diff --git a/nexa/gguf/streamlit/streamlit_text_chat.py b/nexa/gguf/streamlit/streamlit_text_chat.py index 542e8059..4adf60e3 100644 --- a/nexa/gguf/streamlit/streamlit_text_chat.py +++ b/nexa/gguf/streamlit/streamlit_text_chat.py @@ -1,112 +1,313 @@ import sys -from typing import Iterator - +import subprocess +import re +from typing import Iterator, List import streamlit as st from nexa.general import pull_model from nexa.gguf.nexa_inference_text import NexaTextInference +from nexa.utils import ( + get_model_options, + update_model_options, +) +from nexa.constants import ( + DEFAULT_TEXT_GEN_PARAMS, + NEXA_RUN_MODEL_MAP_TEXT, +) -default_model = sys.argv[1] -is_local_path = False if sys.argv[2] == "False" else True -hf = False if sys.argv[3] == "False" else True - -@st.cache_resource -def load_model(model_path): - st.session_state.messages = [] - if is_local_path: - local_path = model_path - elif hf: - local_path, _ = pull_model(model_path, hf=True) - else: - local_path, run_type = pull_model(model_path) - nexa_model = NexaTextInference(model_path=model_path, local_path=local_path) - return nexa_model +specified_run_type = 'NLP' +model_map = NEXA_RUN_MODEL_MAP_TEXT + +# init: +DEFAULT_PARAMS = DEFAULT_TEXT_GEN_PARAMS.copy() + +@st.cache_resource(show_spinner=False) +def load_model(model_path: str, is_local: bool = False, is_hf: bool = False): + """Load model with proper error handling and state management.""" + try: + st.session_state.messages = [] + + if is_local: + local_path = model_path + elif is_hf: + try: + local_path, _ = pull_model(model_path, hf=True) + update_model_options(specified_run_type, model_map) # update options after successful pull + except Exception as e: + st.error(f"Error pulling HuggingFace model: {str(e)}") + return None + else: + try: + # model hub case: + local_path, run_type = pull_model(model_path) + if not local_path or not run_type: + st.error(f"Failed to pull model {model_path} from Nexa Model Hub") + return None + update_model_options(specified_run_type, model_map) # update options after successful pull + except ValueError as e: + st.error(f"Error pulling model from Nexa Model Hub: {str(e)}") + return None + except Exception as e: + st.error(f"Unexpected error while pulling model: {str(e)}") + return None + + try: + nexa_model = NexaTextInference( + model_path=model_path, + local_path=local_path, + **DEFAULT_PARAMS + ) + + # force refresh of model options after successful load: + update_model_options(specified_run_type, model_map) + + # reset the model index to include the new model: + if model_path in st.session_state.model_options: + st.session_state.current_model_index = st.session_state.model_options.index(model_path) + return nexa_model + + except Exception as e: + st.error(f"Error initializing model: {str(e)}") + return None + except Exception as e: + st.error(f"Error in load_model: {str(e)}") + return None + +@st.cache_resource(show_spinner=False) +def load_local_model(local_path: str): + """Load local model with default parameters.""" + try: + st.session_state.messages = [] + nexa_model = NexaTextInference( + model_path="local_model", + local_path=local_path, + **DEFAULT_PARAMS + ) + update_model_options(specified_run_type, model_map) # update options after successful local model load + return nexa_model + except Exception as e: + st.error(f"Error loading local model: {str(e)}") + return None def generate_response(nexa_model: NexaTextInference) -> Iterator: + """Generate response from the model.""" user_input = st.session_state.messages[-1]["content"] if hasattr(nexa_model, "chat_format") and nexa_model.chat_format: return nexa_model._chat(user_input) else: return nexa_model._complete(user_input) -st.markdown( - r""" - - """, - unsafe_allow_html=True, -) -st.title("Nexa AI Text Generation") -st.caption("Powered by Nexa AI SDK🐙") +# main execution: +try: + # get command line arguments with proper error handling: + if len(sys.argv) < 4: + st.error("Missing required command line arguments.") + sys.exit(1) # program terminated with an error -st.sidebar.header("Model Configuration") -model_path = st.sidebar.text_input("Model path", default_model) + default_model = sys.argv[1] + is_local_path = sys.argv[2].lower() == "true" + hf = sys.argv[3].lower() == "true" -if not model_path: - st.warning( - "Please enter a valid path or identifier for the model in Nexa Model Hub to proceed." + # UI setup: + st.set_page_config(page_title="Nexa AI Text Generation", layout="wide") + st.markdown( + r""" + + """, + unsafe_allow_html=True, ) - st.stop() - -if ( - "current_model_path" not in st.session_state - or st.session_state.current_model_path != model_path -): - st.session_state.current_model_path = model_path - st.session_state.nexa_model = load_model(model_path) - if st.session_state.nexa_model is None: - st.stop() - -st.sidebar.header("Generation Parameters") -temperature = st.sidebar.slider( - "Temperature", 0.0, 1.0, st.session_state.nexa_model.params["temperature"] -) -max_new_tokens = st.sidebar.slider( - "Max New Tokens", 1, 500, st.session_state.nexa_model.params["max_new_tokens"] -) -top_k = st.sidebar.slider("Top K", 1, 100, st.session_state.nexa_model.params["top_k"]) -top_p = st.sidebar.slider( - "Top P", 0.0, 1.0, st.session_state.nexa_model.params["top_p"] -) + st.title("Nexa AI Text Generation") + st.caption("Powered by Nexa AI SDK🐙") -st.session_state.nexa_model.params.update( - { - "temperature": temperature, - "max_new_tokens": max_new_tokens, - "top_k": top_k, - "top_p": top_p, - } -) + # force refresh model options on every page load: + if 'model_options' not in st.session_state: + st.session_state.model_options = get_model_options(specified_run_type, model_map) + else: + update_model_options(specified_run_type, model_map) + + # init session state variables: + if 'initialized' not in st.session_state: + st.session_state.messages = [] + st.session_state.current_model_path = None + st.session_state.current_local_path = None + st.session_state.current_hub_model = None + + if not is_local_path and not hf: + try: + with st.spinner(f"Loading model: {default_model}"): + st.session_state.nexa_model = load_model(default_model) + if st.session_state.nexa_model: + st.session_state.current_hub_model = default_model + except Exception as e: + st.error(f"Error loading default model: {str(e)}") + + # set to model hub option if not found in list: + if default_model not in st.session_state.model_options: + st.session_state.current_model_index = st.session_state.model_options.index("Use Model From Nexa Model Hub 🔍") + else: + try: + st.session_state.current_model_index = st.session_state.model_options.index(default_model) + except ValueError: + st.session_state.current_model_index = 0 + + st.session_state.initialized = True + + # model selection sidebar: + st.sidebar.header("Model Configuration") + + # update the selectbox index based on the currently loaded model: + if 'nexa_model' in st.session_state: + if st.session_state.current_hub_model: + # if we have a hub model loaded, select the hub option: + current_index = st.session_state.model_options.index("Use Model From Nexa Model Hub 🔍") + elif st.session_state.current_local_path: + # if we have a local model loaded, select the local option: + current_index = st.session_state.model_options.index("Local Model 📁") + elif st.session_state.current_model_path: + # if we have a listed model loaded, find its index: + current_index = st.session_state.model_options.index(st.session_state.current_model_path) + else: + current_index = st.session_state.current_model_index + else: + current_index = st.session_state.current_model_index + + model_path = st.sidebar.selectbox( + "Select a Model", + st.session_state.model_options, + index=current_index, + key='model_selectbox' + ) + + # update current model index when selection changes: + current_index = st.session_state.model_options.index(model_path) + if current_index != st.session_state.current_model_index: + st.session_state.current_model_index = current_index + if 'nexa_model' in st.session_state: + del st.session_state.nexa_model + st.session_state.messages = [] + st.session_state.current_model_path = None + st.session_state.current_local_path = None + st.session_state.current_hub_model = None + + # handle model loading based on selection: + if model_path == "Local Model 📁": + local_model_path = st.sidebar.text_input("Enter local model path") + if not local_model_path: + st.warning("Please enter a valid local model path to proceed.") + st.stop() + + local_model_path = local_model_path.strip() # remove spaces + if 'nexa_model' not in st.session_state or st.session_state.current_local_path != local_model_path: + with st.spinner("Loading local model..."): + st.session_state.nexa_model = load_local_model(local_model_path) + st.session_state.current_local_path = local_model_path + + elif model_path == "Use Model From Nexa Model Hub 🔍": + initial_value = default_model if not is_local_path and not hf else "" + hub_model_name = st.sidebar.text_input( + "Enter model name from Nexa Model Hub", + value=initial_value + ) + + # empty string check: + if not hub_model_name: + st.warning(f""" + How to add a model from Nexa Model Hub: + \n1. Visit [Nexa Model Hub](https://nexaai.com/models) + \n2. Find a NLP model using the task filters (chat, uncensored, etc.) + \n3. Select your desired model and copy either: + \n - The full nexa run command (e.g., nexa run Sao10K/MN-BackyardAI-Party-12B-v1:gguf-q4_K_M), or + \n - Simply the model name (e.g., Sao10K/MN-BackyardAI-Party-12B-v1:gguf-q4_K_M) + \n4. Paste it into the "Enter model name from Nexa Model Hub" field on the sidebar and press enter + """) + st.stop() + + # process the input after checking it's not empty: + if hub_model_name.startswith("nexa run"): + hub_model_name = hub_model_name.split("nexa run")[-1].strip() + else: + hub_model_name = hub_model_name.strip() + + if 'nexa_model' not in st.session_state or st.session_state.current_hub_model != hub_model_name: + with st.spinner("Loading model from hub..."): + st.session_state.nexa_model = load_model(hub_model_name) + if st.session_state.nexa_model: # only update if load was successful + st.session_state.current_hub_model = hub_model_name + + else: + # load selected model if it's not already loaded: + if ('nexa_model' not in st.session_state or + getattr(st.session_state, 'current_model_path', None) != model_path): + with st.spinner(f"Loading model: {model_path}"): + st.session_state.nexa_model = load_model(model_path) + if st.session_state.nexa_model: # only update if load was successful + st.session_state.current_model_path = model_path + + # generation params: + if 'nexa_model' in st.session_state and st.session_state.nexa_model: + st.sidebar.header("Generation Parameters") + model_params = st.session_state.nexa_model.params + + temperature = st.sidebar.slider( + "Temperature", 0.0, 1.0, model_params.get("temperature", DEFAULT_PARAMS["temperature"]) + ) + max_new_tokens = st.sidebar.slider( + "Max New Tokens", 1, 500, model_params.get("max_new_tokens", DEFAULT_PARAMS["max_new_tokens"]) + ) + top_k = st.sidebar.slider( + "Top K", 1, 100, model_params.get("top_k", DEFAULT_PARAMS["top_k"]) + ) + top_p = st.sidebar.slider( + "Top P", 0.0, 1.0, model_params.get("top_p", DEFAULT_PARAMS["top_p"]) + ) + nctx = st.sidebar.slider( + "Context length", 1000, 9999, model_params.get("nctx", DEFAULT_PARAMS["nctx"]) + ) + + st.session_state.nexa_model.params.update({ + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "top_k": top_k, + "top_p": top_p, + "nctx": nctx, + }) + + # chat interface: + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + if prompt := st.chat_input("Say something..."): + if 'nexa_model' not in st.session_state or not st.session_state.nexa_model: + st.error("Please wait for the model to load or select a valid model.") + else: + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.markdown(prompt) + + with st.chat_message("assistant"): + response_placeholder = st.empty() + full_response = "" + for chunk in generate_response(st.session_state.nexa_model): + choice = chunk["choices"][0] + if "delta" in choice: + delta = choice["delta"] + content = delta.get("content", "") + elif "text" in choice: + delta = choice["text"] + content = delta + + full_response += content + response_placeholder.markdown(full_response, unsafe_allow_html=True) + response_placeholder.markdown(full_response) + + st.session_state.messages.append({"role": "assistant", "content": full_response}) -if "messages" not in st.session_state: - st.session_state.messages = [] - -for message in st.session_state.messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) - -if prompt := st.chat_input("Say something..."): - st.session_state.messages.append({"role": "user", "content": prompt}) - with st.chat_message("user"): - st.markdown(prompt) - - with st.chat_message("assistant"): - response_placeholder = st.empty() - full_response = "" - for chunk in generate_response(st.session_state.nexa_model): - choice = chunk["choices"][0] - if "delta" in choice: - delta = choice["delta"] - content = delta.get("content", "") - elif "text" in choice: - delta = choice["text"] - content = delta - - full_response += content - response_placeholder.markdown(full_response, unsafe_allow_html=True) - response_placeholder.markdown(full_response) - - st.session_state.messages.append({"role": "assistant", "content": full_response}) +except Exception as e: + st.error(f"An unexpected error occurred: {str(e)}") + import traceback + st.error(f"Traceback: {traceback.format_exc()}") diff --git a/nexa/gguf/streamlit/streamlit_vlm.py b/nexa/gguf/streamlit/streamlit_vlm.py index 25f48d0e..a581b167 100644 --- a/nexa/gguf/streamlit/streamlit_vlm.py +++ b/nexa/gguf/streamlit/streamlit_vlm.py @@ -1,40 +1,65 @@ import sys import tempfile -from typing import Iterator - +import subprocess +import re +from typing import List, Iterator import streamlit as st from PIL import Image from nexa.general import pull_model from nexa.gguf.nexa_inference_vlm import NexaVLMInference +from nexa.utils import ( + get_model_options, + update_model_options, +) +from nexa.constants import NEXA_RUN_MODEL_MAP_VLM -default_model = sys.argv[1] -is_local_path = False if sys.argv[2] == "False" else True -hf = False if sys.argv[3] == "False" else True -projector_local_path = sys.argv[4] if len(sys.argv) > 4 else None +specified_run_type = 'Multimodal' +model_map = NEXA_RUN_MODEL_MAP_VLM +# init from command line args: +try: + default_model = sys.argv[1] + is_local_path = sys.argv[2].lower() == "true" + hf = sys.argv[3].lower() == "true" + projector_local_path = sys.argv[4] if len(sys.argv) > 4 else None +except IndexError: + st.error("Missing required command line arguments.") + sys.exit(1) # terminate with an error -@st.cache_resource -def load_model(model_path): - if is_local_path: - local_path = model_path - elif hf: - local_path, _ = pull_model(model_path, hf=True) - else: - local_path, run_type = pull_model(model_path) - - if is_local_path: - nexa_model = NexaVLMInference(model_path=model_path, local_path=local_path, projector_local_path=projector_local_path) - else: - nexa_model = NexaVLMInference(model_path=model_path, local_path=local_path) - return nexa_model - +@st.cache_resource(show_spinner=False) +def load_model(model_path, is_local=False, is_hf=False, projector_path=None): + """Load model with model mapping logic.""" + try: + if is_local: + local_path = model_path + nexa_model = NexaVLMInference( + model_path=model_path, + local_path=local_path, + projector_local_path=projector_path + ) + elif is_hf: + local_path, _ = pull_model(model_path, hf=True) + nexa_model = NexaVLMInference(model_path=model_path, local_path=local_path) + else: + # get the actual model name from the mapping if it exists: + if model_path in NEXA_RUN_MODEL_MAP_VLM: + real_model_path = NEXA_RUN_MODEL_MAP_VLM[model_path] + local_path, run_type = pull_model(real_model_path) + else: + local_path, run_type = pull_model(model_path) + nexa_model = NexaVLMInference(model_path=model_path, local_path=local_path) + return nexa_model + except Exception as e: + st.error(f"Error loading model: {str(e)}") + return None def generate_response( nexa_model: NexaVLMInference, image_path: str, user_input: str ) -> Iterator: return nexa_model._chat(user_input, image_path) - +# UI setup: +st.set_page_config(page_title="Nexa AI Multimodal Generation", layout="wide") st.markdown( r""" + """, + unsafe_allow_html=True, +) +st.title("Nexa AI Omni VLM Generation") +st.caption("Powered by Nexa AI SDK🐙") + +st.sidebar.header("Model Configuration") +model_path = st.sidebar.text_input("Model path", default_model) + +if not model_path: + st.warning( + "Please enter a valid path or identifier for the model in Nexa Model Hub to proceed." + ) + st.stop() + +if ( + "current_model_path" not in st.session_state + or st.session_state.current_model_path != model_path +): + st.session_state.current_model_path = model_path + st.session_state.nexa_model = load_model(model_path) + if st.session_state.nexa_model is None: + st.stop() + +user_input = st.text_input("Enter your text input:") +uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) + +generate_button = st.button("Send") +spinner_placeholder = st.empty() +success_label = st.empty() +response_placeholder = st.empty() +image_placeholder = st.empty() + +if uploaded_file: + image_placeholder.image(uploaded_file, caption="Uploaded Image") + +if generate_button: + if not user_input and not uploaded_file: + st.warning("Please enter text input and upload an image to proceed.") + else: + with spinner_placeholder: + with st.spinner("Generating description..."): + with tempfile.NamedTemporaryFile() as image_path: + full_path = None + if uploaded_file: + ext = uploaded_file.name.split(".")[-1] + full_path = f"{image_path.name}.{ext}" + with Image.open(uploaded_file) as img: + img.save(full_path) + + response = generate_response( + st.session_state.nexa_model, full_path, user_input + ) + + response_placeholder.write(response) + success_label.success("Response generated successfully.") \ No newline at end of file diff --git a/nexa/gguf/streamlit/streamlit_voice_chat.py b/nexa/gguf/streamlit/streamlit_voice_chat.py index 77c4b3c1..750217f9 100644 --- a/nexa/gguf/streamlit/streamlit_voice_chat.py +++ b/nexa/gguf/streamlit/streamlit_voice_chat.py @@ -2,30 +2,56 @@ import os import sys import tempfile - -import librosa +import subprocess +import re +from typing import List import streamlit as st from st_audiorec import st_audiorec - from nexa.general import pull_model from nexa.gguf.nexa_inference_voice import NexaVoiceInference +from nexa.utils import ( + get_model_options, + update_model_options, +) +from nexa.constants import NEXA_RUN_MODEL_MAP_VOICE -default_model = sys.argv[1] -is_local_path = False if sys.argv[2] == "False" else True -hf = False if sys.argv[3] == "False" else True - +specified_run_type = 'Audio' +model_map = NEXA_RUN_MODEL_MAP_VOICE -@st.cache_resource -def load_model(model_path): - if is_local_path: - local_path = model_path - elif hf: - local_path, _ = pull_model(model_path, hf=True) - else: - local_path, run_type = pull_model(model_path) - nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) - return nexa_model +# init from command line args: +try: + default_model = sys.argv[1] + is_local_path = sys.argv[2].lower() == "true" + hf = sys.argv[3].lower() == "true" +except IndexError: + st.error("Missing required command line arguments.") + sys.exit(1) # terminate with an error +@st.cache_resource(show_spinner=False) +def load_model(model_path, is_local=False, is_hf=False): + """Load model with model mapping logic.""" + try: + if is_local: + # for local paths, use the path directly: + nexa_model = NexaVoiceInference(model_path=model_path, local_path=model_path) + else: + # for non-local paths: + if is_hf: + local_path, _ = pull_model(model_path, hf=True) + nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) + else: + # handle Model Hub models: + if model_path in NEXA_RUN_MODEL_MAP_VOICE: + real_model_path = NEXA_RUN_MODEL_MAP_VOICE[model_path] + local_path, _ = pull_model(real_model_path) + nexa_model = NexaVoiceInference(model_path=real_model_path, local_path=local_path) + else: + local_path, _ = pull_model(model_path) + nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) + return nexa_model + except Exception as e: + st.error(f"Error loading model: {str(e)}") + return None def transcribe_audio(nexa_model, audio_file): with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: @@ -50,105 +76,223 @@ def transcribe_audio(nexa_model, audio_file): finally: os.unlink(temp_audio_path) - +# UI setup: +st.set_page_config(page_title="Nexa AI Voice Transcription", layout="wide") st.title("Nexa AI Voice Transcription") st.caption("Powered by Nexa AI SDK🐙") +# force refresh model options on every page load: +if 'model_options' not in st.session_state: + st.session_state.model_options = get_model_options(specified_run_type, model_map) +else: + update_model_options(specified_run_type, model_map) + +# init session state variables: +if 'initialized' not in st.session_state: + st.session_state.model_options = get_model_options(specified_run_type, model_map) + st.session_state.current_model_path = default_model + st.session_state.current_local_path = None + st.session_state.current_hub_model = None + + # init with default model: + if is_local_path: + try: + with st.spinner(f"Loading local model: {default_model}"): + st.session_state.nexa_model = load_model( + default_model, + is_local=True, + is_hf=hf + ) + if st.session_state.nexa_model: + st.session_state.current_local_path = default_model + st.session_state.current_model_path = default_model + except Exception as e: + st.error(f"Error loading local model: {str(e)}") + elif hf: + try: + with st.spinner(f"Loading HuggingFace model: {default_model}"): + st.session_state.nexa_model = load_model(default_model, is_hf=True) + if st.session_state.nexa_model: + st.session_state.current_hub_model = default_model + st.session_state.current_model_path = default_model + except Exception as e: + st.error(f"Error loading HuggingFace model: {str(e)}") + + else: + try: + with st.spinner(f"Loading model: {default_model}"): + st.session_state.nexa_model = load_model(default_model) + if st.session_state.nexa_model: + st.session_state.current_model_path = default_model + st.session_state.current_hub_model = default_model + except Exception as e: + st.error(f"Error loading model: {str(e)}") + + st.session_state.initialized = True + +# model selection UI: st.sidebar.header("Model Configuration") -model_path = st.sidebar.text_input("Model path", default_model) - -if not model_path: - st.warning("Please enter a valid S3 model filename to proceed.") - st.stop() - -# Initialize or update the model when the path changes -if ( - "current_model_path" not in st.session_state - or st.session_state.current_model_path != model_path -): - st.session_state.current_model_path = model_path - st.session_state.nexa_model = load_model(model_path) - if st.session_state.nexa_model is None: - st.stop() -# Add sidebar options for new parameters -st.sidebar.header("Transcription Parameters") -beam_size = st.sidebar.slider( - "Beam Size", - 1, 10, - st.session_state.nexa_model.params["beam_size"] -) -task = st.sidebar.selectbox( - "Task", - ["transcribe", "translate"], - index=0 if st.session_state.nexa_model.params["task"] == "transcribe" else 1 -) -temperature = st.sidebar.slider( - "Temperature", - 0.0, 1.0, - st.session_state.nexa_model.params["temperature"], - step=0.1 +# update selectbox index based on current model: +current_index = st.session_state.model_options.index("Use Model From Nexa Model Hub 🔍") +if 'nexa_model' in st.session_state: + if st.session_state.current_model_path in st.session_state.model_options: + current_index = st.session_state.model_options.index(st.session_state.current_model_path) + elif st.session_state.current_hub_model: + current_index = st.session_state.model_options.index("Use Model From Nexa Model Hub 🔍") + elif st.session_state.current_local_path: + current_index = st.session_state.model_options.index("Local Model 📁") + +selected_option = st.sidebar.selectbox( + "Select a Model", + st.session_state.model_options, + index=current_index ) -# Update model parameters -st.session_state.nexa_model.params.update( - { +# handle model selection: +if selected_option == "Local Model 📁": + model_path = st.sidebar.text_input( + "Enter local model path", + value=st.session_state.current_local_path if hasattr(st.session_state, 'current_local_path') else "", + help="Enter the full path to your local model directory (e.g., /home/user/.cache/nexa/hub/official/model-name)" + ) + + if not model_path: + st.warning("Please enter a valid local model path to proceed.") + st.stop() + + if (not hasattr(st.session_state, 'current_local_path') or + st.session_state.current_local_path != model_path): + with st.spinner("Loading local model..."): + st.session_state.nexa_model = load_model( + model_path, # use the user input path + is_local=True, + is_hf=hf + ) + if st.session_state.nexa_model: + st.session_state.current_local_path = model_path + st.session_state.current_model_path = model_path + +elif selected_option == "Use Model From Nexa Model Hub 🔍": + model_path = st.sidebar.text_input( + "Enter model name from Nexa Model Hub", + value=st.session_state.current_hub_model if hasattr(st.session_state, 'current_hub_model') else default_model + ) + if not model_path: + st.warning(""" + How to add a model from Nexa Model Hub: + \n1. Visit [Nexa Model Hub](https://nexaai.com/models) + \n2. Find an audio model using the task filters + \n3. Select your desired model and copy either: + \n - The full nexa run command (e.g., nexa run faster-whisper-tiny:bin-cpu-fp16), or + \n - Simply the model name (e.g., faster-whisper-tiny:bin-cpu-fp16) + \n4. Paste it into the field on the sidebar and press enter + """) + st.stop() + + # process the input after checking it's not empty: + if model_path.startswith("nexa run"): + model_path = model_path.split("nexa run")[-1].strip() + + if (not hasattr(st.session_state, 'current_hub_model') or st.session_state.current_hub_model != model_path): + with st.spinner("Loading model from hub..."): + st.session_state.nexa_model = load_model(model_path, is_local=False, is_hf=False) + if st.session_state.nexa_model: + st.session_state.current_hub_model = model_path + st.session_state.current_model_path = model_path + st.session_state.current_local_path = None # clear local path state when switching to hub model + +else: + model_path = selected_option + if (not hasattr(st.session_state, 'current_model_path') or + st.session_state.current_model_path != model_path): + with st.spinner(f"Loading model: {model_path}"): + st.session_state.nexa_model = load_model(model_path, is_local=False, is_hf=False) + if st.session_state.nexa_model: + st.session_state.current_model_path = model_path + st.session_state.current_local_path = None + st.session_state.current_hub_model = None + +# only show transcription parameters if model is loaded: +if hasattr(st.session_state, 'nexa_model') and st.session_state.nexa_model: + # transcription parameters: + st.sidebar.header("Transcription Parameters") + beam_size = st.sidebar.slider( + "Beam Size", + 1, 10, + st.session_state.nexa_model.params["beam_size"] + ) + task = st.sidebar.selectbox( + "Task", + ["transcribe", "translate"], + index=0 if st.session_state.nexa_model.params["task"] == "transcribe" else 1 + ) + temperature = st.sidebar.slider( + "Temperature", + 0.0, 1.0, + st.session_state.nexa_model.params["temperature"], + step=0.1 + ) + + # update model parameters: + st.session_state.nexa_model.params.update({ "beam_size": beam_size, "task": task, "temperature": temperature, - } -) + }) -# Option 1: Upload Audio File -st.header("Option 1: Upload Audio File") -uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3"]) - -if uploaded_file is not None: - st.audio(uploaded_file, format="audio/wav") - - if st.button("Transcribe Uploaded Audio"): - with st.spinner("Transcribing audio..."): - transcription = transcribe_audio(st.session_state.nexa_model, uploaded_file) - - if transcription: - st.subheader("Transcription:") - st.write(transcription) - - # Provide a download button for the transcription - transcription_bytes = transcription.encode() - st.download_button( - label="Download Transcription", - data=transcription_bytes, - file_name="transcription.txt", - mime="text/plain", - ) - else: - st.error( - "Transcription failed. Please try again with a different audio file." - ) - -# Option 2: Real-time Recording -st.header("Option 2: Record Audio") -wav_audio_data = st_audiorec() - -if wav_audio_data: - if st.button("Transcribe Recorded Audio"): - with st.spinner("Transcribing audio..."): - transcription = transcribe_audio(st.session_state.nexa_model, io.BytesIO(wav_audio_data)) - - if transcription: - st.subheader("Transcription:") - st.write(transcription) - - # Provide a download button for the transcription - transcription_bytes = transcription.encode() - st.download_button( - label="Download Transcription", - data=transcription_bytes, - file_name="transcription.txt", - mime="text/plain", - ) - else: - st.error("Transcription failed. Please try recording again.") + # Option 1: Upload Audio File + st.header("Option 1: Upload Audio File") + uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3"]) + + if uploaded_file is not None: + st.audio(uploaded_file, format="audio/wav") + + if st.button("Transcribe Uploaded Audio"): + with st.spinner("Transcribing audio..."): + transcription = transcribe_audio(st.session_state.nexa_model, uploaded_file) + + if transcription: + st.subheader("Transcription:") + st.write(transcription) + + # Provide a download button for the transcription + transcription_bytes = transcription.encode() + st.download_button( + label="Download Transcription", + data=transcription_bytes, + file_name="transcription.txt", + mime="text/plain", + ) + else: + st.error( + "Transcription failed. Please try again with a different audio file." + ) + + # Option 2: Real-time Recording + st.header("Option 2: Record Audio") + wav_audio_data = st_audiorec() + + if wav_audio_data: + if st.button("Transcribe Recorded Audio"): + with st.spinner("Transcribing audio..."): + transcription = transcribe_audio(st.session_state.nexa_model, io.BytesIO(wav_audio_data)) + + if transcription: + st.subheader("Transcription:") + st.write(transcription) + + # Provide a download button for the transcription + transcription_bytes = transcription.encode() + st.download_button( + label="Download Transcription", + data=transcription_bytes, + file_name="transcription.txt", + mime="text/plain", + ) + else: + st.error("Transcription failed. Please try recording again.") + else: + st.warning("No audio recorded. Please record some audio before transcribing.") else: - st.warning("No audio recorded. Please record some audio before transcribing.") + st.warning("Please select or load a model to proceed.") diff --git a/nexa/onnx/nexa_inference_image.py b/nexa/onnx/nexa_inference_image.py index 8392aa31..b38be335 100644 --- a/nexa/onnx/nexa_inference_image.py +++ b/nexa/onnx/nexa_inference_image.py @@ -64,10 +64,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.params.update(kwargs) self.pipeline = None - def run(self): - if self.download_onnx_folder is None: - self.download_onnx_folder, run_type = pull_model(self.model_path, **kwargs) + self.download_onnx_folder, _ = pull_model(self.model_path, **kwargs) if self.download_onnx_folder is None: logging.error( @@ -76,17 +74,19 @@ def run(self): ) exit(1) - self._load_model(self.download_onnx_folder) + self._load_model() + + def run(self): self._dialogue_mode() @SpinningCursorAnimation() - def _load_model(self, model_path): + def _load_model(self): """ Load the model from the given model path using the appropriate pipeline. """ - logging.debug(f"Loading model from {model_path}") + logging.debug(f"Loading model from {self.download_onnx_folder}") try: - model_index_path = os.path.join(model_path, "model_index.json") + model_index_path = os.path.join(self.download_onnx_folder, "model_index.json") with open(model_index_path, "r") as index_file: model_index = json.load(index_file) @@ -96,7 +96,7 @@ def _load_model(self, model_path): PipelineClass = ORT_PIPELINES_MAPPING.get( pipeline_class_name, ORTStableDiffusionPipeline ) - self.pipeline = PipelineClass.from_pretrained(model_path) + self.pipeline = PipelineClass.from_pretrained(self.download_onnx_folder) logging.debug(f"Model loaded successfully using {pipeline_class_name}") except Exception as e: logging.error(f"Error loading model: {e}") diff --git a/nexa/onnx/nexa_inference_text.py b/nexa/onnx/nexa_inference_text.py index fdb6db5f..f9a767e9 100644 --- a/nexa/onnx/nexa_inference_text.py +++ b/nexa/onnx/nexa_inference_text.py @@ -53,9 +53,21 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.downloaded_onnx_folder = local_path self.timings = kwargs.get("timings", False) self.device = "cpu" + + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs) + + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + + self._load_model_and_tokenizer() @SpinningCursorAnimation() - def _load_model_and_tokenizer(self) -> Tuple[Any, Any, Any, bool]: + def _load_model_and_tokenizer(self): logging.debug(f"Loading model from {self.downloaded_onnx_folder}") start_time = time.time() self.tokenizer = AutoTokenizer.from_pretrained(self.downloaded_onnx_folder) @@ -148,18 +160,6 @@ def run(self): if self.params.get("streamlit"): self.run_streamlit() else: - if self.downloaded_onnx_folder is None: - self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs) - - if self.downloaded_onnx_folder is None: - logging.error( - f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", - exc_info=True, - ) - exit(1) - - self._load_model_and_tokenizer() - if self.model is None or self.tokenizer is None or self.streamer is None: logging.error( "Failed to load model or tokenizer. Exiting.", exc_info=True diff --git a/nexa/onnx/nexa_inference_tts.py b/nexa/onnx/nexa_inference_tts.py index fb1f2f9a..26c6d3e4 100644 --- a/nexa/onnx/nexa_inference_tts.py +++ b/nexa/onnx/nexa_inference_tts.py @@ -50,8 +50,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.downloaded_onnx_folder = local_path if self.downloaded_onnx_folder is None: - self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs) - + self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs) + if self.downloaded_onnx_folder is None: logging.error( f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", @@ -69,12 +69,10 @@ def _load_model(self): logging.debug(f"Loading model from {self.downloaded_onnx_folder}") try: self.tokenizer = TTSTokenizer(self.config["token"]["list"]) - print(self.tokenizer) self.model = onnxruntime.InferenceSession( os.path.join(self.downloaded_onnx_folder, "model.onnx"), providers=["CPUExecutionProvider"], ) - print(self.model) logging.debug("Model and tokenizer loaded successfully") except Exception as e: logging.error(f"Error loading model or tokenizer: {e}") diff --git a/nexa/onnx/nexa_inference_voice.py b/nexa/onnx/nexa_inference_voice.py index e6d7d696..c0f56ab4 100644 --- a/nexa/onnx/nexa_inference_voice.py +++ b/nexa/onnx/nexa_inference_voice.py @@ -43,9 +43,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.model = None self.processor = None - def run(self): if self.downloaded_onnx_folder is None: - self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs) + self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs) if self.downloaded_onnx_folder is None: logging.error( @@ -54,14 +53,16 @@ def run(self): ) exit(1) - self._load_model(self.downloaded_onnx_folder) + self._load_model() + + def run(self): self._dialogue_mode() - def _load_model(self, model_path): - logging.debug(f"Loading model from {model_path}") + def _load_model(self): + logging.debug(f"Loading model from {self.downloaded_onnx_folder}") try: - self.processor = AutoProcessor.from_pretrained(model_path) - self.model = ORTModelForSpeechSeq2Seq.from_pretrained(model_path) + self.processor = AutoProcessor.from_pretrained(self.downloaded_onnx_folder) + self.model = ORTModelForSpeechSeq2Seq.from_pretrained(self.downloaded_onnx_folder) logging.debug("Model and processor loaded successfully") except Exception as e: logging.error(f"Error loading model or processor: {e}") diff --git a/nexa/utils.py b/nexa/utils.py index e7761ece..28440cac 100644 --- a/nexa/utils.py +++ b/nexa/utils.py @@ -5,8 +5,97 @@ import time from functools import partial, wraps from importlib.metadata import PackageNotFoundError, distribution +from typing import Dict, List +import json +import logging +import streamlit as st +from nexa.constants import ( + EXIT_COMMANDS, + EXIT_REMINDER, + NEXA_MODEL_LIST_PATH, +) + + +def get_available_models() -> Dict[str, dict]: + """Get list of available computer vision (cv) models from the model list JSON file.""" + # check whether the model list file exists: + if not NEXA_MODEL_LIST_PATH.exists(): + st.error("Model list file not found") + return {} # empty dict + + try: + # read model list from the JSON file: + with open(NEXA_MODEL_LIST_PATH, "r") as f: + available_models = json.load(f) + return available_models + + except json.JSONDecodeError as e: + logging.error(f"Invalid JSON in model list file: {e}") + return {} + except Exception as e: + logging.error(f"Error loading available models: {e}") + return {} + + +def filter_available_models( + models: Dict[str, dict], + specified_run_type: str, + model_map: Dict[str, str] +) -> List[str]: + """Filter available models by run type and apply model mapping.""" + if not models: + return [] + + filtered_models = set() # to avoid duplicates + + for model_name, model_info in models.items(): + # skip if run_type doesn't match: + if model_info.get('run_type') != specified_run_type or 'projector' in model_name: + continue + + if model_name in model_map.values(): + # find short form from mapping: + for short_name, full_name in model_map.items(): + if full_name == model_name: + filtered_models.add(short_name) + break + else: + filtered_models.add(model_name) + + return sorted(list(filtered_models)) -from nexa.constants import EXIT_COMMANDS, EXIT_REMINDER + +def get_model_options( + specified_run_type: str, + model_map: Dict[str, str] +) -> List[str]: + """Get list of model options including special options.""" + available_models = get_available_models() + models_list = filter_available_models(available_models, specified_run_type, model_map) + # add special options at the end of the dropdown menu: + models_list.extend(["Use Model From Nexa Model Hub 🔍", "Local Model 📁"]) + return models_list + + +def update_model_options( + specified_run_type: str, + model_map: Dict[str, str] +) -> None: + """Update the model options in session state and force a refresh.""" + try: + fresh_options = get_model_options(specified_run_type, model_map) + st.session_state.model_options = fresh_options # update session state with new options + + if hasattr(st.session_state, 'current_model_path') and st.session_state.current_model_path: + if st.session_state.current_model_path in fresh_options: + st.session_state.current_model_index = fresh_options.index(st.session_state.current_model_path) + else: + # if current model not in list, reset to Model Hub option: + hub_index = fresh_options.index("Use Model From Nexa Model Hub 🔍") + st.session_state.current_model_index = hub_index + + except Exception as e: + logging.error(f"Error updating model options: {e}") def is_package_installed(package_name: str) -> bool: @@ -84,14 +173,14 @@ def nexa_prompt(placeholder: str = "Send a message ...") -> str: try: hint = light_text(placeholder) hint_length = len(strip_ansi(hint)) - + # Print the prompt with placeholder print(f">>> {hint}", end='', flush=True) - + # Move cursor back to the start of the line print('\r', end='', flush=True) print(">>> ", end='', flush=True) - + user_input = "" while True: char = msvcrt.getch().decode() @@ -108,7 +197,7 @@ def nexa_prompt(placeholder: str = "Send a message ...") -> str: else: user_input += char print(char, end='', flush=True) - + if len(user_input) == 1: # Clear hint after first character print('\r' + ' ' * (hint_length + 4), end='', flush=True) print(f'\r>>> {user_input}', end='', flush=True) diff --git a/pyproject.toml b/pyproject.toml index 071faca5..24b6ee35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["scikit-build-core"] +requires = ["scikit-build-core", "setuptools>=64.0"] build-backend = "scikit_build_core.build" [project] @@ -22,6 +22,7 @@ dependencies = [ "pydantic", "pillow", "huggingface_hub", + "modelscope", "prompt_toolkit", "tqdm", # Shared dependencies "tabulate", @@ -105,6 +106,7 @@ wheel.packages = [ "nexa.onnx.streamlit", "nexa.onnx.server", "nexa.eval", + "nexa.transformers", ] sdist.include = [ "CMakeLists.txt", diff --git a/requirements.txt b/requirements.txt index 978b8c1d..6e732a0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ pydantic pillow python-multipart huggingface_hub +modelscope # For onnx optimum[onnxruntime] # for CPU version diff --git a/swift/README.md b/swift/README.md new file mode 100644 index 00000000..0b284330 --- /dev/null +++ b/swift/README.md @@ -0,0 +1,84 @@ +# NexaSwift + +**NexaSwift** is a Swift wrapper for the [llama.cpp](https://github.com/ggerganov/llama.cpp.git) library. This repository provides a Swifty API, allowing Swift developers to easily integrate and use `llama.cpp` models in their projects. +**NOTE:** Currently, we support text inference capabilities. + +## Installation + +To add NexaSwift to your Swift project, add the following dependency in your `Package.swift` file: + +```swift +.package(url: "https://github.com/NexaAI/nexa-sdk.git", .branch("main")) +``` + +## Usage + +### 1. Initialize NexaSwift with Model Path + +Create a configuration and initialize NexaSwift with the path to your model file: + +```swift +let configuration = NexaSwift.Configuration( + maxNewToken: 128, + stopTokens: [] +) +let modelPath = "path/to/your/model" +let nexaSwift = try NexaSwift.NexaTextInference(modelPath: modelPath, modelConfiguration: configuration) +``` + +### 2 Completion chat API + +#### Generate messages + +```swift +var messages:[ChatCompletionRequestMessage] = [] +let userMessage = ChatCompletionRequestMessage.user( + ChatCompletionRequestUserMessage(content: .text("user input")) +) +messages.append(userMessage) +``` + +#### Non-Streaming Mode + +For non-streaming mode, simply call the start method with your prompt. This will return the complete response once it’s available. + +```swift +let response = try await nexaSwift.createChatCompletion(for: messages) +print(response.choices[0].message.content ?? "") +``` + +#### Streaming Mode + +In streaming mode, you can process the response in real-time as it’s generated: + +```swift +for try await response in await nexaSwift.createChatCompletionStream(for: messages) { + print(response.choices[0].delta.content ?? "") +} +``` + +### 3 Completion API + +#### Non-Streaming Mode + +```swift +if let response = try? await nexaSwift.createCompletion(for: prompt) { + print(response.choices[0].text)) +} +``` + +#### Streaming Mode + +```swift +for try await response in await nexaSwift.createCompletionStream(for: prompt) { + print(response.choices[0].text) +} +``` + +## Quick Start + +Open the [swift test project](../examples/swift-test/) folder in Xcode and run the project. + +## Download Models + +NexaSwift supports all models compatible with llama.cpp. You can download models from the [Nexa AI ModelHub](https://nexa.ai/models). diff --git a/swift/Sources/NexaSwift/LlamaModel.swift b/swift/Sources/NexaSwift/LlamaModel.swift new file mode 100644 index 00000000..d8bc6f68 --- /dev/null +++ b/swift/Sources/NexaSwift/LlamaModel.swift @@ -0,0 +1,207 @@ +import Foundation +import llama + +class LlamaModel { + private let model: Model + public var configuration: Configuration + private let context: OpaquePointer + private var sampler: UnsafeMutablePointer + private var batch: Batch + private var tokens: [Token] + private var temporaryInvalidCChars: [CChar] = [] + private var generatedTokenAccount: Int32 = 0 + private var totalTokensProcessed: Int32 = 0 + private var ended = false + private let n_ctx: Int32 + public var arch: String { + return getModelDetails().arch + } + public var modelType: String { + return getModelDetails().modelType + } + public var modelFtype: String { + return getModelDetails().modelFtype + } + + var shouldContinue: Bool { + generatedTokenAccount < configuration.maxNewToken && !ended + } + + public func reset() { + generatedTokenAccount = 0 + ended = false + } + + init(path: String, configuration: Configuration = .init()) throws { + self.configuration = configuration + llama_backend_init() + llama_numa_init(GGML_NUMA_STRATEGY_DISABLED) + var model_params = llama_model_default_params() + #if os(iOS) || targetEnvironment(simulator) + model_params.n_gpu_layers = 0 + #endif + guard let model = llama_load_model_from_file(path, model_params) else { + throw NexaSwiftError.others("Cannot load model at path \(path)") + } + + self.model = model + + guard let context = llama_new_context_with_model(model, configuration.contextParameters) else { + throw NexaSwiftError.others("Cannot load model context") + } + self.context = context + self.n_ctx = Int32(llama_n_ctx(context)) + self.tokens = [] + self.sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) + self.batch = llama_batch_init(configuration.nTokens, 0, 1) + try checkContextLength() + } + + public func updateSampler() { + self.sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) + llama_sampler_chain_add(sampler, llama_sampler_init_temp(configuration.temperature)) + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(configuration.topK)) + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(configuration.topP, 1)) + llama_sampler_chain_add(sampler, llama_sampler_init_softmax()) + llama_sampler_chain_add(sampler, llama_sampler_init_dist(configuration.seed)) + } + + private func checkContextLength() throws { + let n_ctx_train = llama_n_ctx_train(model) + if n_ctx > n_ctx_train { + throw NexaSwiftError.others("Model was trained on \(n_ctx_train) context but tokens \(n_ctx) specified") + } + } + + private func getModelDetails() -> (arch: String, modelType: String, modelFtype: String) { + let bufSize = 256 + var buf = [CChar](repeating: 0, count: bufSize) + let result = llama_model_desc(model, &buf, bufSize) + + if result > 0 { + let modelDesc = String(cString: buf) + let components = modelDesc.components(separatedBy: " ") + let arch = components[0] ?? "Unknown" + let modelType = components[1] ?? "Unknown" + let modelFtype = components[2] ?? "Unknown" + return (arch, modelType, modelFtype) + } else { + return ("Unknown", "Unknown", "Unknown") + } + } + + func start(for prompt: String) throws { +// print("arch: \(arch), modelType: \(modelType), modelFtype: \(modelFtype)") + updateSampler() + ended = false + tokens = tokenize(text: prompt, addBos: true) + + // Check for token length + if tokens.count > n_ctx { + let originalCount = tokens.count + tokens = Array(tokens.prefix(Int(n_ctx))) + print(""" + WARNING: Input tokens (\(originalCount)) exceed context length (\(n_ctx)). + Truncating to first \(n_ctx) tokens. Some content at the end will be ignored. + Consider splitting your input into smaller chunks for better results. + """) + } + + temporaryInvalidCChars = [] + batch.clear() + + tokens.enumerated().forEach { index, token in + batch.add(token: token, position: Int32(index), seqIDs: [0], logit: false) + } + batch.logits[Int(batch.n_tokens) - 1] = 1 + + if llama_decode(context, batch) != 0 { + throw NexaSwiftError.decodeError + } + generatedTokenAccount = 0 + totalTokensProcessed = batch.n_tokens + } + + func `continue`() throws -> String { + if totalTokensProcessed >= n_ctx { + print("WARNING: Reached maximum context length (\(n_ctx)). Stopping generation.") + temporaryInvalidCChars.removeAll() + ended = true + return "" + } + + let newToken = llama_sampler_sample(sampler, context, batch.n_tokens - 1) + + if llama_token_is_eog(model, newToken) { + temporaryInvalidCChars.removeAll() + ended = true + return "" + } + + + let newTokenCChars = tokenToCChars(token: newToken) + temporaryInvalidCChars.append(contentsOf: newTokenCChars) + + let newTokenStr: String + if let validString = String(validating: temporaryInvalidCChars + [0], as: UTF8.self) { + newTokenStr = validString + temporaryInvalidCChars.removeAll() + } else if let suffixIndex = temporaryInvalidCChars.firstIndex(where: { $0 != 0 }), + let validSuffix = String(validating: Array(temporaryInvalidCChars.suffix(from: suffixIndex)) + [0], + as: UTF8.self) { + newTokenStr = validSuffix + temporaryInvalidCChars.removeAll() + } else { + newTokenStr = "" + } + + batch.clear() + batch.add(token: newToken, position: totalTokensProcessed, seqIDs: [0], logit: true) + generatedTokenAccount += 1 + totalTokensProcessed += 1 + + if llama_decode(context, batch) != 0 { + throw NexaSwiftError.decodeError + } + return newTokenStr.filter { $0 != "\0" } + } + + private func tokenToCChars(token: llama_token) -> [CChar] { + var length: Int32 = 8 + var piece = Array(repeating: 0, count: Int(length)) + + let nTokens = llama_token_to_piece(model, token, &piece, length, 0, false) + if nTokens >= 0 { + return Array(piece.prefix(Int(nTokens))) + } else { + length = -nTokens + piece = Array(repeating: 0, count: Int(length)) + let nNewTokens = llama_token_to_piece(model, token, &piece, length, 0, false) + return Array(piece.prefix(Int(nNewTokens))) + } + } + + private func tokenize(text: String, addBos: Bool) -> [Token] { + let utf8Count = text.utf8.count + let n_tokens = utf8Count + (addBos ? 1 : 0) + 1 + + return Array(unsafeUninitializedCapacity: n_tokens) { buffer, initializedCount in + initializedCount = Int( + llama_tokenize(model, text, Int32(utf8Count), buffer.baseAddress, Int32(n_tokens), addBos, false) + ) + } + } + + func clear() { + tokens.removeAll() + temporaryInvalidCChars.removeAll() + llama_kv_cache_clear(context) + } + + deinit { + llama_batch_free(batch) + llama_free(context) + llama_free_model(model) + llama_backend_free() + } +} diff --git a/swift/Sources/NexaSwift/Models/Batch.swift b/swift/Sources/NexaSwift/Models/Batch.swift new file mode 100644 index 00000000..ca784716 --- /dev/null +++ b/swift/Sources/NexaSwift/Models/Batch.swift @@ -0,0 +1,23 @@ +import Foundation +import llama + +extension Batch { + mutating func clear() { + self.n_tokens = 0 + } + + mutating func add(token: Token, + position: Position, + seqIDs: [SeqID], + logit: Bool) { + let nextIndex = Int(n_tokens) + self.token[nextIndex] = token + self.pos[nextIndex] = position + self.n_seq_id[nextIndex] = Int32(seqIDs.count) + seqIDs.enumerated().forEach { index, id in + seq_id[nextIndex]?[index] = id + } + self.logits[nextIndex] = logit ? 1 : 0 + self.n_tokens += 1 + } +} diff --git a/swift/Sources/NexaSwift/Models/ChatCompletionMessage.swift b/swift/Sources/NexaSwift/Models/ChatCompletionMessage.swift new file mode 100644 index 00000000..bae89664 --- /dev/null +++ b/swift/Sources/NexaSwift/Models/ChatCompletionMessage.swift @@ -0,0 +1,517 @@ +import Foundation + + +public struct ChatCompletionRequestSystemMessage: Codable { + public var role: Role = .system + public let content: String? + + public init(content: String?) { + self.content = content + } +} + +public struct ChatCompletionRequestUserMessage: Codable { + public var role: Role = .user + public let content: UserMessageContent + + public init(content: UserMessageContent) { + self.content = content + } +} + +public enum UserMessageContent: Codable { + case text(String) + case image(ImageContent) + + enum CodingKeys: String, CodingKey { + case type, text, imageUrl + } + + enum ContentType: String, Codable { + case text + case imageUrl + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(ContentType.self, forKey: .type) + + switch type { + case .text: + let text = try container.decode(String.self, forKey: .text) + self = .text(text) + case .imageUrl: + let imageUrl = try container.decode(ImageContent.self, forKey: .imageUrl) + self = .image(imageUrl) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case .text(let text): + try container.encode(ContentType.text, forKey: .type) + try container.encode(text, forKey: .text) + case .image(let imageUrl): + try container.encode(ContentType.imageUrl, forKey: .type) + try container.encode(imageUrl, forKey: .imageUrl) + } + } +} + +public struct ImageContent: Codable { + public let url: String + public let detail: String? + + public init(url: String, detail: String? = nil) { + self.url = url + self.detail = detail + } +} + +public struct ChatCompletionRequestAssistantMessage: Codable { + public var role: Role = .assistant + public let content: String? + public let toolCalls: [ChatCompletionMessageToolCall]? + public let functionCall: ChatCompletionRequestAssistantMessageFunctionCall? + + public init(content: String?, toolCalls: [ChatCompletionMessageToolCall]? = nil, functionCall: ChatCompletionRequestAssistantMessageFunctionCall? = nil) { + self.content = content + self.toolCalls = toolCalls + self.functionCall = functionCall + } +} + +public struct ChatCompletionRequestToolMessage: Codable { + public var role: Role = .tool + public let content: String? + public let toolCallID: String + + public init(content: String?, toolCallID: String) { + self.content = content + self.toolCallID = toolCallID + } +} + +public struct ChatCompletionRequestFunctionMessage: Codable { + public var role: Role = .function + public let content: String? + public let name: String + + public init(content: String?, name: String) { + self.content = content + self.name = name + } +} + +public struct ChatCompletionRequestAssistantMessageFunctionCall: Codable { + public let name: String + public let arguments: String + + public init(name: String, arguments: String) { + self.name = name + self.arguments = arguments + } +} + + +class ChatFormatterRegistry { + private var formatters = [String: ChatFormatter]() + + init() { + register(name: ChatCompletionModel.octopusv2.rawValue, formatter: OctopusV2Formatter()) + register(name: ChatCompletionModel.llama.rawValue, formatter: LlamaFormatter()) + register(name: ChatCompletionModel.llama3.rawValue, formatter: Llama3Formatter()) + register(name: ChatCompletionModel.gemma.rawValue, formatter: GemmaFormatter()) + register(name: ChatCompletionModel.qwen.rawValue, formatter: QwenFormatter()) + register(name: ChatCompletionModel.mistral.rawValue, formatter: MistralFormatter()) + } + + func register(name: String, formatter: ChatFormatter) { + formatters[name] = formatter + } + + func getFormatter(name: String?) -> ChatFormatter? { + return formatters[getFormatterName(name: name)] + } + + func getFormatterName(name: String?) -> String { + return name ?? ChatCompletionModel.llama.rawValue + } +} + +//formatter +public struct ChatFormatterResponse { + let prompt: String + let stop: [String] +} + +public protocol ChatFormatter { + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse +} + + +class OctopusV2Formatter: ChatFormatter { + private let systemMessage = """ + Below is the query from the users, please call the correct function and generate the parameters to call the function. + + """ + private let separator = "\n\n" + + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse { + var formattedMessages = mapRoles(messages: messages) + + // Assuming the last message should be the assistant's response + formattedMessages.append(("Response:", nil)) + + var prompt = systemMessage + for (role, content) in formattedMessages { + if let content = content { + prompt += "\(role) \(content.trimmingCharacters(in: .whitespacesAndNewlines))\(separator)" + } else { + prompt += "\(role) " + } + } + + return ChatFormatterResponse(prompt: prompt.trimmingCharacters(in: .whitespacesAndNewlines), stop: [separator]) + } + + private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] { + var mappedMessages = [(String, String?)]() + let roleMapping: [Role: String] = [ + .user: "Query:", + .assistant: "Response:" + ] + + for message in messages { + var rolePrefix = "" + var content: String? = nil + + switch message { + case .system(let systemMessage): + // Include system message if necessary + continue + case .user(let userMessage): + rolePrefix = roleMapping[.user] ?? "Query:" + switch userMessage.content { + case .text(let text): + content = text + case .image(let imageContent): + content = imageContent.detail ?? imageContent.url + } + case .assistant(let assistantMessage): + rolePrefix = roleMapping[.assistant] ?? "Response:" + content = assistantMessage.content + case .tool(let toolMessage): + rolePrefix = "Tool:" + content = toolMessage.content + case .function(let functionMessage): + rolePrefix = "Function:" + content = functionMessage.content + } + + mappedMessages.append((rolePrefix, content)) + } + + return mappedMessages + } +} + + +//https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ +class LlamaFormatter: ChatFormatter { + private let systemTemplate = "<>\n{system_message}\n<>\n\n" + private let roles: [String: String] = [ + "user": "[INST] ", + "assistant": " [/INST] " + ] + private let endToken = "" + + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse { + let formattedMessages = mapRoles(messages: messages) + let systemMessage = getSystemMessage(messages) + let formattedSystemMessage = systemMessage.map { msg in + systemTemplate.replacingOccurrences(of: "{system_message}", with: msg) + } + let prompt = formatPrompt(systemMessage: formattedSystemMessage, messages: formattedMessages) + return ChatFormatterResponse(prompt: prompt, stop: [endToken]) + } + + private func getSystemMessage(_ messages: [ChatCompletionRequestMessage]) -> String? { + for message in messages { + if case .system(let systemMessage) = message { + return systemMessage.content + } + } + return nil + } + + private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] { + return messages.compactMap { message in + switch message { + case .system: + return nil + case .user(let userMessage): + let content: String? + switch userMessage.content { + case .text(let text): + content = text + case .image(let imageContent): + content = imageContent.detail + } + return (roles["user"] ?? "", content) + case .assistant(let assistantMessage): + return (roles["assistant"] ?? "", assistantMessage.content) + case .tool, .function: + return nil + } + } + } + + private func formatPrompt(systemMessage: String?, messages: [(String, String?)]) -> String { + var conversations: [String] = [] + var currentConversation = "" + + for (index, (role, content)) in messages.enumerated() { + if index % 2 == 0 { // User message + if !currentConversation.isEmpty { + conversations.append(currentConversation + " " + endToken) + } + currentConversation = role // [INST] + if index == 0 && systemMessage != nil { + currentConversation += systemMessage! + content! + } else { + currentConversation += content ?? "" + } + } else { // Assistant message + if let content = content { + currentConversation += role + content // [/INST] response + } + } + } + + // Add the last conversation if it's a user message without response + if messages.count % 2 != 0 { + currentConversation += roles["assistant"]! + conversations.append(currentConversation) + } else if !currentConversation.isEmpty { + conversations.append(currentConversation + endToken) + } + + return conversations.joined(separator: "\n") + } +} + +//https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/ +class Llama3Formatter: ChatFormatter { + private let roles: [String: String] = [ + "system": "<|start_header_id|>system<|end_header_id|>\n\n", + "user": "<|start_header_id|>user<|end_header_id|>\n\n", + "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n" + ] + private let endToken = "<|eot_id|>" + + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse { + var formattedMessages = mapRoles(messages: messages) + + formattedMessages.append((roles["assistant"] ?? "", nil)) + + let prompt = formatPrompt(formattedMessages) + + return ChatFormatterResponse(prompt: prompt, stop: [endToken]) + } + + private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] { + return messages.map { message in + var rolePrefix = "" + var content: String? = "" + + switch message { + case .system(let systemMessage): + rolePrefix = roles["system"] ?? "" + content = systemMessage.content + case .user(let userMessage): + rolePrefix = roles["user"] ?? "" + switch userMessage.content { + case .text(let text): + content = text + case .image(let imageContent): + content = imageContent.detail + } + case .assistant(let assistantMessage): + rolePrefix = roles["assistant"] ?? "" + content = assistantMessage.content + case .tool(let toolMessage): + rolePrefix = roles["tool"] ?? "" + content = toolMessage.content + case .function(let functionMessage): + rolePrefix = roles["function"] ?? "" + content = functionMessage.content + } + + return (rolePrefix, content) + } + } + + private func formatPrompt(_ formattedMessages: [(String, String?)]) -> String { + var prompt = "<|begin_of_text|>" + for (role, content) in formattedMessages { + if let content = content { + prompt += "\(role)\(content.trimmingCharacters(in: .whitespacesAndNewlines))\(endToken)" + } else { + prompt += "\(role) " + } + } + return prompt.trimmingCharacters(in: .whitespacesAndNewlines) + } +} + + +//https://ai.google.dev/gemma/docs/formatting +class GemmaFormatter: ChatFormatter { + private let roles: [String: String] = [ + "user": "user\n", + "assistant": "model\n" + ] + + private let endToken = "" + private let separator = "\n" + + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse { + var formattedMessages = mapRoles(messages: messages) + formattedMessages.append((roles["assistant"]!, nil)) + let prompt = formatPrompt(formattedMessages) + + return ChatFormatterResponse(prompt: prompt, stop: [endToken]) + } + + private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] { + return messages.compactMap { message in + switch message { + case .system: + return nil + case .user(let userMessage): + let content: String? + switch userMessage.content { + case .text(let text): + content = text + case .image(let imageContent): + content = imageContent.detail + } + return (roles["user"] ?? "", content) + case .assistant(let assistantMessage): + return (roles["assistant"] ?? "", assistantMessage.content) + case .tool, .function: + return nil + } + } + } + + private func formatPrompt(_ formattedMessages: [(String, String?)]) -> String { + var prompt = "" + + for (index, (role, content)) in formattedMessages.enumerated() { + if index == formattedMessages.count - 1 { + prompt += role + } else if let content = content { + prompt += "\(role)\(content)\(separator)" + } + } + return prompt.trimmingCharacters(in: .whitespacesAndNewlines) + } +} + +// https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens-chat-template +class QwenFormatter: ChatFormatter { + private let roles: [String: String] = [ + "user": "<|im_start|>user", + "assistant": "<|im_start|>assistant" + ] + + private let systemTemplate = "<|im_start|>system\n{system_message}" + private let defaultSystemMessage = "You are a helpful assistant." + private let separator = "<|im_end|>" + private let endToken = "<|endoftext|>" + + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse { + let systemMessage = formatSystemMessage() + var formattedMessages = mapRoles(messages: messages) + formattedMessages.append((roles["assistant"]!, nil)) + let prompt = formatChatML(systemMessage: systemMessage, messages: formattedMessages) + return ChatFormatterResponse(prompt: prompt, stop: [endToken]) + } + + private func formatSystemMessage() -> String { + return systemTemplate.replacingOccurrences(of: "{system_message}", with: defaultSystemMessage) + } + + private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] { + return messages.compactMap { message in + switch message { + case .user(let userMessage): + let content: String? + switch userMessage.content { + case .text(let text): + content = text + case .image(let imageContent): + content = imageContent.detail + } + return (roles["user"]!, content) + case .assistant(let assistantMessage): + return (roles["assistant"]!, assistantMessage.content) + case .system, .tool, .function: + return nil + } + } + } + + private func formatChatML(systemMessage: String, messages: [(String, String?)]) -> String { + var prompt = systemMessage.isEmpty ? "" : "\(systemMessage)\(separator)\n" + for (role, content) in messages { + if let content = content { + prompt += "\(role)\n\(content)\(separator)\n" + } else { + prompt += "\(role)\n" + } + } + return prompt.trimmingCharacters(in: .whitespacesAndNewlines) + } +} + +// https://www.promptingguide.ai/models/mistral-7b#chat-template-for-mistral-7b-instruct +class MistralFormatter: ChatFormatter { + private let endToken = "" + private let conversationStart = "" + private let instructStart = "[INST] " + private let instructEnd = " [/INST] " + + func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse { + var prompt = conversationStart // Add only once at the start + + for (index, message) in messages.enumerated() { + switch message { + case .user(let userMessage): + switch userMessage.content { + case .text(let text): + prompt += "\(instructStart)\(text)" + case .image: + continue + } + + case .assistant(let assistantMessage): + if let content = assistantMessage.content { + prompt += "\(instructEnd)\(content)\(endToken)" + } + default: + continue + } + } + + // Add instructEnd if the last message was from user (waiting for AI response) + if messages.last.map({ if case .user = $0 { return true } else { return false } }) ?? false { + prompt += instructEnd + } + + return ChatFormatterResponse(prompt: prompt, stop: [endToken]) + } +} diff --git a/swift/Sources/NexaSwift/Models/ChatCompletionResponse.swift b/swift/Sources/NexaSwift/Models/ChatCompletionResponse.swift new file mode 100644 index 00000000..396dba5c --- /dev/null +++ b/swift/Sources/NexaSwift/Models/ChatCompletionResponse.swift @@ -0,0 +1,91 @@ +import Foundation + + +public struct ChatCompletionMessageToolCallFunction: Codable { + public let name: String + public let arguments: String + + public init(name: String, arguments: String) { + self.name = name + self.arguments = arguments + } +} + +public struct ChatCompletionMessageToolCall: Codable { + public let id: String + public var type: Role = .function + public let function: ChatCompletionMessageToolCallFunction + + public init(id: String, function: ChatCompletionMessageToolCallFunction) { + self.id = id + self.function = function + } +} + +public struct ChatCompletionResponseFunctionCall:Codable{ + public let name: String + public let arguments: String +} + +public struct ChatCompletionResponseMessage: Codable{ + public let content: String? + public let toolCalls: [ChatCompletionMessageToolCall]? + public let role: String? + public let functionCall: ChatCompletionResponseFunctionCall? +} + +public struct ChatCompletionResponseChoice: Codable{ + public let index: Int + public let message: ChatCompletionResponseMessage + public let logprobs: CompletionLogprobs? + public let finishReason: FinishReason? +} + +public struct ChatCompletionResponse: Codable { + public let id: String + public let object: String + public let created: Int + public let model: String + public let choices: [ChatCompletionResponseChoice] + public let usage: CompletionUsage? + + enum CodingKeys: String, CodingKey { + case id + case object + case created + case model + case choices + case usage + } +} + +public struct ChatCompletionStreamResponseDelta: Codable { + public var content: String? + public var functionCall: ChatCompletionStreamResponseDeltaFunctionCall? // DEPRECATED + public var toolCalls: [ChatCompletionMessageToolCallChunk]? + public var role: Role? + +} + +public struct ChatCompletionStreamResponseDeltaFunctionCall: Codable { + +} + +public struct ChatCompletionMessageToolCallChunk: Codable { + +} + +public struct ChatCompletionStreamResponseChoice: Codable { + public var index: Int + public var delta: ChatCompletionStreamResponseDelta + public var finishReason: FinishReason? + public var logprobs: CompletionLogprobs? +} + +public struct CreateChatCompletionStreamResponse: Codable { + public var id: String + public var model: String + public var object: String + public var created: Int + public var choices: [ChatCompletionStreamResponseChoice] +} diff --git a/swift/Sources/NexaSwift/Models/Common.swift b/swift/Sources/NexaSwift/Models/Common.swift new file mode 100644 index 00000000..56cb371c --- /dev/null +++ b/swift/Sources/NexaSwift/Models/Common.swift @@ -0,0 +1,28 @@ +public enum Role: String, Codable { + case system + case user + case assistant + case tool + case function +} + +public enum ChatCompletionRequestMessage { + case system(ChatCompletionRequestSystemMessage) + case user(ChatCompletionRequestUserMessage) + case assistant(ChatCompletionRequestAssistantMessage) + case tool(ChatCompletionRequestToolMessage) + case function(ChatCompletionRequestFunctionMessage) +} + +public enum FinishReason: String, Codable { + case stop, length, toolCalls = "tool_calls", functionCall = "function_call" +} + +public enum ChatCompletionModel: String, Codable { + case octopusv2 + case llama + case llama3 + case gemma + case qwen + case mistral +} diff --git a/swift/Sources/NexaSwift/Models/CompletionResponse.swift b/swift/Sources/NexaSwift/Models/CompletionResponse.swift new file mode 100644 index 00000000..b3d16519 --- /dev/null +++ b/swift/Sources/NexaSwift/Models/CompletionResponse.swift @@ -0,0 +1,52 @@ +import Foundation + +public struct CompletionUsage: Codable { + public let promptTokens: Int + public let completionTokens: Int + public let totalTokens: Int + + enum CodingKeys: String, CodingKey { + case promptTokens = "prompt_tokens" + case completionTokens = "completion_tokens" + case totalTokens = "total_tokens" + } +} + +public struct CompletionLogprobs: Codable { + public let textOffset: [Int]? + public let tokenLogprobs: [Float?]? + public let tokens: [String]? + public let topLogprobs: [Dictionary?]? +} + +public struct CompletionChoice: Codable { + public let text: String + public let index: Int + public let logprobs: CompletionLogprobs? + public let finishReason: FinishReason? + + enum CodingKeys: String, CodingKey { + case text + case index + case logprobs + case finishReason = "finish_reason" + } +} + +public struct CompletionResponse: Codable { + public let id: String + public let object: String + public let created: Int + public let model: String + public let choices: [CompletionChoice] + public let usage: CompletionUsage? + + enum CodingKeys: String, CodingKey { + case id + case object + case created + case model + case choices + case usage + } +} diff --git a/swift/Sources/NexaSwift/Models/Configuration.swift b/swift/Sources/NexaSwift/Models/Configuration.swift new file mode 100644 index 00000000..b8c3757c --- /dev/null +++ b/swift/Sources/NexaSwift/Models/Configuration.swift @@ -0,0 +1,53 @@ +import Foundation +import llama + +public struct Configuration { + public var nTokens:Int32 + public var embd: Int32 + public var nSeqMax: Int32 + public var seed: UInt32 + public var topK: Int32 + public var topP: Float + public var nCTX: Int + public var temperature: Float + public var maxNewToken: Int + public var batchSize: Int + public var stopTokens: [String] + + public init( + nTokens:Int32 = 2048, + embd:Int32 = 512, + nSeqMax:Int32 = 2, + seed: UInt32 = 1234, + topK: Int32 = 50, + topP: Float = 1.0, + nCTX: Int = 2048, + temperature: Float = 0.7, + batchSize: Int = 2048, + stopSequence: String? = nil, + maxNewToken: Int = 128, + stopTokens: [String] = []) { + self.nTokens = nTokens + self.embd = embd + self.nSeqMax = nSeqMax + self.seed = seed + self.topK = topK + self.topP = topP + self.nCTX = nCTX + self.batchSize = batchSize + self.temperature = temperature + self.maxNewToken = maxNewToken + self.stopTokens = stopTokens + } +} + +extension Configuration { + var contextParameters: ContextParameters { + var params = llama_context_default_params() + let processorCount = max(1, min(16, ProcessInfo.processInfo.processorCount - 2)) + params.n_ctx = max(8, UInt32(self.nCTX)) // minimum context size is 8 + params.n_threads = Int32(processorCount) + params.n_threads_batch = Int32(processorCount) + return params + } +} diff --git a/swift/Sources/NexaSwift/Models/SwiftLlamaError.swift b/swift/Sources/NexaSwift/Models/SwiftLlamaError.swift new file mode 100644 index 00000000..37c04296 --- /dev/null +++ b/swift/Sources/NexaSwift/Models/SwiftLlamaError.swift @@ -0,0 +1,6 @@ +import Foundation + +public enum NexaSwiftError: Error { + case decodeError + case others(String) +} diff --git a/swift/Sources/NexaSwift/Models/TypeAlias.swift b/swift/Sources/NexaSwift/Models/TypeAlias.swift new file mode 100644 index 00000000..4711efc8 --- /dev/null +++ b/swift/Sources/NexaSwift/Models/TypeAlias.swift @@ -0,0 +1,10 @@ +import Foundation +import llama + +typealias Batch = llama_batch +typealias Model = OpaquePointer +typealias Context = OpaquePointer +typealias Token = llama_token +typealias Position = llama_pos +typealias SeqID = llama_seq_id +typealias ContextParameters = llama_context_params diff --git a/swift/Sources/NexaSwift/NexaSwiftActor.swift b/swift/Sources/NexaSwift/NexaSwiftActor.swift new file mode 100644 index 00000000..b295dc71 --- /dev/null +++ b/swift/Sources/NexaSwift/NexaSwiftActor.swift @@ -0,0 +1,6 @@ +import Foundation + +@globalActor +public actor NexaSwiftActor { + public static let shared = NexaSwiftActor() +} diff --git a/swift/Sources/NexaSwift/NexaTextInference.swift b/swift/Sources/NexaSwift/NexaTextInference.swift new file mode 100644 index 00000000..7c137bc7 --- /dev/null +++ b/swift/Sources/NexaSwift/NexaTextInference.swift @@ -0,0 +1,362 @@ +import Foundation +import llama +import Combine + +public class NexaTextInference { + private let model: LlamaModel + private let modelPath: String + private var generatedTokenCache = "" + private var contentStarted = false + private let chatFormatterRegistry: ChatFormatterRegistry + + var maxLengthOfStopToken: Int { + model.configuration.stopTokens.map { $0.count }.max() ?? 0 + } + + public init(modelPath: String, + modelConfiguration: Configuration = .init()) throws { + if modelPath.isEmpty { + throw NSError(domain: "InvalidParameterError", code: 400, userInfo: [NSLocalizedDescriptionKey: "Either modelPath or localPath must be provided."]) + } + self.model = try LlamaModel(path: modelPath, configuration: modelConfiguration) + self.modelPath = modelPath + self.chatFormatterRegistry = .init() + } + + private func updateConfiguration( + temperature: Float?, + maxNewToken: Int?, + topK: Int32?, + topP: Float?, + stopTokens: [String]? + ) { + if let temperature = temperature { + model.configuration.temperature = temperature + } + if let maxNewToken = maxNewToken { + model.configuration.maxNewToken = maxNewToken + } + if let topK = topK { + model.configuration.topK = topK + } + if let topP = topP { + model.configuration.topP = topP + } + if let stopTokens = stopTokens { + model.configuration.stopTokens = stopTokens + } + } + + private func getFormatterForModel() -> ChatFormatter? { + let modelArch = model.arch.lowercased() + let lowerModelPath = modelPath.lowercased() + + let modelType: ChatCompletionModel? = { + switch modelArch { + case _ where modelArch.contains("gemma"): + // For Gemma-based models, check the model path + if lowerModelPath.contains("octopus-v2") || lowerModelPath.contains("octopusv2") { + return .octopusv2 + } else { + return .gemma + } + case _ where modelArch.contains("qwen"): + return .qwen + case _ where modelArch.contains("llama"): + // For Llama-based models, check the model path + if lowerModelPath.contains("llama-2") || lowerModelPath.contains("llama2") { + return .llama + } else if lowerModelPath.contains("llama-3") || lowerModelPath.contains("llama3") { + return .llama3 + } else if lowerModelPath.contains("mistral") { + return .mistral + } else { + // If can't determine specific version, default to Llama2 + print("Warning: Unable to determine specific Llama model version from path: \(modelPath). Defaulting to Llama2 format.") + return .llama + } + default: + print("Warning: Unknown model architecture: \(modelArch). Defaulting to Llama2 format.") + return .llama + } + }() + + return chatFormatterRegistry.getFormatter(name: modelType?.rawValue) + } + + private func isStopToken() -> Bool { + model.configuration.stopTokens.reduce(false) { partialResult, stopToken in + generatedTokenCache.hasSuffix(stopToken) + } + } + + private func response(for prompt: String, output: (String) -> Void, finish: () -> Void) { + func finaliseOutput() { + model.configuration.stopTokens.forEach { + generatedTokenCache = generatedTokenCache.replacingOccurrences(of: $0, with: "") + } + output(generatedTokenCache) + finish() + generatedTokenCache = "" + } + defer { model.clear() } + do { + try model.start(for: prompt) + while model.shouldContinue { + var delta = try model.continue() + if contentStarted { // remove the prefix empty spaces + if needToStop(after: delta, output: output) { + finish() + break + } + } else { + delta = delta.trimmingCharacters(in: .whitespacesAndNewlines) + if !delta.isEmpty { + contentStarted = true + if needToStop(after: delta, output: output) { + finish() + break + } + } + } + } + finaliseOutput() + } catch { + finaliseOutput() + } + } + + private func needToStop(after delta: String, output: (String) -> Void) -> Bool { + guard maxLengthOfStopToken > 0 else { + output(delta) + return false + } + generatedTokenCache += delta + + if generatedTokenCache.count >= maxLengthOfStopToken * 2 { + if let stopToken = model.configuration.stopTokens.first(where: { generatedTokenCache.contains($0) }), + let index = generatedTokenCache.range(of: stopToken) { + let outputCandidate = String(generatedTokenCache[.. AsyncThrowingStream { + return .init { continuation in + Task { + response(for: prompt) { [weak self] delta in + continuation.yield(delta) + } finish: { [weak self] in + continuation.finish() + } + } + } + } + + @NexaSwiftActor + public func createCompletion( + for prompt: String, + temperature: Float? = nil, + maxNewToken: Int? = nil, + topK: Int32? = nil, + topP: Float? = nil, + stopTokens: [String]? = nil) async throws -> CompletionResponse { + updateConfiguration( + temperature: temperature, + maxNewToken: maxNewToken, + topK: topK, + topP: topP, + stopTokens: stopTokens + ) + model.reset() + var result = "" + for try await value in await run(for: prompt) { + result += value + } + + let completionResponse = CompletionResponse( + id: UUID().uuidString, + object: "text_completion", + created: Int(Date().timeIntervalSince1970), + model: "", + choices: [ + CompletionChoice( + text: result, + index: 0, + logprobs: nil, + finishReason: FinishReason.stop + ) + ], + usage: CompletionUsage( + promptTokens: 0, + completionTokens: 0, + totalTokens: 0 + ) + ) + return completionResponse + } + + @NexaSwiftActor + public func createCompletionStream( + for prompt: String, + temperature: Float? = nil, + maxNewToken: Int? = nil, + topK: Int32? = nil, + topP: Float? = nil, + stopTokens: [String]? = nil) -> AsyncThrowingStream { + updateConfiguration( + temperature: temperature, + maxNewToken: maxNewToken, + topK: topK, + topP: topP, + stopTokens: stopTokens + ) + model.reset() + return .init { continuation in + Task { + var index = 0 + response(for: prompt) { text in + let completionResponse = CompletionResponse( + id: UUID().uuidString, + object: "text_completion", + created: Int(Date().timeIntervalSince1970), + model: "", + choices: [ + CompletionChoice( + text: text, + index: 0, + logprobs: nil, + finishReason: FinishReason.stop + ) + ], + usage: CompletionUsage( + promptTokens: 0, + completionTokens: 0, + totalTokens: 0 + ) + ) + + index += 1 + continuation.yield(completionResponse) + } finish: { + continuation.finish() + } + } + } + } + + @NexaSwiftActor + public func createChatCompletion( + for messages: [ChatCompletionRequestMessage], + temperature: Float? = nil, + maxNewToken: Int? = nil, + topK: Int32? = nil, + topP: Float? = nil, + stopTokens: [String]? = nil, + modelType: ChatCompletionModel? = nil) async throws -> ChatCompletionResponse { + let formatter = modelType.map { chatFormatterRegistry.getFormatter(name: $0.rawValue) } ?? getFormatterForModel() + let chatFormatter: ChatFormatterResponse? = formatter?.format(messages: messages) + // let chatFormatter: ChatFormatterResponse? = chatFormatterRegistry.getFormatter(name: modelType?.rawValue)?.format(messages: messages) ?? nil + updateConfiguration( + temperature: temperature, + maxNewToken: maxNewToken, + topK: topK, + topP: topP, + stopTokens: stopTokens ?? (!model.configuration.stopTokens.isEmpty ? model.configuration.stopTokens : chatFormatter?.stop) ?? nil + ) + model.reset() + + var result = "" + for try await value in await run(for: chatFormatter?.prompt ?? "") { + result += value + } + + let response = ChatCompletionResponse( + id: UUID().uuidString, + object: "chat.completion", + created: Int(Date().timeIntervalSince1970), + model: chatFormatterRegistry.getFormatterName(name: modelType?.rawValue), + choices: [ + ChatCompletionResponseChoice( + index: 0, + message: ChatCompletionResponseMessage( + content: result, + toolCalls: nil, + role: nil, + functionCall: nil + ), + logprobs: nil, + finishReason: FinishReason.stop + ) + ], + usage: nil + ) + + return response + } + + @NexaSwiftActor + public func createChatCompletionStream( + for messages: [ChatCompletionRequestMessage], + temperature: Float? = nil, + maxNewToken: Int? = nil, + topK: Int32? = nil, + topP: Float? = nil, + stopTokens: [String]? = nil, + modelType: ChatCompletionModel? = nil + ) -> AsyncThrowingStream { + model.reset() + let formatter = modelType.map { chatFormatterRegistry.getFormatter(name: $0.rawValue) } ?? getFormatterForModel() + let chatFormatter: ChatFormatterResponse? = formatter?.format(messages: messages) +// let chatFormatter: ChatFormatterResponse? = chatFormatterRegistry.getFormatter(name: modelType?.rawValue)?.format(messages: messages) ?? nil + updateConfiguration( + temperature: temperature, + maxNewToken: maxNewToken, + topK: topK, + topP: topP, + stopTokens: stopTokens ?? (!model.configuration.stopTokens.isEmpty ? model.configuration.stopTokens : chatFormatter?.stop) ?? nil + ) + return .init { continuation in + Task { + var index = 0 + response(for: chatFormatter?.prompt ?? "") { text in + let response = CreateChatCompletionStreamResponse( + id: UUID().uuidString, + model: chatFormatterRegistry.getFormatterName(name: modelType?.rawValue), + object: "chat.completion.chunk", + created: Int(Date().timeIntervalSince1970), + choices: [ + ChatCompletionStreamResponseChoice( + index: index, + delta: ChatCompletionStreamResponseDelta( + content: text, + functionCall: nil, + toolCalls: nil, + role: nil + ), + finishReason: FinishReason.stop, + logprobs: nil + ) + ] + ) + + index += 1 + continuation.yield(response) + } finish: { + continuation.finish() + } + } + } + } +} diff --git a/swift/Tests/NexaSwiftTests/NexaSwiftTests.swift b/swift/Tests/NexaSwiftTests/NexaSwiftTests.swift new file mode 100644 index 00000000..9661ba10 --- /dev/null +++ b/swift/Tests/NexaSwiftTests/NexaSwiftTests.swift @@ -0,0 +1,6 @@ +import XCTest +@testable import NexaSwift + +final class NexaSwiftTests: XCTestCase { + +} diff --git a/tests/test_tts_generation.py b/tests/test_tts_generation.py index a03195d1..2dc9c526 100644 --- a/tests/test_tts_generation.py +++ b/tests/test_tts_generation.py @@ -1,22 +1,24 @@ -from nexa.gguf import NexaTTSInference +# Temporarily disabled since version v0.0.9.3 -def test_tts_generation(): - tts = NexaTTSInference( - model_path="bark-small", - local_path=None, - n_threads=4, - seed=42, - sampling_rate=24000, - verbosity=1 - ) +# from nexa.gguf import NexaTTSInference + +# def test_tts_generation(): +# tts = NexaTTSInference( +# model_path="bark-small", +# local_path=None, +# n_threads=4, +# seed=42, +# sampling_rate=24000, +# verbosity=2 +# ) - # Generate audio from prompt - prompt = "Hello, this is a test of the Bark text to speech system." - audio_data = tts.audio_generation(prompt) +# # Generate audio from prompt +# prompt = "Hello, this is a test of the Bark text to speech system." +# audio_data = tts.audio_generation(prompt) - # Save the generated audio - tts._save_audio(audio_data, tts.sampling_rate, "tts_output") - print("TTS generation test completed successfully!") +# # Save the generated audio +# tts._save_audio(audio_data, tts.sampling_rate, "tts_output") +# print("TTS generation test completed successfully!") -if __name__ == "__main__": - test_tts_generation() \ No newline at end of file +# if __name__ == "__main__": +# test_tts_generation() \ No newline at end of file