Skip to content

Commit

Permalink
Replace c10::optional with std::optional
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed Apr 9, 2024
1 parent 8d5e257 commit 1cb98f2
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 109 deletions.
15 changes: 9 additions & 6 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def __init__(self, binary_dir):
self.generated_path.mkdir(parents=True, exist_ok=True)

# Create symlink to match doc structure
generated_path = self.backend_path.joinpath("generated").resolve()
if not generated_path.exists():
generated_path.symlink_to(
os.path.relpath(self.generated_path, generated_path.parent),
target_is_directory=True,
)
generated_path = self.backend_path.joinpath("generated")
generated_path.unlink(missing_ok=True)
generated_path.symlink_to(
os.path.relpath(self.generated_path, generated_path.parent),
target_is_directory=True,
)

self.tensor_class = "torch::lazy::LazyTensor"

Expand Down Expand Up @@ -350,7 +350,10 @@ def generate_shape_inference(self):
def extract_signatures(text):
signatures = set()
for name, args in sig_re.findall(text):
# Remove all whitespace from signature
signature = re.sub(r"\s+", "", f"{name}({args})")
# Ignore optional's namespace
signature = re.sub(r":*\w*:*optional", "optional", signature)
global_signatures[signature] = (name, args)
signatures.add(signature)
return signatures
Expand Down
1 change: 1 addition & 0 deletions build_tools/ci/build_posix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
-DLLVM_TARGETS_TO_BUILD=host \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_LTC=ON
echo "::endgroup::"

echo "::group::Build"
Expand Down
2 changes: 0 additions & 2 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,6 @@ function clean_build() {
}

function build_torch_mlir() {
# Disable LTC build for releases to avoid linker issues
export TORCH_MLIR_ENABLE_LTC=0
local torch_version="$1"
case $torch_version in
nightly)
Expand Down
2 changes: 1 addition & 1 deletion projects/ltc/csrc/base_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ TorchMlirBackendImpl::GetComputationDataFromNode(const Node *node) const {

at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
std::optional<at::ScalarType> logical_scalar_type) const {
PRINT_FUNCTION();

TorchMlirBackendData *torch_mlir_data =
Expand Down
4 changes: 2 additions & 2 deletions projects/ltc/csrc/base_lazy_backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TORCH_API TorchMlirBackendData : public BackendData {
public:
struct Info : public BackendData::Info {
at::Tensor tensor;
c10::optional<at::Scalar> scalar;
std::optional<at::Scalar> scalar;
bool requires_grad;
std::string name;

Expand Down Expand Up @@ -111,7 +111,7 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {

virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override;
std::optional<at::ScalarType> logical_scalar_type) const override;

/**
* Lowering, Compilation, Execution
Expand Down
2 changes: 1 addition & 1 deletion projects/ltc/csrc/base_lazy_backend/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct TorchMlirIrBuilder : IrBuilder {
NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) const override { return MakeNode<DeviceData>(data); }
NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode<Scalar>(value, type); }
NodePtr MakeExpand(const Value& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand) const override { return MakeNode<Expand>(input0, size, is_scalar_expand); }
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const std::optional<at::ScalarType>& stype = c10::nullopt) const override { return MakeNode<Cast>(input0, dtype, stype); }
NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode<TorchMlirTensorList>(inputs); }
NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const override { return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); }

Expand Down
46 changes: 23 additions & 23 deletions projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ at::Tensor to_meta(const at::Tensor &tensor) {
return out;
}

c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor> &tensor) {
std::optional<at::Tensor> to_meta(const std::optional<at::Tensor> &tensor) {
if (tensor.has_value()) {
return to_meta(*tensor);
}
Expand All @@ -76,9 +76,9 @@ c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor> &tensor) {
return outs;
}

c10::List<c10::optional<at::Tensor>>
to_meta(const c10::List<c10::optional<at::Tensor>> &t_list) {
c10::List<c10::optional<at::Tensor>> outs;
c10::List<std::optional<at::Tensor>>
to_meta(const c10::List<std::optional<at::Tensor>> &t_list) {
c10::List<std::optional<at::Tensor>> outs;
outs.reserve(t_list.size());
for (const auto &tensor : t_list) {
outs.push_back(to_meta(tensor));
Expand All @@ -94,16 +94,16 @@ namespace {

[[maybe_unused]] at::Tensor
CreateLtcTensor(const at::Tensor &tensor,
const c10::optional<torch::lazy::BackendDevice> &device) {
const std::optional<torch::lazy::BackendDevice> &device) {
if (tensor.defined() && device) {
return torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(tensor, *device));
}
return tensor;
}

[[maybe_unused]] c10::optional<torch::lazy::BackendDevice>
GetLtcDevice(const c10::optional<c10::Device> &device) {
[[maybe_unused]] std::optional<torch::lazy::BackendDevice>
GetLtcDevice(const std::optional<c10::Device> &device) {
if (!device) {
return c10::nullopt;
}
Expand Down Expand Up @@ -148,7 +148,7 @@ void copy_(torch::lazy::LazyTensorPtr &input, torch::lazy::LazyTensorPtr &src) {
// This should be safe to do, because every operator in the LT is functional.
at::Tensor
LazyNativeFunctions::clone(const at::Tensor &self,
c10::optional<at::MemoryFormat> memory_format) {
std::optional<at::MemoryFormat> memory_format) {
auto self_lt = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
Expand Down Expand Up @@ -234,10 +234,10 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor &self,
}

at::Tensor LazyNativeFunctions::_to_copy(
const at::Tensor &self, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory, bool non_blocking,
c10::optional<at::MemoryFormat> memory_format) {
const at::Tensor &self, std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout, std::optional<at::Device> device,
std::optional<bool> pin_memory, bool non_blocking,
std::optional<at::MemoryFormat> memory_format) {
PRINT_FUNCTION();
auto options = self.options();
if (dtype) {
Expand Down Expand Up @@ -482,7 +482,7 @@ LazyNativeFunctions::split_copy_symint(const at::Tensor &self,

at::Tensor LazyNativeFunctions::index(
const at::Tensor &self,
const c10::List<c10::optional<at::Tensor>> &indices) {
const c10::List<std::optional<at::Tensor>> &indices) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
Expand All @@ -491,7 +491,7 @@ at::Tensor LazyNativeFunctions::index(

std::vector<torch::lazy::Value> values;
for (const auto &it : indices) {
c10::optional<at::Tensor> tensor = it;
std::optional<at::Tensor> tensor = it;
LazyTensorPtr lazy_tensor =
torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
values.push_back(
Expand Down Expand Up @@ -532,7 +532,7 @@ at::Tensor LazyNativeFunctions::index(
}

at::Tensor LazyNativeFunctions::index_put(
const at::Tensor &self, const c10::List<c10::optional<at::Tensor>> &indices,
const at::Tensor &self, const c10::List<std::optional<at::Tensor>> &indices,
const at::Tensor &values, bool accumulate) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
Expand All @@ -544,7 +544,7 @@ at::Tensor LazyNativeFunctions::index_put(

std::vector<torch::lazy::Value> indices_vector;
for (const auto &it : indices) {
c10::optional<at::Tensor> tensor = it;
std::optional<at::Tensor> tensor = it;
LazyTensorPtr lazy_tensor =
torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
indices_vector.push_back(
Expand Down Expand Up @@ -616,9 +616,9 @@ at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
}
at::Tensor LazyNativeFunctions::new_empty_strided_symint(
const at::Tensor &self, c10::SymIntArrayRef size,
c10::SymIntArrayRef stride, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
c10::SymIntArrayRef stride, std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout, std::optional<at::Device> device,
std::optional<bool> pin_memory) {
if (!device || device->type() == c10::DeviceType::Lazy) {
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
new_empty_strided)>::call(self, size, stride, dtype, layout, device,
Expand All @@ -628,8 +628,8 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
// lazy_tensor.new_empty_strided(..., "cpu") we need to avoid explicit
// functionalization. To do that we create regular cpu tensors.
at::Tensor t = at::empty_symint(
size, (dtype ? dtype : c10::optional<at::ScalarType>(self.scalar_type())),
(layout ? layout : c10::optional<at::Layout>(self.layout())), device,
size, (dtype ? dtype : std::optional<at::ScalarType>(self.scalar_type())),
(layout ? layout : std::optional<at::Layout>(self.layout())), device,
pin_memory, c10::nullopt);
return t.as_strided_symint(size, stride, /*storage_offset=*/0);
}
Expand Down Expand Up @@ -679,8 +679,8 @@ at::Tensor LazyNativeFunctions::_trilinear(
unroll_dim);
}
at::Tensor LazyNativeFunctions::linalg_pinv(
const at::Tensor &self, const c10::optional<at::Tensor> &atol,
const c10::optional<at::Tensor> &rtol, bool hermitian) {
const at::Tensor &self, const std::optional<at::Tensor> &atol,
const std::optional<at::Tensor> &rtol, bool hermitian) {
return at::functionalization::functionalize_aten_op<ATEN_OP2(
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
}
Expand Down
2 changes: 1 addition & 1 deletion projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ c10::TensorType &cast_tensor_type(c10::TypePtr value_type) {
return *tensor_type.get();
}

c10::optional<std::vector<int64_t>>
std::optional<std::vector<int64_t>>
get_tensor_type_shape(c10::TensorType &tensor_type) {
auto &symbolic_shape = tensor_type.symbolic_sizes();
if (!symbolic_shape.rank()) {
Expand Down
20 changes: 10 additions & 10 deletions projects/ltc/csrc/base_lazy_backend/ops/to_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace lazy {
class ToCopy : public torch::lazy::TorchMlirNode {
public:
ToCopy(const torch::lazy::Value &self,
const c10::optional<at::ScalarType> &dtype,
const c10::optional<at::Layout> &layout,
const c10::optional<at::Device> &device,
const c10::optional<bool> &pin_memory, const bool &non_blocking,
const c10::optional<at::MemoryFormat> &memory_format,
const std::optional<at::ScalarType> &dtype,
const std::optional<at::Layout> &layout,
const std::optional<at::Device> &device,
const std::optional<bool> &pin_memory, const bool &non_blocking,
const std::optional<at::MemoryFormat> &memory_format,
std::vector<torch::lazy::Shape> &&shapes)
: torch::lazy::TorchMlirNode(
torch::lazy::OpKind(at::aten::_to_copy), {self}, std::move(shapes),
Expand Down Expand Up @@ -95,12 +95,12 @@ class ToCopy : public torch::lazy::TorchMlirNode {
return _to_copy_out;
}

c10::optional<at::ScalarType> dtype;
c10::optional<at::Layout> layout;
c10::optional<at::Device> device;
c10::optional<bool> pin_memory;
std::optional<at::ScalarType> dtype;
std::optional<at::Layout> layout;
std::optional<at::Device> device;
std::optional<bool> pin_memory;
bool non_blocking;
c10::optional<at::MemoryFormat> memory_format;
std::optional<at::MemoryFormat> memory_format;
};
} // namespace lazy
} // namespace torch
Loading

0 comments on commit 1cb98f2

Please sign in to comment.