diff --git a/fairseq2n/CMakeLists.txt b/fairseq2n/CMakeLists.txt index b351d11ee..8eb256b8b 100644 --- a/fairseq2n/CMakeLists.txt +++ b/fairseq2n/CMakeLists.txt @@ -101,6 +101,13 @@ option(FAIRSEQ2N_USE_CUDA OFF ) +option(FAIRSEQ2N_SUPPORT_IMAGE + #DESCRIPTION + "Supports JPEG/PNG decoding." + #VALUE + ON +) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") set(default_thread_lib tbb) else() @@ -160,6 +167,12 @@ find_package(SndFile 1.0.25 REQUIRED) find_package(Threads REQUIRED) +if(FAIRSEQ2N_SUPPORT_IMAGE) + find_package(JPEG REQUIRED) + + find_package(PNG REQUIRED) +endif() + if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") find_package(TBB 2021.8 REQUIRED) endif() diff --git a/fairseq2n/cmake/summary.cmake b/fairseq2n/cmake/summary.cmake index 61ea927ff..b12d47cc7 100644 --- a/fairseq2n/cmake/summary.cmake +++ b/fairseq2n/cmake/summary.cmake @@ -41,6 +41,7 @@ function(fairseq2n_print_project_summary) message(STATUS " CUDA NVCC : ${CUDAToolkit_NVCC_EXECUTABLE}") message(STATUS " CUDA Architectures : ${CMAKE_CUDA_ARCHITECTURES}") endif() + message(STATUS " FAIRSEQ2N_SUPPORT_IMAGE : ${FAIRSEQ2N_SUPPORT_IMAGE}") message(STATUS " FAIRSEQ2N_BUILD_PYTHON_BINDINGS : ${FAIRSEQ2N_BUILD_PYTHON_BINDINGS}") if(FAIRSEQ2N_BUILD_PYTHON_BINDINGS) message(STATUS " FAIRSEQ2N_PYTHON_DEVEL : ${FAIRSEQ2N_PYTHON_DEVEL}") @@ -50,6 +51,10 @@ function(fairseq2n_print_project_summary) if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") message(STATUS " Intel oneTBB : ${TBB_VERSION}") endif() + if(FAIRSEQ2N_SUPPORT_IMAGE) + message(STATUS " libjpeg : ${JPEG_VERSION}") + message(STATUS " libpng : ${PNG_VERSION_STRING}") + endif() message(STATUS " libsndfile : ${SndFile_VERSION}") message(STATUS "") endfunction() diff --git a/fairseq2n/python/src/fairseq2n/__init__.py b/fairseq2n/python/src/fairseq2n/__init__.py index 820884bdd..26394a664 100644 --- a/fairseq2n/python/src/fairseq2n/__init__.py +++ b/fairseq2n/python/src/fairseq2n/__init__.py @@ -132,6 +132,13 @@ def supports_cuda() -> bool: return _supports_cuda() # type: ignore[no-any-return] +def supports_image() -> bool: + """Return ``True`` if fairseq2n supports JPEG/PNG decoding.""" + from fairseq2n.bindings import _supports_image # type: ignore[attr-defined] + + return _supports_image() # type: ignore[no-any-return] + + def cuda_version() -> Optional[Tuple[int, int]]: """Return the version of CUDA that fairseq2n supports. diff --git a/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt b/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt index f9806c7d4..5db36c0ff 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt +++ b/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt @@ -17,9 +17,9 @@ target_sources(py_bindings init.cc memory.cc data/audio.cc - data/image.cc data/data_pipeline.cc data/init.cc + data/image.cc data/string.cc data/text/converters.cc data/text/init.cc diff --git a/fairseq2n/python/src/fairseq2n/bindings/init.cc b/fairseq2n/python/src/fairseq2n/bindings/init.cc index dea93b382..983a1bbe0 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/init.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/init.cc @@ -24,6 +24,13 @@ PYBIND11_MODULE(bindings, m) return supports_cuda; }); + m.def( + "_supports_image", + [] + { + return supports_image; + }); + // See https://github.com/llvm/llvm-project/issues/57123. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunreachable-code-return" diff --git a/fairseq2n/src/fairseq2n/CMakeLists.txt b/fairseq2n/src/fairseq2n/CMakeLists.txt index 2b00cc157..abee139ec 100644 --- a/fairseq2n/src/fairseq2n/CMakeLists.txt +++ b/fairseq2n/src/fairseq2n/CMakeLists.txt @@ -56,8 +56,6 @@ target_sources(fairseq2n data/detail/file.cc data/detail/file_system.cc data/image/image_decoder.cc - data/image/detail/jpeg_decompress_struct.cc - data/image/detail/png_read_struct.cc data/text/string_splitter.cc data/text/string_to_int_converter.cc data/text/string_to_tensor_converter.cc @@ -72,6 +70,14 @@ target_sources(fairseq2n data/text/sentencepiece/sp_processor.cc ) +if(FAIRSEQ2N_SUPPORT_IMAGE) + target_sources(fairseq2n + PRIVATE + data/image/detail/jpeg_decompress_struct.cc + data/image/detail/png_read_struct.cc + ) +endif() + if(FAIRSEQ2N_USE_CUDA) target_sources(fairseq2n PRIVATE @@ -87,6 +93,10 @@ if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") target_compile_definitions(fairseq2n PRIVATE FAIRSEQ2N_USE_TBB) endif() +if(FAIRSEQ2N_SUPPORT_IMAGE) + target_compile_definitions(fairseq2n PRIVATE FAIRSEQ2N_SUPPORT_IMAGE) +endif() + if(FAIRSEQ2N_USE_CUDA) target_compile_features(fairseq2n PRIVATE cuda_std_17) @@ -105,23 +115,18 @@ target_include_directories(fairseq2n ${system} $ ) -find_package(PNG REQUIRED) -find_package(JPEG REQUIRED) - target_link_libraries(fairseq2n PRIVATE ${CMAKE_DL_LIBS} PRIVATE - fmt::fmt Iconv::Iconv + SndFile::sndfile + Threads::Threads + fmt::fmt kaldi-native-fbank::core kuba-zip natsort - Threads::Threads sentencepiece-static - SndFile::sndfile - PNG::PNG - JPEG::JPEG PUBLIC torch ) @@ -130,11 +135,14 @@ if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") target_link_libraries(fairseq2n PRIVATE TBB::tbb) endif() +if(FAIRSEQ2N_SUPPORT_IMAGE) + target_link_libraries(fairseq2n PRIVATE JPEG::JPEG PNG::PNG) +endif() + if(FAIRSEQ2N_USE_CUDA) target_link_libraries(fairseq2n PRIVATE CUDA::cudart) endif() - fairseq2n_set_link_options(fairseq2n) set_target_properties(fairseq2n PROPERTIES diff --git a/fairseq2n/src/fairseq2n/config.h.in b/fairseq2n/src/fairseq2n/config.h.in index 06fd2cdd3..137336204 100644 --- a/fairseq2n/src/fairseq2n/config.h.in +++ b/fairseq2n/src/fairseq2n/config.h.in @@ -20,4 +20,10 @@ constexpr std::optional cuda_version_minor = @CUDA_VERSION_MINOR@; constexpr bool supports_cuda = cuda_version_major.has_value(); +#ifdef FAIRSEQ2N_SUPPORT_IMAGE +constexpr bool supports_image = true; +#else +constexpr bool supports_image = false; +#endif + } // namespace fairseq2n diff --git a/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h b/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h index 3b433e3a1..1cdcddeb5 100644 --- a/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h +++ b/fairseq2n/src/fairseq2n/data/image/detail/jpeg_decompress_struct.h @@ -7,8 +7,9 @@ #pragma once #include +#include // Forward declaration -using FILE = struct _IO_FILE; +//using FILE = struct _IO_FILE; #include #include "fairseq2n/exception.h" diff --git a/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc index 204249deb..29e40e38e 100644 --- a/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc +++ b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.cc @@ -11,7 +11,7 @@ namespace fairseq2n::detail { -png_read::png_read() : png_ptr(nullptr), info_ptr(nullptr) { +png_read::png_read() { png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); if (png_ptr == nullptr) { throw internal_error("Failed to create PNG read struct."); diff --git a/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h index 941b3a13c..3fdf2ad26 100644 --- a/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h +++ b/fairseq2n/src/fairseq2n/data/image/detail/png_read_struct.h @@ -20,8 +20,8 @@ class png_read{ png_read& operator=(const png_read&) = delete; private: - png_structp png_ptr{nullptr}; - png_infop info_ptr{nullptr}; + png_structp png_ptr{}; + png_infop info_ptr{}; }; } // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/image/image_decoder.cc b/fairseq2n/src/fairseq2n/data/image/image_decoder.cc index 31052da57..8639f4886 100644 --- a/fairseq2n/src/fairseq2n/data/image/image_decoder.cc +++ b/fairseq2n/src/fairseq2n/data/image/image_decoder.cc @@ -6,6 +6,7 @@ #include "fairseq2n/data/image/image_decoder.h" +#ifdef FAIRSEQ2N_SUPPORT_IMAGE #include #include #include @@ -13,6 +14,8 @@ #include #include #include +#include +#include #include "fairseq2n/exception.h" #include "fairseq2n/float.h" @@ -26,12 +29,12 @@ using namespace fairseq2n::detail; namespace fairseq2n { - + image_decoder::image_decoder(image_decoder_options opts) : opts_{opts} {} -bool +bool image_decoder::is_little_endian() { uint32_t x = 1; return (*reinterpret_cast(&x) == 1); @@ -43,55 +46,59 @@ image_decoder::operator()(data &&d) const if (!d.is_memory_block()) throw_( "The input data must be of type `memory_block`, but is of type `{}` instead.", d.type()); - + const memory_block &block = d.as_memory_block(); if (block.empty()) throw_( "The input memory block has zero length and cannot be decoded."); auto data_ptr = block.data(); - data output; - + + data output{}; + const std::array jpeg_signature = {255, 216, 255}; const std::array png_signature = {137, 80, 78, 71}; - if(std::memcmp(jpeg_signature.data(), data_ptr, jpeg_signature.size()) == 0) { + if(std::memcmp(jpeg_signature.data(), data_ptr, jpeg_signature.size()) == 0) return decode_jpeg(block); - } else if(std::memcmp(png_signature.data(), data_ptr, 4) == 0) { + + if(std::memcmp(png_signature.data(), data_ptr, png_signature.size()) == 0) return decode_png(block); - } + throw_( "Unsupported image file. Only jpeg and png are currently supported."); } data -image_decoder::decode_png(const memory_block &block) const +image_decoder::decode_png(const memory_block &block) const { - png_read pngReadStruct; + png_read pngReadStruct; png_structp png_ptr = pngReadStruct.getPngPtr(); png_infop info_ptr = pngReadStruct.getInfoPtr(); auto data_ptr = png_const_bytep(block.data()); auto data_len = block.size(); // If an error occurs, libpng will longjmp back to setjmp + // NOLINTNEXTLINE(cert-err52-cpp) if (setjmp(png_jmpbuf(png_ptr))) { throw_("libpng internal error."); } struct Reader { - png_const_bytep ptr; - png_size_t count; - Reader(png_const_bytep p, png_size_t c) : ptr(p), count(c) {} + png_const_bytep ptr; + png_size_t count; + Reader(png_const_bytep p, png_size_t c) : ptr(p), count(c) {} }; + Reader reader(data_ptr + 8, data_len - 8); - + auto read_callback = [](png_structp png_ptr2, png_bytep output, png_size_t bytes) { - auto reader = static_cast(png_get_io_ptr(png_ptr2)); - std::copy(reader->ptr, reader->ptr + bytes, output); - reader->ptr += bytes; - reader->count -= bytes; + auto reader = static_cast(png_get_io_ptr(png_ptr2)); + std::copy(reader->ptr, reader->ptr + bytes, output); + reader->ptr += bytes; + reader->count -= bytes; }; png_set_sig_bytes(png_ptr, 8); @@ -123,11 +130,11 @@ image_decoder::decode_png(const memory_block &block) const at::ScalarType dtype = bit_depth <= 8 ? at::kByte : at::kShort; at::Tensor image = at::empty({height, width, channels}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory())); - + size_t rowbytes = png_get_rowbytes(png_ptr, info_ptr); writable_memory_span image_bits = get_raw_mutable_storage(image); auto image_data = reinterpret_cast(image_bits.data()); - + // Read image data into tensor for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, image_data, nullptr); @@ -145,7 +152,7 @@ image_decoder::decode_png(const memory_block &block) const {"width", static_cast(width)}}; output.emplace("image", std::move(image)); - + return output; } @@ -157,7 +164,7 @@ image_decoder::decode_jpeg(const memory_block &block) const auto data_ptr = block.data(); auto data_len = block.size(); - + struct custom_error_mgr { struct jpeg_error_mgr pub; // Public fields jmp_buf setjmp_buffer; // Return to caller @@ -171,18 +178,20 @@ image_decoder::decode_jpeg(const memory_block &block) const auto myerr = reinterpret_cast(cinfo->err); (*cinfo->err->output_message)(cinfo); // Return control to the setjmp point + // NOLINTNEXTLINE(cert-err52-cpp) longjmp(myerr->setjmp_buffer, 1); }; // If an error occurs, error_exit will longjmp back to setjmp + // NOLINTNEXTLINE(cert-err52-cpp) if (setjmp(jerr.setjmp_buffer)) { throw_("JPEG decompression failed."); } - + //jpeg_create_decompress(&cinfo); jpeg_mem_src(&cinfo, reinterpret_cast(data_ptr), data_len); jpeg_read_header(&cinfo, TRUE); jpeg_start_decompress(&cinfo); - + auto width = cinfo.output_width; auto height = cinfo.output_height; auto channels = cinfo.output_components; @@ -204,14 +213,37 @@ image_decoder::decode_jpeg(const memory_block &block) const at::Device device = opts_.maybe_device().value_or(at::kCPU); if (device != at::kCPU) image = image.to(device); - + // Pack jpeg data and format as output. data_dict output{ {{"channels", static_cast(channels)}, {"height", static_cast(height)}, {"width", static_cast(width)}, {"bit_depth", static_cast(bit_depth)}}}; - + output.emplace("image", std::move(image)); - + return output; } + }; // namespace fairseq2n + +#else + +#include "fairseq2n/exception.h" +#include "fairseq2n/detail/exception.h" + +namespace fairseq2n { + +image_decoder::image_decoder(image_decoder_options opts) + : opts_{opts} +{} + +data +image_decoder::operator()(data &&) const +{ + detail::throw_( + "fairseq2n is not built with JPEG/PNG decoding support."); +} + +}; // namespace fairseq2n + +#endif diff --git a/fairseq2n/src/fairseq2n/data/image/image_decoder.h b/fairseq2n/src/fairseq2n/data/image/image_decoder.h index c4c0b297a..c5864804d 100644 --- a/fairseq2n/src/fairseq2n/data/image/image_decoder.h +++ b/fairseq2n/src/fairseq2n/data/image/image_decoder.h @@ -13,8 +13,6 @@ #include #include -#include -#include namespace fairseq2n { @@ -68,7 +66,7 @@ class FAIRSEQ2_API image_decoder { private: image_decoder_options opts_; - static bool + static bool is_little_endian(); data @@ -77,4 +75,5 @@ class FAIRSEQ2_API image_decoder { data decode_jpeg(const memory_block &block) const; }; + } // namespace fairseq2n diff --git a/tests/unit/data/image/test_image_decoder.py b/tests/unit/data/image/test_image_decoder.py index f83f79a20..aae6eeced 100644 --- a/tests/unit/data/image/test_image_decoder.py +++ b/tests/unit/data/image/test_image_decoder.py @@ -12,6 +12,7 @@ from fairseq2.data.image import ImageDecoder from fairseq2.memory import MemoryBlock +from fairseq2n import supports_image from tests.common import assert_close, device TEST_PNG_PATH: Final = Path(__file__).parent.joinpath("test.png") @@ -20,6 +21,9 @@ TEST_CORRUPT_PNG_PATH: Final = Path(__file__).parent.joinpath("test_corrupt.png") +@pytest.mark.skipif( + not supports_image(), reason="fairseq2n is not built with JPEG/PNG decoding support" +) class TestImageDecoder: def test_init(self) -> None: decoder = ImageDecoder()