We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi,
When I was going to try the demo code, I constantly got the following error:
python: /tmp/pip-install-nx20f2ug/mirage-project_c1ff8981742f441194cc66eb5385f40e/src/search/op_utils.cc:102: std::shared_ptrmirage::search::AlgebraicPattern mirage::search::get_pattern(mirage::type::KNOperatorType, const mirage::kernel::DTensor&, std::shared_ptrmirage::search::AlgebraicPattern): Assertion `false' failed. Aborted (core dumped)
And the native call stack is like this
The code I am running is this one, is there anything wrong in my setting?
import mirage as mi import numpy as np import torch if __name__ == "__main__": graph = mi.new_kernel_graph() X = graph.new_input(dims=(1, 4096), dtype=mi.float16) W = graph.new_input(dims=(4096, 4096), dtype=mi.float16) A = graph.new_input(dims=(4096, 16), dtype=mi.float16) B = graph.new_input(dims=(16, 4096), dtype=mi.float16) D = graph.matmul(X, A) E = graph.matmul(D, B) C = graph.matmul(X, W) O = graph.add(C, E) graph.mark_output(O) optimized_graph = graph.superoptimize(config="lora") input_tensors = [ torch.randn(1, 4096, dtype=torch.float16, device='cuda:0'), torch.randn(4096, 4096, dtype=torch.float16, device='cuda:0'), torch.randn(4096, 16, dtype=torch.float16, device='cuda:0'), torch.randn(16, 4096, dtype=torch.float16, device='cuda:0') ] for _ in range(16): optimized_graph(inputs=input_tensors) torch.cuda.synchronize() starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) starter.record() for _ in range(1000): optimized_graph(inputs=input_tensors) ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) mean_syn = curr_time / 1000 print(mean_syn)
accelerate==1.0.1 aenum==3.1.15 aiofiles==22.1.0 aiohappyeyeballs==2.4.3 aiohttp==3.10.10 aiosignal==1.3.1 aiosqlite==0.20.0 annotated-types==0.7.0 anyio==4.6.2.post1 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.1 async-timeout==4.0.3 attrs==24.2.0 babel==2.16.0 beautifulsoup4==4.12.3 bidict==0.23.1 bleach==6.1.0 bson==0.5.10 build==1.2.2.post1 certifi==2024.2.2 cffi==1.16.0 chardet==5.2.0 charset-normalizer==3.3.2 click==8.1.7 click-option-group==0.5.6 cloudpickle==3.1.0 cmake==3.30.4 colored==2.2.4 coloredlogs==15.0.1 comm==0.2.2 crcmod==1.7 cryptography==38.0.4 cuda-python==12.6.0 Cython==3.0.11 datasets==3.0.1 debugpy==1.8.7 decorator==5.1.1 defusedxml==0.7.1 Deprecated==1.2.14 diffusers==0.30.3 dill==0.3.8 distlib==0.3.8 distro==1.9.0 dnspython==2.6.1 docker-pycreds==0.4.0 einops==0.8.0 entrypoints==0.4 evaluate==0.4.3 exceptiongroup==1.2.2 executing==2.1.0 fastjsonschema==2.20.0 filelock==3.16.1 flash-attn==2.6.3 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.6.1 gitdb==4.0.11 GitPython==3.1.43 grpcio==1.62.2 h11==0.14.0 h5py==3.10.0 httpcore==1.0.6 httpx==0.27.2 huggingface-hub==0.26.1 humanfriendly==10.0 idna==3.7 importlib_metadata==8.5.0 iniconfig==2.0.0 ipaddress==1.0.23 iso8601==1.0.0 isoduration==20.11.0 jedi==0.19.1 Jinja2==3.1.4 json5==0.9.25 jsonpointer==3.0.0 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 markdown-it-py==3.0.0 MarkupSafe==3.0.1 mdurl==0.1.2 mirage-project==0.2.1 mistune==3.0.2 mpi4py==4.0.1 mpmath==1.3.0 msgpack==1.0.8 multidict==6.1.0 multiprocess==0.70.16 nest-asyncio==1.6.0 networkx==3.4.1 ninja==1.11.1.1 notebook==6.5.7 notebook_shim==0.2.4 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-modelopt==0.15.1 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.6.77 nvidia-nvtx-cu12==12.1.105 onnx==1.17.0 openai==1.39.0 optimum==1.23.1 overrides==7.7.0 packaging==24.1 pandas==2.2.3 pandocfilters==1.5.1 parso==0.8.4 pathlib2==2.3.7.post1 pathtools==0.1.2 pexpect==4.8.0 pillow==10.3.0 platformdirs==4.3.6 pluggy==1.5.0 ply==3.11 polygraphy==0.49.9 prometheus_client==0.21.0 promise==2.3 prompt_toolkit==3.0.48 propcache==0.2.0 protobuf==3.20.3 psutil==5.9.8 ptyprocess==0.7.0 PuLP==2.9.0 pure_eval==0.2.3 py==1.11.0 py-spy==0.3.14 pyarrow==17.0.0 pycparser==2.22 pycryptodomex==3.21.0 pydantic==2.9.2 pydantic_core==2.23.4 Pygments==2.18.0 PyJWT==2.8.0 pynvml==11.5.3 pyOpenSSL==22.1.0 pyproject_hooks==1.2.0 pytest==6.2.5 python-consul==1.1.0 python-dateutil==2.9.0.post0 python-engineio==4.9.1 python-etcd==0.4.5 python-json-logger==2.0.7 python-socketio==5.11.4 pytz==2022.5 PyYAML==6.0.1 pyzmq==26.2.0 referencing==0.35.1 regex==2024.9.11 requests==2.32.3 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.9.2 rpds-py==0.20.0 safetensors==0.4.5 schedule==1.2.1 scipy==1.14.1 Send2Trash==1.8.3 sentencepiece==0.2.0 sentry-sdk==2.0.1 setproctitle==1.3.3 shortuuid==1.0.13 simple-websocket==1.0.0 six==1.16.0 smmap==5.0.1 sniffio==1.3.1 soupsieve==2.6 stack-data==0.6.3 StrEnum==0.4.15 sympy==1.13.3 tensorrt==10.4.0 tensorrt-cu12==10.4.0 tensorrt-cu12-bindings==10.4.0 tensorrt-cu12-libs==10.4.0 tensorrt-llm==0.13.0 terminado==0.18.1 timm==1.0.10 tinycss2==1.3.0 tokenizers==0.15.2 toml==0.10.2 tomli==2.0.2 torch==2.4.0 torchvision==0.19.0 tornado==6.4.1 tox==3.28.0 tqdm==4.66.5 traitlets==5.14.3 transformers==4.37.2 triton==3.0.0 types-python-dateutil==2.9.0.20241003 typing_extensions==4.12.2 tzdata==2024.2 uri-template==1.3.0 urllib3==1.26.18 virtualenv==20.26.6 watchdog==5.0.3 wcwidth==0.2.13 webcolors==24.8.0 webencodings==0.5.1 websocket-client==1.8.0 widgetsnbextension==4.0.13 wrapt==1.16.0 wsproto==1.2.0 xxhash==3.5.0 y-py==0.6.2 yarl==1.16.0 ypy-websocket==0.8.4 z3-solver==4.13.3.0 zipp==3.20.2
The text was updated successfully, but these errors were encountered:
It seems I cannot reproduce the issue on my end. Can you confirm that you have built the most recent version of mirage using pip install .?
pip install .
Sorry, something went wrong.
No branches or pull requests
Hi,
When I was going to try the demo code, I constantly got the following error:
python: /tmp/pip-install-nx20f2ug/mirage-project_c1ff8981742f441194cc66eb5385f40e/src/search/op_utils.cc:102: std::shared_ptrmirage::search::AlgebraicPattern mirage::search::get_pattern(mirage::type::KNOperatorType, const mirage::kernel::DTensor&, std::shared_ptrmirage::search::AlgebraicPattern): Assertion `false' failed.
Aborted (core dumped)
And the native call stack is like this
The code I am running is this one, is there anything wrong in my setting?
The text was updated successfully, but these errors were encountered: