forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_pytorch_amd.py
76 lines (64 loc) · 2.47 KB
/
build_pytorch_amd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import absolute_import, division, print_function
import os
import subprocess
import sys
from functools import reduce
from pyHIPIFY import hipify_python
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
proj_dir = os.path.dirname(os.path.dirname(amd_build_dir))
includes = [
"aten/*",
"torch/*",
]
ignores = [
"aten/src/ATen/core/*",
]
# List of operators currently disabled
yaml_file = os.path.join(amd_build_dir, "disabled_features.yaml")
# Apply patch files in place.
patch_folder = os.path.join(amd_build_dir, "patches")
for filename in os.listdir(os.path.join(amd_build_dir, "patches")):
subprocess.Popen(["git", "apply", os.path.join(patch_folder, filename)], cwd=proj_dir)
# HIPCC Compiler doesn't provide host defines - Automatically include them.
for root, _, files in os.walk(os.path.join(proj_dir, "aten/src/ATen")):
for filename in files:
if filename.endswith(".cu") or filename.endswith(".cuh"):
filepath = os.path.join(root, filename)
# Add the include header!
with open(filepath, "r+") as f:
txt = f.read()
result = '#include "hip/hip_runtime.h"\n%s' % txt
f.seek(0)
f.write(result)
f.truncate()
f.flush()
# Flush to disk
os.fsync(f)
# Make various replacements inside AMD_BUILD/torch directory
ignore_files = ["csrc/autograd/profiler.h", "csrc/autograd/profiler.cpp",
"csrc/cuda/cuda_check.h"]
for root, _directories, files in os.walk(os.path.join(proj_dir, "torch")):
for filename in files:
if filename.endswith(".cpp") or filename.endswith(".h"):
source = os.path.join(root, filename)
# Disabled files
if reduce(lambda result, exclude: source.endswith(exclude) or result, ignore_files, False):
continue
# Update contents.
with open(source, "r+") as f:
contents = f.read()
contents = contents.replace("USE_CUDA", "USE_ROCM")
contents = contents.replace("CUDA_VERSION", "0")
f.seek(0)
f.write(contents)
f.truncate()
f.flush()
os.fsync(f)
hipify_python.hipify(
project_directory=proj_dir,
output_directory=proj_dir,
includes=includes,
ignores=ignores,
yaml_settings=yaml_file,
add_static_casts_option=True,
show_progress=False)