From 2125408a29947352dc32ac8f4502ca1443f9d64b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20B=C3=A4rwinkel?= Date: Sat, 25 Nov 2023 19:10:02 +0100 Subject: [PATCH] Add cuda hardware acceleration for textgen --- packages/llama-cpp-python/default.nix | 45 ++++++++++++++------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/packages/llama-cpp-python/default.nix b/packages/llama-cpp-python/default.nix index bcebf399..11b24c2d 100644 --- a/packages/llama-cpp-python/default.nix +++ b/packages/llama-cpp-python/default.nix @@ -1,14 +1,16 @@ -{ buildPythonPackage, fetchFromGitHub, lib, stdenv, darwin, cmake, ninja, pathspec, poetry-core, pyproject-metadata, scikit-build-core, setuptools, diskcache, numpy, typing-extensions }: -let - inherit (stdenv) isDarwin; - osSpecific = with darwin.apple_sdk.frameworks; if isDarwin then [ Accelerate CoreGraphics CoreVideo ] else [ ]; - llama-cpp-pin = fetchFromGitHub { - owner = "ggerganov"; - repo = "llama.cpp"; - rev = "a98b1633d5a94d0aa84c7c16e1f8df5ac21fc850"; - hash = "sha256-HNwyPJXsUL41zLA+90Yu7kCpihW0HBOeW2jDs8sw7qs="; - }; -in +{ buildPythonPackage +, fetchFromGitHub +, lib +, cudaPackages +, cmake +, ninja +, pathspec +, pyproject-metadata +, scikit-build-core +, diskcache +, numpy +, typing-extensions +}: buildPythonPackage rec { pname = "llama-cpp-python"; version = "0.2.7"; @@ -18,26 +20,25 @@ buildPythonPackage rec { owner = "abetlen"; repo = pname; rev = "refs/tags/v${version}"; - hash = "sha256-jL2jVTKwmTx6pSnoN5n4NtQ3hs3weXiQTKFQdjL172U="; + hash = "sha256-2uPWH8ik/YznJTNBCopz58YjDJ7i1l9hgp8t0Nwjm5Q="; + fetchSubmodules = true; }; - preConfigure = '' - cp -r ${llama-cpp-pin}/. ./vendor/llama.cpp - chmod -R +w ./vendor/llama.cpp - ''; - preBuild = '' - cd .. - ''; - buildInputs = osSpecific; + dontUseCmakeConfigure = true; + SKBUILD_CMAKE_ARGS = lib.strings.concatStringsSep ";" [ + "-DLLAMA_CUBLAS=on" + ]; + + buildInputs = [ + cudaPackages.cudatoolkit + ]; nativeBuildInputs = [ cmake ninja pathspec - poetry-core pyproject-metadata scikit-build-core - setuptools ]; propagatedBuildInputs = [