diff --git a/.github/workflows/_build_wheel-macos.yaml b/.github/workflows/_build_wheel-macos.yaml index 1137d4878..c79f79a26 100644 --- a/.github/workflows/_build_wheel-macos.yaml +++ b/.github/workflows/_build_wheel-macos.yaml @@ -37,7 +37,7 @@ jobs: HOMEBREW_NO_INSTALL_UPGRADE: 1 HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK: 1 run: | - brew install libsndfile python@${{ inputs.py }} || true + brew install libjpeg libpng libsndfile python@${{ inputs.py }} || true - name: Create the Python virtual environment run: | /usr/local/bin/python${{ inputs.py }} -m venv ~/venv @@ -100,7 +100,7 @@ jobs: HOMEBREW_NO_INSTALL_UPGRADE: 1 HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK: 1 run: | - brew install libsndfile python@${{ inputs.py }} || true + brew install libjpeg libpng libsndfile python@${{ inputs.py }} || true - name: Download wheels and native tests from staging uses: actions/download-artifact@v3 with: diff --git a/INSTALL_FROM_SOURCE.md b/INSTALL_FROM_SOURCE.md index 40672b237..887f6f75d 100644 --- a/INSTALL_FROM_SOURCE.md +++ b/INSTALL_FROM_SOURCE.md @@ -53,21 +53,22 @@ reusing an existing one to avoid dependency conflicts. ## 3. Install Dependencies ### 3.1 System Dependencies -fairseq2 has a dependency on -[libsndfile](https://github.com/libsndfile/libsndfile) that can be installed via -the system package manager on most Linux distributions, or via Homebrew on +fairseq2 depends on [libjpeg](https://libjpeg.sourceforge.net), +[libpng](http://www.libpng.org/pub/png/libpng.html), and +[libsndfile](https://github.com/libsndfile/libsndfile), which can be installed +via the system package manager on most Linux distributions, or via Homebrew on macOS. For Ubuntu-based systems, run: ```sh -sudo apt install libsndfile-dev +sudo apt install libjpeg8-dev libpng-dev libsndfile-dev ``` Similarly, on Fedora, run: ```sh -sudo dnf install libsndfile-devel +sudo dnf install libjpeg-devel libpng-devel libsndfile-devel ``` For other Linux distributions, please consult its documentation on how to @@ -76,7 +77,7 @@ install packages. For macOS, you can use Homebrew: ```sh -brew install libsndfile +brew install libjpeb libpng libsndfile ``` ### 3.2 PyTorch diff --git a/README.md b/README.md index 97c982fe7..ee5271b90 100644 --- a/README.md +++ b/README.md @@ -59,19 +59,20 @@ fairseq2 is also used by various external projects such as: ## Installing on Linux ### System Dependencies -fairseq2 has a dependency on -[libsndfile](https://github.com/libsndfile/libsndfile) that can be installed via -the system package manager on most Linux distributions. For Ubuntu-based +fairseq2 depends on [libjpeg](https://libjpeg.sourceforge.net), +[libpng](http://www.libpng.org/pub/png/libpng.html), and +[libsndfile](https://github.com/libsndfile/libsndfile), which can be installed +via the system package manager on most Linux distributions. For Ubuntu-based systems, run: ```sh -sudo apt install libsndfile1 +sudo apt install libjpeg8 libpng16-16 libsndfile1 ``` Similarly, on Fedora, run: ```sh -sudo dnf install libsndfile +sudo dnf install libjpeg libpng libsndfile ``` For other Linux distributions, please consult its documentation on how to @@ -139,12 +140,13 @@ pip install fairseq2\ ## Installing on macOS ### System Dependencies -fairseq2 has a dependency on -[libsndfile](https://github.com/libsndfile/libsndfile) that can be installed via -Homebrew: +fairseq2 depends on [libjpeg](https://libjpeg.sourceforge.net), +[libpng](http://www.libpng.org/pub/png/libpng.html), and +[libsndfile](https://github.com/libsndfile/libsndfile), which can be installed +via Homebrew: ```sh -brew install libsndfile +brew install libjpeg libpng libsndfile ``` ### pip diff --git a/fairseq2n/CMakeLists.txt b/fairseq2n/CMakeLists.txt index b351d11ee..e7e9eef62 100644 --- a/fairseq2n/CMakeLists.txt +++ b/fairseq2n/CMakeLists.txt @@ -87,6 +87,13 @@ option(FAIRSEQ2N_TREAT_WARNINGS_AS_ERRORS OFF ) +option(FAIRSEQ2N_SUPPORT_IMAGE + #DESCRIPTION + "Supports JPEG/PNG decoding." + #VALUE + ON +) + option(FAIRSEQ2N_USE_LIBTORCH #DESCRIPTION "Uses libtorch instead of PyTorch." @@ -160,6 +167,12 @@ find_package(SndFile 1.0.25 REQUIRED) find_package(Threads REQUIRED) +if(FAIRSEQ2N_SUPPORT_IMAGE) + find_package(JPEG REQUIRED) + + find_package(PNG REQUIRED) +endif() + if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") find_package(TBB 2021.8 REQUIRED) endif() diff --git a/fairseq2n/cmake/summary.cmake b/fairseq2n/cmake/summary.cmake index 61ea927ff..f394b255d 100644 --- a/fairseq2n/cmake/summary.cmake +++ b/fairseq2n/cmake/summary.cmake @@ -34,6 +34,7 @@ function(fairseq2n_print_project_summary) message(STATUS " FAIRSEQ2N_SANITIZERS : ${FAIRSEQ2N_SANITIZERS}") endif() message(STATUS " FAIRSEQ2N_TREAT_WARNINGS_AS_ERRORS : ${FAIRSEQ2N_TREAT_WARNINGS_AS_ERRORS}") + message(STATUS " FAIRSEQ2N_SUPPORT_IMAGE : ${FAIRSEQ2N_SUPPORT_IMAGE}") message(STATUS " FAIRSEQ2N_USE_LIBTORCH : ${FAIRSEQ2N_USE_LIBTORCH}") message(STATUS " FAIRSEQ2N_USE_CUDA : ${FAIRSEQ2N_USE_CUDA}") if(FAIRSEQ2N_USE_CUDA) @@ -50,6 +51,10 @@ function(fairseq2n_print_project_summary) if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") message(STATUS " Intel oneTBB : ${TBB_VERSION}") endif() + if(FAIRSEQ2N_SUPPORT_IMAGE) + message(STATUS " libjpeg : ${JPEG_VERSION}") + message(STATUS " libpng : ${PNG_VERSION_STRING}") + endif() message(STATUS " libsndfile : ${SndFile_VERSION}") message(STATUS "") endfunction() diff --git a/fairseq2n/python/src/fairseq2n/__init__.py b/fairseq2n/python/src/fairseq2n/__init__.py index 820884bdd..26394a664 100644 --- a/fairseq2n/python/src/fairseq2n/__init__.py +++ b/fairseq2n/python/src/fairseq2n/__init__.py @@ -132,6 +132,13 @@ def supports_cuda() -> bool: return _supports_cuda() # type: ignore[no-any-return] +def supports_image() -> bool: + """Return ``True`` if fairseq2n supports JPEG/PNG decoding.""" + from fairseq2n.bindings import _supports_image # type: ignore[attr-defined] + + return _supports_image() # type: ignore[no-any-return] + + def cuda_version() -> Optional[Tuple[int, int]]: """Return the version of CUDA that fairseq2n supports. diff --git a/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt b/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt index 2a7d1e529..4c5b16cdc 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt +++ b/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt @@ -17,8 +17,10 @@ target_sources(py_bindings init.cc memory.cc data/audio.cc + data/image.cc data/data_pipeline.cc data/init.cc + data/image.cc data/string.cc data/text/converters.cc data/text/init.cc diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/image.cc b/fairseq2n/python/src/fairseq2n/bindings/data/image.cc new file mode 100644 index 000000000..7263a9b3d --- /dev/null +++ b/fairseq2n/python/src/fairseq2n/bindings/data/image.cc @@ -0,0 +1,45 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/bindings/module.h" + +#include +#include + +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace fairseq2n { + +void +def_image(py::module_ &data_module) +{ + py::module_ m = data_module.def_submodule("image"); + + // ImageDecoder + py::class_>(m, "ImageDecoder") + .def( + py::init([]( + std::optional maybe_device, + bool pin_memory) + { + auto opts = image_decoder_options() + .maybe_device(maybe_device).pin_memory(pin_memory); + + return std::make_shared(opts); + }), + py::arg("device") = std::nullopt, + py::arg("pin_memory") = false) + .def("__call__", &image_decoder::operator(), py::call_guard{}); + + map_functors().register_(); +} +} // namespace fairseq2n diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/init.cc b/fairseq2n/python/src/fairseq2n/bindings/data/init.cc index acceffc5b..f0408b262 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/init.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/init.cc @@ -40,6 +40,8 @@ def_data(py::module_ &base) def_audio(m); + def_image(m); + def_data_pipeline(m); def_string(m); diff --git a/fairseq2n/python/src/fairseq2n/bindings/init.cc b/fairseq2n/python/src/fairseq2n/bindings/init.cc index dea93b382..983a1bbe0 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/init.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/init.cc @@ -24,6 +24,13 @@ PYBIND11_MODULE(bindings, m) return supports_cuda; }); + m.def( + "_supports_image", + [] + { + return supports_image; + }); + // See https://github.com/llvm/llvm-project/issues/57123. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunreachable-code-return" diff --git a/fairseq2n/python/src/fairseq2n/bindings/module.h b/fairseq2n/python/src/fairseq2n/bindings/module.h index 2a79e9fdd..3cf686773 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/module.h +++ b/fairseq2n/python/src/fairseq2n/bindings/module.h @@ -22,6 +22,9 @@ namespace fairseq2n { void def_audio(pybind11::module_ &data_module); +void +def_image(pybind11::module_ &data_module); + void def_data(pybind11::module_ &base_module); diff --git a/fairseq2n/src/fairseq2n/CMakeLists.txt b/fairseq2n/src/fairseq2n/CMakeLists.txt index 9f752d58c..34d06c2bc 100644 --- a/fairseq2n/src/fairseq2n/CMakeLists.txt +++ b/fairseq2n/src/fairseq2n/CMakeLists.txt @@ -55,6 +55,7 @@ target_sources(fairseq2n data/audio/detail/sndfile.cc data/detail/file.cc data/detail/file_system.cc + data/image/image_decoder.cc data/text/string_splitter.cc data/text/string_to_int_converter.cc data/text/string_to_tensor_converter.cc @@ -69,6 +70,14 @@ target_sources(fairseq2n data/text/sentencepiece/sp_processor.cc ) +if(FAIRSEQ2N_SUPPORT_IMAGE) + target_sources(fairseq2n + PRIVATE + data/image/detail/jpeg_decompress_struct.cc + data/image/detail/png_read_struct.cc + ) +endif() + if(FAIRSEQ2N_USE_CUDA) target_sources(fairseq2n PRIVATE @@ -80,8 +89,8 @@ fairseq2n_set_compile_options(fairseq2n) target_compile_features(fairseq2n PUBLIC cxx_std_17) -if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") - target_compile_definitions(fairseq2n PRIVATE FAIRSEQ2N_USE_TBB) +if(FAIRSEQ2N_SUPPORT_IMAGE) + target_compile_definitions(fairseq2n PRIVATE FAIRSEQ2N_SUPPORT_IMAGE) endif() if(FAIRSEQ2N_USE_CUDA) @@ -90,6 +99,10 @@ if(FAIRSEQ2N_USE_CUDA) target_compile_definitions(fairseq2n PRIVATE FAIRSEQ2N_USE_CUDA) endif() +if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") + target_compile_definitions(fairseq2n PRIVATE FAIRSEQ2N_USE_TBB) +endif() + if(PROJECT_IS_TOP_LEVEL) set(system) else() @@ -102,6 +115,9 @@ target_include_directories(fairseq2n ${system} $ ) +find_package(PNG REQUIRED) +find_package(JPEG REQUIRED) + target_link_libraries(fairseq2n PRIVATE ${CMAKE_DL_LIBS} @@ -114,18 +130,23 @@ target_link_libraries(fairseq2n Threads::Threads sentencepiece-static SndFile::sndfile + PNG::PNG + JPEG::JPEG PUBLIC torch ) -if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") - target_link_libraries(fairseq2n PRIVATE TBB::tbb) +if(FAIRSEQ2N_SUPPORT_IMAGE) + target_link_libraries(fairseq2n PRIVATE JPEG::JPEG PNG::PNG) endif() if(FAIRSEQ2N_USE_CUDA) target_link_libraries(fairseq2n PRIVATE CUDA::cudart) endif() +if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") + target_link_libraries(fairseq2n PRIVATE TBB::tbb) +endif() fairseq2n_set_link_options(fairseq2n) @@ -174,10 +195,20 @@ install( # Library Configuration # ------------------------------------------------------------ +if(FAIRSEQ2N_SUPPORT_IMAGE) + set(SUPPORTS_IMAGE "true") +else() + set(SUPPORTS_IMAGE "false") +endif() + if(FAIRSEQ2N_USE_CUDA) + set(USES_CUDA "true") + set(CUDA_VERSION_MAJOR "${CUDAToolkit_VERSION_MAJOR}") set(CUDA_VERSION_MINOR "${CUDAToolkit_VERSION_MINOR}") else() + set(USES_CUDA "false") + set(CUDA_VERSION_MAJOR "std::nullopt") set(CUDA_VERSION_MINOR "std::nullopt") endif() diff --git a/fairseq2n/src/fairseq2n/config.h.in b/fairseq2n/src/fairseq2n/config.h.in index 06fd2cdd3..276bb8901 100644 --- a/fairseq2n/src/fairseq2n/config.h.in +++ b/fairseq2n/src/fairseq2n/config.h.in @@ -15,9 +15,11 @@ constexpr std::int32_t version_major = @PROJECT_VERSION_MAJOR@; constexpr std::int32_t version_minor = @PROJECT_VERSION_MINOR@; constexpr std::int32_t version_patch = @PROJECT_VERSION_PATCH@; +constexpr bool supports_image = @SUPPORTS_IMAGE@; + +constexpr bool supports_cuda = @USES_CUDA@; + constexpr std::optional cuda_version_major = @CUDA_VERSION_MAJOR@; constexpr std::optional cuda_version_minor = @CUDA_VERSION_MINOR@; -constexpr bool supports_cuda = cuda_version_major.has_value(); - } // namespace fairseq2n diff --git a/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.cc b/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.cc new file mode 100644 index 000000000..5d56112b9 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.cc @@ -0,0 +1,25 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates.error_ptr +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/image/detail/jpeg_decompress_struct.h" + +namespace fairseq2n::detail { + +jpeg_decompress::jpeg_decompress() : cinfo() { + jpeg_create_decompress(&cinfo); +} + +jpeg_decompress::~jpeg_decompress() { + if(cinfo.err != nullptr) { + jpeg_destroy_decompress(&cinfo); + } +} + +jpeg_decompress_struct& jpeg_decompress::get() { + return cinfo; +} + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h b/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h new file mode 100644 index 000000000..1cdcddeb5 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +// Forward declaration +//using FILE = struct _IO_FILE; +#include + +#include "fairseq2n/exception.h" +#include "fairseq2n/detail/exception.h" + +namespace fairseq2n::detail { + +class jpeg_decompress { +public: + jpeg_decompress(); + ~jpeg_decompress(); + jpeg_decompress_struct& get(); + jpeg_decompress(const jpeg_decompress&) = delete; + jpeg_decompress& operator=(const jpeg_decompress&) = delete; + +private: + jpeg_decompress_struct cinfo; +}; + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc new file mode 100644 index 000000000..29e40e38e --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc @@ -0,0 +1,40 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/image/detail/png_read_struct.h" + +#include "fairseq2n/exception.h" +#include "fairseq2n/detail/exception.h" + +namespace fairseq2n::detail { + +png_read::png_read() { + png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + if (png_ptr == nullptr) { + throw internal_error("Failed to create PNG read struct."); + } + info_ptr = png_create_info_struct(png_ptr); + if (info_ptr == nullptr) { + png_destroy_read_struct(&png_ptr, nullptr, nullptr); + throw internal_error("Failed to create PNG info struct."); + } +} + +png_read::~png_read() { + if (png_ptr != nullptr) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + } +} + +png_structp png_read::getPngPtr() const { + return png_ptr; +} + +png_infop png_read::getInfoPtr() const { + return info_ptr; +} + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h new file mode 100644 index 000000000..3fdf2ad26 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h @@ -0,0 +1,27 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace fairseq2n::detail { + +class png_read{ +public: + png_read(); + ~png_read(); + png_structp getPngPtr() const; + png_infop getInfoPtr() const; + png_read(const png_read&) = delete; + png_read& operator=(const png_read&) = delete; + +private: + png_structp png_ptr{}; + png_infop info_ptr{}; +}; + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/image/image_decoder.cc b/fairseq2n/src/fairseq2n/data/image/image_decoder.cc new file mode 100644 index 000000000..d2d9dd7cb --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/image/image_decoder.cc @@ -0,0 +1,251 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/image/image_decoder.h" + +#ifdef FAIRSEQ2N_SUPPORT_IMAGE +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "fairseq2n/exception.h" +#include "fairseq2n/float.h" +#include "fairseq2n/fmt.h" +#include "fairseq2n/memory.h" +#include "fairseq2n/data/image/detail/png_read_struct.h" +#include "fairseq2n/data/image/detail/jpeg_decompress_struct.h" +#include "fairseq2n/data/detail/tensor_helpers.h" +#include "fairseq2n/detail/exception.h" + +using namespace fairseq2n::detail; + +namespace fairseq2n { + +image_decoder::image_decoder(image_decoder_options opts) + : opts_{opts} +{} + +bool +image_decoder::is_little_endian() { + uint32_t x = 1; + return (*reinterpret_cast(&x) == 1); +} + +data +image_decoder::operator()(data &&d) const +{ + if (!d.is_memory_block()) + throw_( + "The input data must be of type `memory_block`, but is of type `{}` instead.", d.type()); + + const memory_block &block = d.as_memory_block(); + if (block.empty()) + throw_( + "The input memory block has zero length and cannot be decoded."); + + auto data_ptr = block.data(); + + data output{}; + + const std::array jpeg_signature = {255, 216, 255}; + const std::array png_signature = {137, 80, 78, 71}; + + if(std::memcmp(jpeg_signature.data(), data_ptr, jpeg_signature.size()) == 0) + return decode_jpeg(block); + + if(std::memcmp(png_signature.data(), data_ptr, png_signature.size()) == 0) + return decode_png(block); + + throw_( + "Unsupported image file. Only jpeg and png are currently supported."); +} + +data +image_decoder::decode_png(const memory_block &block) const +{ + png_read pngReadStruct; + png_structp png_ptr = pngReadStruct.getPngPtr(); + png_infop info_ptr = pngReadStruct.getInfoPtr(); + + auto data_ptr = png_const_bytep(block.data()); + auto data_len = block.size(); + // If an error occurs, libpng will longjmp back to setjmp + // NOLINTNEXTLINE(cert-err52-cpp) + if (setjmp(png_jmpbuf(png_ptr))) { + throw_("libpng internal error."); + } + + struct Reader { + png_const_bytep ptr; + png_size_t count; + Reader(png_const_bytep p, png_size_t c) : ptr(p), count(c) {} + }; + + Reader reader(data_ptr + 8, data_len - 8); + + auto read_callback = [](png_structp png_ptr2, + png_bytep output, + png_size_t bytes) { + auto reader = static_cast(png_get_io_ptr(png_ptr2)); + std::copy(reader->ptr, reader->ptr + bytes, output); + reader->ptr += bytes; + reader->count -= bytes; + }; + + png_set_sig_bytes(png_ptr, 8); + png_set_read_fn(png_ptr, &reader, read_callback); + png_read_info(png_ptr, info_ptr); + + png_uint_32 width=0, height=0; + int bit_depth=0, color_type=0; + int interlace_type=0; + auto retval = png_get_IHDR( + png_ptr, + info_ptr, + &width, + &height, + &bit_depth, + &color_type, + &interlace_type, + nullptr, + nullptr); + + if (retval != 1) { + throw_("Could not read image metadata from content."); + } + + if (is_little_endian()) { + png_set_swap(png_ptr); + } + int channels = png_get_channels(png_ptr, info_ptr); + + at::ScalarType dtype = bit_depth <= 8 ? at::kByte : at::kShort; + at::Tensor image = at::empty({height, width, channels}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory())); + + size_t rowbytes = png_get_rowbytes(png_ptr, info_ptr); + writable_memory_span image_bits = get_raw_mutable_storage(image); + auto image_data = reinterpret_cast(image_bits.data()); + + // Read image data into tensor + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, image_data, nullptr); + image_data += rowbytes; + } + + at::Device device = opts_.maybe_device().value_or(at::kCPU); + if (device != at::kCPU) + image = image.to(device); + + // Pack png data and format as output + data_dict output{ + {"bit_depth", static_cast(bit_depth)}, {"color_type", static_cast(color_type)}, + {"channels", static_cast(channels)}, {"height", static_cast(height)}, + {"width", static_cast(width)}}; + + output.emplace("image", std::move(image)); + + return output; +} + +data +image_decoder::decode_jpeg(const memory_block &block) const +{ + jpeg_decompress jpegDecompressStruct; + jpeg_decompress_struct cinfo = jpegDecompressStruct.get(); + + auto data_ptr = block.data(); + auto data_len = block.size(); + + struct custom_error_mgr { + struct jpeg_error_mgr pub; // Public fields + jmp_buf setjmp_buffer; // Return to caller + }; + struct custom_error_mgr jerr = {}; + using error_ptr = struct custom_error_mgr *; + cinfo.err = jpeg_std_error(&jerr.pub); + // error_exit is called by libjpeg when a fatal error occurs + jerr.pub.error_exit = [](j_common_ptr cinfo) { + // Coerce pointer to custom_error_mgr struct + auto myerr = reinterpret_cast(cinfo->err); + (*cinfo->err->output_message)(cinfo); + // Return control to the setjmp point + // NOLINTNEXTLINE(cert-err52-cpp) + longjmp(myerr->setjmp_buffer, 1); + }; + // If an error occurs, error_exit will longjmp back to setjmp + // NOLINTNEXTLINE(cert-err52-cpp) + if (setjmp(jerr.setjmp_buffer)) { + throw_("JPEG decompression failed."); + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto mutable_data_ptr = const_cast(data_ptr); + //jpeg_create_decompress(&cinfo); + jpeg_mem_src(&cinfo, reinterpret_cast(mutable_data_ptr), data_len); + jpeg_read_header(&cinfo, TRUE); + jpeg_start_decompress(&cinfo); + + auto width = cinfo.output_width; + auto height = cinfo.output_height; + auto channels = cinfo.output_components; + auto row_size = width * static_cast(channels); + int bit_depth = cinfo.data_precision; + + at::ScalarType dtype = bit_depth <= 8 ? at::kByte : at::kShort; + at::Tensor image = at::empty({height, width, channels}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory())); + writable_memory_span image_bits = get_raw_mutable_storage(image); + auto image_data = reinterpret_cast(image_bits.data()); + + // Read image into tensor + while (cinfo.output_scanline < cinfo.output_height) { + jpeg_read_scanlines(&cinfo, &image_data, 1); + image_data += row_size; + } + jpeg_finish_decompress(&cinfo); + + at::Device device = opts_.maybe_device().value_or(at::kCPU); + if (device != at::kCPU) + image = image.to(device); + + // Pack jpeg data and format as output. + data_dict output{ + {{"channels", static_cast(channels)}, {"height", static_cast(height)}, + {"width", static_cast(width)}, {"bit_depth", static_cast(bit_depth)}}}; + + output.emplace("image", std::move(image)); + + return output; +} + +}; // namespace fairseq2n + +#else + +#include "fairseq2n/exception.h" +#include "fairseq2n/detail/exception.h" + +namespace fairseq2n { + +image_decoder::image_decoder(image_decoder_options opts) + : opts_{opts} +{} + +data +image_decoder::operator()(data &&) const +{ + detail::throw_( + "fairseq2n is not built with JPEG/PNG decoding support."); +} + +}; // namespace fairseq2n + +#endif diff --git a/fairseq2n/src/fairseq2n/data/image/image_decoder.h b/fairseq2n/src/fairseq2n/data/image/image_decoder.h new file mode 100644 index 000000000..c5864804d --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/image/image_decoder.h @@ -0,0 +1,79 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include "fairseq2n/api.h" +#include "fairseq2n/data/data.h" + +#include +#include + +namespace fairseq2n { + +class image_decoder_options { +public: + image_decoder_options + maybe_device(std::optional value) noexcept + { + auto tmp = *this; + + tmp.maybe_device_ = value; + + return tmp; + } + + std::optional + maybe_device() const noexcept + { + return maybe_device_; + } + + image_decoder_options + pin_memory(bool value) noexcept + { + auto tmp = *this; + + tmp.pin_memory_ = value; + + return tmp; + } + + bool + pin_memory() const noexcept + { + return pin_memory_; + } + +private: + std::optional maybe_device_{}; + bool pin_memory_ = false; +}; + +class FAIRSEQ2_API image_decoder { +public: + explicit + image_decoder(image_decoder_options opts = {}); + + data + operator()(data &&d) const; + +private: + image_decoder_options opts_; + + static bool + is_little_endian(); + + data + decode_png(const memory_block &block) const; + + data + decode_jpeg(const memory_block &block) const; +}; + +} // namespace fairseq2n diff --git a/src/fairseq2/data/image.py b/src/fairseq2/data/image.py new file mode 100644 index 000000000..2b54032f9 --- /dev/null +++ b/src/fairseq2/data/image.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import TYPE_CHECKING, Optional, TypedDict + +from torch import Tensor + +from fairseq2 import _DOC_MODE +from fairseq2.memory import MemoryBlock +from fairseq2.typing import Device + +if TYPE_CHECKING or _DOC_MODE: + + class ImageDecoder: + def __init__( + self, + device: Optional[Device] = None, + pin_memory: bool = False, + ) -> None: + ... + + def __call__(self, memory_block: MemoryBlock) -> "ImageDecoderOutput": + ... + +else: + from fairseq2n.bindings.data.image import ImageDecoder as ImageDecoder + + def _set_module_name() -> None: + for t in [ImageDecoder]: + t.__module__ = __name__ + + _set_module_name() + + +class ImageDecoderOutput(TypedDict): + bit_depth: float + color_type: float + channels: float + height: float + width: float + image: Tensor diff --git a/tests/unit/data/image/__init__.py b/tests/unit/data/image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/image/test.jpg b/tests/unit/data/image/test.jpg new file mode 100644 index 000000000..a3052b7c3 Binary files /dev/null and b/tests/unit/data/image/test.jpg differ diff --git a/tests/unit/data/image/test.png b/tests/unit/data/image/test.png new file mode 100644 index 000000000..5717d71d3 Binary files /dev/null and b/tests/unit/data/image/test.png differ diff --git a/tests/unit/data/image/test_corrupt.jpg b/tests/unit/data/image/test_corrupt.jpg new file mode 100644 index 000000000..420af2251 Binary files /dev/null and b/tests/unit/data/image/test_corrupt.jpg differ diff --git a/tests/unit/data/image/test_corrupt.png b/tests/unit/data/image/test_corrupt.png new file mode 100644 index 000000000..ee6292deb Binary files /dev/null and b/tests/unit/data/image/test_corrupt.png differ diff --git a/tests/unit/data/image/test_image_decoder.py b/tests/unit/data/image/test_image_decoder.py new file mode 100644 index 000000000..df613c365 --- /dev/null +++ b/tests/unit/data/image/test_image_decoder.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Any, Final + +import pytest +import torch + +from fairseq2.data.image import ImageDecoder +from fairseq2.memory import MemoryBlock +from fairseq2n import supports_image +from tests.common import assert_close, device + +TEST_PNG_PATH: Final = Path(__file__).parent.joinpath("test.png") +TEST_JPG_PATH: Final = Path(__file__).parent.joinpath("test.jpg") +TEST_CORRUPT_JPG_PATH: Final = Path(__file__).parent.joinpath("test_corrupt.jpg") +TEST_CORRUPT_PNG_PATH: Final = Path(__file__).parent.joinpath("test_corrupt.png") + + +@pytest.mark.skipif( + not supports_image(), reason="fairseq2n is not built with JPEG/PNG decoding support" +) +class TestImageDecoder: + def test_init(self) -> None: + decoder = ImageDecoder() + assert isinstance(decoder, ImageDecoder) + + def test_call_works_on_png(self) -> None: + decoder = ImageDecoder(device=device) + + with TEST_PNG_PATH.open("rb") as fb: + block = MemoryBlock(fb.read()) + + output = decoder(block) + + assert output["bit_depth"] == 8.0 + + assert output["color_type"] == 6.0 + + assert output["channels"] == 4.0 + + assert output["height"] == 70.0 + + assert output["width"] == 70.0 + + image = output["image"] + + assert image.shape == torch.Size([70, 70, 4]) + + assert image.dtype == torch.uint8 + + assert image.device == device + + assert_close(image.sum(), torch.tensor(4656924, device=device)) + + def test_call_works_on_jpg(self) -> None: + decoder = ImageDecoder(device=device) + + with TEST_JPG_PATH.open("rb") as fb: + block = MemoryBlock(fb.read()) + + output = decoder(block) + + assert output["bit_depth"] == 8.0 + + assert output["channels"] == 3.0 + + assert output["height"] == 50.0 + + assert output["width"] == 50.0 + + image = output["image"] + + assert image.shape == torch.Size([50, 50, 3]) + + assert image.dtype == torch.uint8 + + assert image.device == device + + assert_close(image.sum(), torch.tensor(1747686, device=device)) + + def test_call_raises_error_when_input_is_corrupted_png(self) -> None: + decoder = ImageDecoder(device=device) + + with TEST_CORRUPT_PNG_PATH.open("rb") as fb: + block = MemoryBlock(fb.read()) + + with pytest.raises( + RuntimeError, + match="libpng internal error.", + ): + decoder(block) + + def test_call_raises_error_when_input_is_corrupted_jpg(self) -> None: + decoder = ImageDecoder(device=device) + + with TEST_CORRUPT_JPG_PATH.open("rb") as fb: + block = MemoryBlock(fb.read()) + + with pytest.raises( + RuntimeError, + match="JPEG decompression failed.", + ): + decoder(block) + + @pytest.mark.parametrize( + "value,type_name", [(None, "pyobj"), (123, "int"), ("s", "string")] + ) + def test_call_raises_error_when_input_is_not_memory_block( + self, value: Any, type_name: str + ) -> None: + decoder = ImageDecoder() + + with pytest.raises( + ValueError, + match=rf"^The input data must be of type `memory_block`, but is of type `{type_name}` instead\.$", + ): + decoder(value) + + def test_call_raises_error_when_input_is_empty(self) -> None: + decoder = ImageDecoder() + + empty_block = MemoryBlock() + + with pytest.raises( + ValueError, + match=r"^The input memory block has zero length and cannot be decoded\.$", + ): + decoder(empty_block) + + def test_call_raises_error_when_input_is_invalid(self) -> None: + decoder = ImageDecoder() + + block = MemoryBlock(b"foo") + + with pytest.raises( + ValueError, + match=r"^Unsupported image file. Only jpeg and png are currently supported\.$", + ): + decoder(block)