diff --git a/Cargo.toml b/Cargo.toml index a8e5933d2fe..fd79e4acd1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,7 +95,7 @@ strength_reduce = { version = "0.2", optional = true } multiversion = { version = "0.7.3", optional = true } # For support for odbc -odbc-api = { version = "0.36", optional = true } +odbc-api = { version = "4.1.0", optional = true } # Faster hashing ahash = "0.8" diff --git a/arrow-odbc-integration-testing/src/read.rs b/arrow-odbc-integration-testing/src/read.rs index a41b1388738..ad5a6c864ac 100644 --- a/arrow-odbc-integration-testing/src/read.rs +++ b/arrow-odbc-integration-testing/src/read.rs @@ -2,10 +2,10 @@ use stdext::function_name; use arrow2::array::{Array, BinaryArray, BooleanArray, Int32Array, Int64Array, Utf8Array}; use arrow2::chunk::Chunk; -use arrow2::datatypes::{DataType, Field, TimeUnit}; +use arrow2::datatypes::{DataType, TimeUnit}; use arrow2::error::Result; -use arrow2::io::odbc::api::{Connection, Cursor}; -use arrow2::io::odbc::read::{buffer_from_metadata, deserialize, infer_schema}; +use arrow2::io::odbc::api::ConnectionOptions; +use arrow2::io::odbc::read::Reader; use super::{setup_empty_table, ENV, MSSQL}; @@ -138,45 +138,18 @@ fn test( insert: &str, table_name: &str, ) -> Result<()> { - let connection = ENV.connect_with_connection_string(MSSQL).unwrap(); + let connection = ENV + .connect_with_connection_string(MSSQL, ConnectionOptions::default()) + .unwrap(); setup_empty_table(&connection, table_name, &[type_]).unwrap(); + connection .execute(&format!("INSERT INTO {table_name} (a) VALUES {insert}"), ()) .unwrap(); - // When let query = format!("SELECT a FROM {table_name} ORDER BY id"); - - let chunks = read(&connection, &query)?.1; + let chunks = Reader::new(MSSQL.to_string(), query, None, None).read()?; assert_eq!(chunks, expected); Ok(()) } - -pub fn read( - connection: &Connection<'_>, - query: &str, -) -> Result<(Vec, Vec>>)> { - let mut a = connection.prepare(query).unwrap(); - let fields = infer_schema(&a)?; - - let max_batch_size = 100; - let buffer = buffer_from_metadata(&a, max_batch_size).unwrap(); - - let cursor = a.execute(()).unwrap().unwrap(); - let mut cursor = cursor.bind_buffer(buffer).unwrap(); - - let mut chunks = vec![]; - while let Some(batch) = cursor.fetch().unwrap() { - let arrays = (0..batch.num_cols()) - .zip(fields.iter()) - .map(|(index, field)| { - let column_view = batch.column(index); - deserialize(column_view, field.data_type.clone()) - }) - .collect::>(); - chunks.push(Chunk::new(arrays)); - } - - Ok((fields, chunks)) -} diff --git a/arrow-odbc-integration-testing/src/write.rs b/arrow-odbc-integration-testing/src/write.rs index bcf12761abd..3a26b501dbc 100644 --- a/arrow-odbc-integration-testing/src/write.rs +++ b/arrow-odbc-integration-testing/src/write.rs @@ -4,36 +4,34 @@ use arrow2::array::{Array, BinaryArray, BooleanArray, Int32Array, Utf8Array}; use arrow2::chunk::Chunk; use arrow2::datatypes::{DataType, Field}; use arrow2::error::Result; -use arrow2::io::odbc::write::{buffer_from_description, infer_descriptions, serialize}; -use super::read::read; use super::{setup_empty_table, ENV, MSSQL}; +use arrow2::io::odbc::api::ConnectionOptions; +use arrow2::io::odbc::read::Reader; +use arrow2::io::odbc::write::Writer; + fn test( expected: Chunk>, - fields: Vec, + _fields: Vec, type_: &str, table_name: &str, ) -> Result<()> { - let connection = ENV.connect_with_connection_string(MSSQL).unwrap(); + let connection = ENV + .connect_with_connection_string(MSSQL, ConnectionOptions::default()) + .unwrap(); setup_empty_table(&connection, table_name, &[type_]).unwrap(); - let query = &format!("INSERT INTO {table_name} (a) VALUES (?)"); - let mut a = connection.prepare(query).unwrap(); - - let mut buffer = buffer_from_description(infer_descriptions(&fields)?, expected.len()); + let write_query = &format!("INSERT INTO {table_name} (a) VALUES (?)"); - // write - buffer.set_num_rows(expected.len()); - let array = &expected.columns()[0]; + let mut writer = Writer::new(MSSQL.to_string(), write_query.to_string(), None); - serialize(array.as_ref(), &mut buffer.column_mut(0))?; - - a.execute(&buffer).unwrap(); + writer.write(&expected)?; // read - let query = format!("SELECT a FROM {table_name} ORDER BY id"); - let chunks = read(&connection, &query)?.1; + let read_query = format!("SELECT a FROM {table_name} ORDER BY id"); + + let chunks = Reader::new(MSSQL.to_string(), read_query, None, None).read()?; assert_eq!(chunks[0], expected); Ok(()) diff --git a/examples/io_odbc.rs b/examples/io_odbc.rs deleted file mode 100644 index 9305fab6e24..00000000000 --- a/examples/io_odbc.rs +++ /dev/null @@ -1,83 +0,0 @@ -//! Demo of how to write to, and read from, an ODBC connector -//! -//! On an Ubuntu, you need to run the following (to install the driver): -//! ```bash -//! sudo apt install libsqliteodbc sqlite3 unixodbc-dev -//! sudo sed --in-place 's/libsqlite3odbc.so/\/usr\/lib\/x86_64-linux-gnu\/odbc\/libsqlite3odbc.so/' /etc/odbcinst.ini -//! ``` -use arrow2::array::{Array, Int32Array, Utf8Array}; -use arrow2::chunk::Chunk; -use arrow2::datatypes::{DataType, Field}; -use arrow2::error::Result; -use arrow2::io::odbc::api; -use arrow2::io::odbc::api::Cursor; -use arrow2::io::odbc::read; -use arrow2::io::odbc::write; - -fn main() -> Result<()> { - let connector = "Driver={SQLite3};Database=sqlite-test.db"; - let env = api::Environment::new()?; - let connection = env.connect_with_connection_string(connector)?; - - // let's create an empty table with a schema - connection.execute("DROP TABLE IF EXISTS example;", ())?; - connection.execute("CREATE TABLE example (c1 INT, c2 TEXT);", ())?; - - // and now let's write some data into it (from arrow arrays!) - // first, we prepare the statement - let query = "INSERT INTO example (c1, c2) VALUES (?, ?)"; - let prepared = connection.prepare(query).unwrap(); - - // secondly, we initialize buffers from odbc-api - let fields = vec![ - // (for now) the types here must match the tables' schema - Field::new("unused", DataType::Int32, true), - Field::new("unused", DataType::LargeUtf8, true), - ]; - - // third, we initialize the writer - let mut writer = write::Writer::try_new(prepared, fields)?; - - // say we have (or receive from a channel) a chunk: - let chunk = Chunk::new(vec![ - Box::new(Int32Array::from_slice([1, 2, 3])) as Box, - Box::new(Utf8Array::::from([Some("Hello"), None, Some("World")])), - ]); - - // we write it like this - writer.write(&chunk)?; - - // and we can later read from it - let chunks = read(&connection, "SELECT c1 FROM example")?; - - // and the result should be the same - assert_eq!(chunks[0].columns()[0], chunk.columns()[0]); - - Ok(()) -} - -/// Reads chunks from a query done against an ODBC connection -pub fn read(connection: &api::Connection<'_>, query: &str) -> Result>>> { - let mut a = connection.prepare(query)?; - let fields = read::infer_schema(&a)?; - - let max_batch_size = 100; - let buffer = read::buffer_from_metadata(&a, max_batch_size)?; - - let cursor = a.execute(())?.unwrap(); - let mut cursor = cursor.bind_buffer(buffer)?; - - let mut chunks = vec![]; - while let Some(batch) = cursor.fetch()? { - let arrays = (0..batch.num_cols()) - .zip(fields.iter()) - .map(|(index, field)| { - let column_view = batch.column(index); - read::deserialize(column_view, field.data_type.clone()) - }) - .collect::>(); - chunks.push(Chunk::new(arrays)); - } - - Ok(chunks) -} diff --git a/src/io/odbc/read/deserialize.rs b/src/io/odbc/read/deserialize.rs index be0a548e1a0..9aab1c0d27a 100644 --- a/src/io/odbc/read/deserialize.rs +++ b/src/io/odbc/read/deserialize.rs @@ -1,6 +1,4 @@ use chrono::{NaiveDate, NaiveDateTime}; -use odbc_api::buffers::{BinColumnView, TextColumnView}; -use odbc_api::Bit; use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; use crate::bitmap::{Bitmap, MutableBitmap}; @@ -9,77 +7,78 @@ use crate::datatypes::{DataType, TimeUnit}; use crate::offset::{Offsets, OffsetsBuffer}; use crate::types::NativeType; -use super::super::api::buffers::AnyColumnView; +use super::super::api::buffers::{AnySlice, BinColumnView, TextColumnView}; +use super::super::api::Bit; -/// Deserializes a [`AnyColumnView`] into an array of [`DataType`]. +/// Deserializes a [`AnySlice`] into an array of [`DataType`]. /// This is CPU-bounded -pub fn deserialize(column: AnyColumnView, data_type: DataType) -> Box { +pub fn deserialize(column: AnySlice, data_type: DataType) -> Box { match column { - AnyColumnView::Text(view) => Box::new(utf8(data_type, view)) as _, - AnyColumnView::WText(_) => todo!(), - AnyColumnView::Binary(view) => Box::new(binary(data_type, view)) as _, - AnyColumnView::Date(values) => Box::new(date(data_type, values)) as _, - AnyColumnView::Time(values) => Box::new(time(data_type, values)) as _, - AnyColumnView::Timestamp(values) => Box::new(timestamp(data_type, values)) as _, - AnyColumnView::F64(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::F32(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::I8(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::I16(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::I32(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::I64(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::U8(values) => Box::new(primitive(data_type, values)) as _, - AnyColumnView::Bit(values) => Box::new(bool(data_type, values)) as _, - AnyColumnView::NullableDate(slice) => Box::new(date_optional( + AnySlice::Text(view) => Box::new(utf8(data_type, view)) as _, + AnySlice::WText(_) => todo!(), + AnySlice::Binary(view) => Box::new(binary(data_type, view)) as _, + AnySlice::Date(values) => Box::new(date(data_type, values)) as _, + AnySlice::Time(values) => Box::new(time(data_type, values)) as _, + AnySlice::Timestamp(values) => Box::new(timestamp(data_type, values)) as _, + AnySlice::F64(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::F32(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::I8(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::I16(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::I32(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::I64(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::U8(values) => Box::new(primitive(data_type, values)) as _, + AnySlice::Bit(values) => Box::new(bool(data_type, values)) as _, + AnySlice::NullableDate(slice) => Box::new(date_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableTime(slice) => Box::new(time_optional( + AnySlice::NullableTime(slice) => Box::new(time_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableTimestamp(slice) => Box::new(timestamp_optional( + AnySlice::NullableTimestamp(slice) => Box::new(timestamp_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableF64(slice) => Box::new(primitive_optional( + AnySlice::NullableF64(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableF32(slice) => Box::new(primitive_optional( + AnySlice::NullableF32(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableI8(slice) => Box::new(primitive_optional( + AnySlice::NullableI8(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableI16(slice) => Box::new(primitive_optional( + AnySlice::NullableI16(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableI32(slice) => Box::new(primitive_optional( + AnySlice::NullableI32(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableI64(slice) => Box::new(primitive_optional( + AnySlice::NullableI64(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableU8(slice) => Box::new(primitive_optional( + AnySlice::NullableU8(slice) => Box::new(primitive_optional( data_type, slice.raw_values().0, slice.raw_values().1, )) as _, - AnyColumnView::NullableBit(slice) => Box::new(bool_optional( + AnySlice::NullableBit(slice) => Box::new(bool_optional( data_type, slice.raw_values().0, slice.raw_values().1, diff --git a/src/io/odbc/read/mod.rs b/src/io/odbc/read/mod.rs index e8945759c65..b54623fd202 100644 --- a/src/io/odbc/read/mod.rs +++ b/src/io/odbc/read/mod.rs @@ -5,33 +5,87 @@ mod schema; pub use deserialize::deserialize; pub use schema::infer_schema; -use super::api; +pub use super::api::buffers::{BufferDesc, ColumnarAnyBuffer}; +pub use super::api::ColumnDescription; +pub use super::api::Error; +pub use super::api::ResultSetMetadata; + +use crate::array::Array; +use crate::chunk::Chunk; +use crate::error::Result; +use crate::io::odbc::api::{Connection, ConnectionOptions, Cursor, Environment}; + +pub struct Reader { + connection_string: String, + query: String, + login_timeout_sec: Option, + max_batch_size: Option, +} + +impl Reader { + pub fn new( + connection_string: String, + query: String, + login_timeout_sec: Option, + max_batch_size: Option, + ) -> Self { + Self { + connection_string, + query, + login_timeout_sec, + max_batch_size, + } + } + + pub fn read(&self) -> Result>>> { + let env = Environment::new().unwrap(); + let conn: Connection = env + .connect_with_connection_string( + self.connection_string.as_str(), + ConnectionOptions { + login_timeout_sec: self.login_timeout_sec, + }, + ) + .unwrap(); + + let mut a = conn.prepare(self.query.as_str()).unwrap(); + let fields = infer_schema(&mut a)?; + + let buffer = buffer_from_metadata(&mut a, self.max_batch_size.unwrap_or(100)).unwrap(); + + let cursor = a.execute(()).unwrap().unwrap(); + let mut cursor = cursor.bind_buffer(buffer).unwrap(); + + let mut chunks = vec![]; + while let Some(batch) = cursor.fetch().unwrap() { + let arrays = (0..batch.num_cols()) + .zip(fields.iter()) + .map(|(index, field)| { + let column_view = batch.column(index); + deserialize(column_view, field.data_type.clone()) + }) + .collect::>(); + chunks.push(Chunk::new(arrays)); + } + + Ok(chunks) + } +} /// Creates a [`api::buffers::ColumnarBuffer`] from the metadata. /// # Errors /// Iff the driver provides an incorrect [`api::ResultSetMetadata`] pub fn buffer_from_metadata( - resut_set_metadata: &impl api::ResultSetMetadata, - max_batch_size: usize, -) -> std::result::Result, api::Error> { - let num_cols: u16 = resut_set_metadata.num_result_cols()? as u16; - - let descs = (0..num_cols) - .map(|index| { - let mut column_description = api::ColumnDescription::default(); - - resut_set_metadata.describe_col(index + 1, &mut column_description)?; - - Ok(api::buffers::BufferDescription { - nullable: column_description.could_be_nullable(), - kind: api::buffers::BufferKind::from_data_type(column_description.data_type) - .unwrap(), - }) - }) - .collect::, api::Error>>()?; - - Ok(api::buffers::buffer_from_description( - max_batch_size, - descs.into_iter(), - )) + result_set_metadata: &mut impl ResultSetMetadata, + capacity: usize, +) -> std::result::Result { + let num_cols: u16 = result_set_metadata.num_result_cols()? as u16; + + let descs = (1..=num_cols).map(|i| { + let mut col_desc = ColumnDescription::default(); + result_set_metadata.describe_col(i, &mut col_desc).unwrap(); + BufferDesc::from_data_type(col_desc.data_type, col_desc.could_be_nullable()).unwrap() + }); + + Ok(ColumnarAnyBuffer::from_descs(capacity, descs)) } diff --git a/src/io/odbc/read/schema.rs b/src/io/odbc/read/schema.rs index c679500b7ae..a97841368a8 100644 --- a/src/io/odbc/read/schema.rs +++ b/src/io/odbc/read/schema.rs @@ -5,13 +5,13 @@ use super::super::api; use super::super::api::ResultSetMetadata; /// Infers the Arrow [`Field`]s from a [`ResultSetMetadata`] -pub fn infer_schema(resut_set_metadata: &impl ResultSetMetadata) -> Result> { - let num_cols: u16 = resut_set_metadata.num_result_cols().unwrap() as u16; +pub fn infer_schema(result_set_metadata: &mut impl ResultSetMetadata) -> Result> { + let num_cols: u16 = result_set_metadata.num_result_cols().unwrap() as u16; let fields = (0..num_cols) .map(|index| { let mut column_description = api::ColumnDescription::default(); - resut_set_metadata + result_set_metadata .describe_col(index + 1, &mut column_description) .unwrap(); @@ -58,7 +58,9 @@ fn column_to_data_type(data_type: &api::DataType) -> DataType { OdbcDataType::BigInt => DataType::Int64, OdbcDataType::TinyInt => DataType::Int8, OdbcDataType::Bit => DataType::Boolean, - OdbcDataType::Binary { length } => DataType::FixedSizeBinary(*length), + OdbcDataType::Binary { length } => length + .map(|l| DataType::FixedSizeBinary(l.get())) + .unwrap_or(DataType::FixedSizeBinary(0)), OdbcDataType::LongVarbinary { length: _ } | OdbcDataType::Varbinary { length: _ } => { DataType::Binary } diff --git a/src/io/odbc/write/mod.rs b/src/io/odbc/write/mod.rs index 245f2455bb8..f71e5512217 100644 --- a/src/io/odbc/write/mod.rs +++ b/src/io/odbc/write/mod.rs @@ -2,70 +2,82 @@ mod schema; mod serialize; -use crate::{array::Array, chunk::Chunk, datatypes::Field, error::Result}; +use crate::{array::Array, chunk::Chunk, error::Result}; use super::api; -pub use schema::infer_descriptions; +use crate::io::odbc::api::{Connection, ConnectionOptions, Environment}; +pub use api::buffers::{BufferDesc, ColumnarAnyBuffer}; +pub use api::ColumnDescription; +pub use schema::data_type_to; pub use serialize::serialize; /// Creates a [`api::buffers::ColumnarBuffer`] from [`api::ColumnDescription`]s. /// /// This is useful when separating the serialization (CPU-bounded) to writing to the DB (IO-bounded). pub fn buffer_from_description( - descriptions: Vec, + descriptions: Vec, capacity: usize, -) -> api::buffers::ColumnarBuffer { - let descs = descriptions - .into_iter() - .map(|description| api::buffers::BufferDescription { - nullable: description.could_be_nullable(), - kind: api::buffers::BufferKind::from_data_type(description.data_type).unwrap(), - }); +) -> ColumnarAnyBuffer { + let descs = descriptions.into_iter().map(|description| { + BufferDesc::from_data_type(description.data_type, description.could_be_nullable()).unwrap() + }); - api::buffers::buffer_from_description(capacity, descs) + ColumnarAnyBuffer::from_descs(capacity, descs) } /// A writer of [`Chunk`]s to an ODBC [`api::Prepared`] statement. /// # Implementation /// This struct mixes CPU-bounded and IO-bounded tasks and is not ideal /// for an `async` context. -pub struct Writer<'a> { - fields: Vec, - buffer: api::buffers::ColumnarBuffer, - prepared: api::Prepared<'a>, +pub struct Writer { + connection_string: String, + query: String, + login_timeout_sec: Option, } -impl<'a> Writer<'a> { - /// Creates a new [`Writer`]. - /// # Errors - /// Errors iff any of the types from [`Field`] is not supported. - pub fn try_new(prepared: api::Prepared<'a>, fields: Vec) -> Result { - let buffer = buffer_from_description(infer_descriptions(&fields)?, 0); - Ok(Self { - fields, - buffer, - prepared, - }) +impl Writer { + pub fn new(connection_string: String, query: String, login_timeout_sec: Option) -> Self { + Self { + connection_string, + query, + login_timeout_sec, + } } /// Writes a chunk to the writer. /// # Errors /// Errors iff the execution of the statement fails. pub fn write>(&mut self, chunk: &Chunk) -> Result<()> { - if chunk.len() > self.buffer.num_rows() { - // if the chunk is larger, we re-allocate new buffers to hold it - self.buffer = buffer_from_description(infer_descriptions(&self.fields)?, chunk.len()); - } + let env = Environment::new().unwrap(); + + let conn: Connection = env + .connect_with_connection_string( + self.connection_string.as_str(), + ConnectionOptions { + login_timeout_sec: self.login_timeout_sec, + }, + ) + .unwrap(); + + let buf_descs = chunk.arrays().iter().map(|array| { + BufferDesc::from_data_type( + data_type_to(array.as_ref().data_type()).unwrap(), + array.as_ref().null_count() > 0, + ) + .unwrap() + }); - self.buffer.set_num_rows(chunk.len()); + let prepared = conn.prepare(self.query.as_str()).unwrap(); + let mut prebound = prepared + .into_column_inserter(chunk.len(), buf_descs) + .unwrap(); + prebound.set_num_rows(chunk.len()); - // serialize (CPU-bounded) for (i, column) in chunk.arrays().iter().enumerate() { - serialize(column.as_ref(), &mut self.buffer.column_mut(i))?; + serialize(column.as_ref(), &mut prebound.column_mut(i)).unwrap(); } + prebound.execute().unwrap(); - // write (IO-bounded) - self.prepared.execute(&self.buffer)?; Ok(()) } } diff --git a/src/io/odbc/write/schema.rs b/src/io/odbc/write/schema.rs index 5ac7ebfaf82..cffad7bcc1c 100644 --- a/src/io/odbc/write/schema.rs +++ b/src/io/odbc/write/schema.rs @@ -1,38 +1,26 @@ use super::super::api; -use crate::datatypes::{DataType, Field}; -use crate::error::{Error, Result}; +use std::num::NonZeroUsize; -/// Infers the [`api::ColumnDescription`] from the fields -pub fn infer_descriptions(fields: &[Field]) -> Result> { - fields - .iter() - .map(|field| { - let nullability = if field.is_nullable { - api::Nullability::Nullable - } else { - api::Nullability::NoNulls - }; - let data_type = data_type_to(field.data_type())?; - Ok(api::ColumnDescription { - name: api::U16String::from_str(&field.name).into_vec(), - nullability, - data_type, - }) - }) - .collect() -} +use crate::datatypes::DataType; +use crate::error::{Error, Result}; -fn data_type_to(data_type: &DataType) -> Result { +pub fn data_type_to(data_type: &DataType) -> Result { Ok(match data_type { DataType::Boolean => api::DataType::Bit, DataType::Int16 => api::DataType::SmallInt, DataType::Int32 => api::DataType::Integer, DataType::Float32 => api::DataType::Float { precision: 24 }, DataType::Float64 => api::DataType::Float { precision: 53 }, - DataType::FixedSizeBinary(length) => api::DataType::Binary { length: *length }, - DataType::Binary | DataType::LargeBinary => api::DataType::Varbinary { length: 0 }, - DataType::Utf8 | DataType::LargeUtf8 => api::DataType::Varchar { length: 0 }, + DataType::FixedSizeBinary(length) => api::DataType::Binary { + length: NonZeroUsize::new(*length), + }, + DataType::Binary | DataType::LargeBinary => api::DataType::Varbinary { + length: NonZeroUsize::new(1), + }, + DataType::Utf8 | DataType::LargeUtf8 => api::DataType::Varchar { + length: NonZeroUsize::new(1), + }, other => return Err(Error::nyi(format!("{other:?} to ODBC"))), }) } diff --git a/src/io/odbc/write/serialize.rs b/src/io/odbc/write/serialize.rs index f92326ba89c..babc5e73c77 100644 --- a/src/io/odbc/write/serialize.rs +++ b/src/io/odbc/write/serialize.rs @@ -1,5 +1,3 @@ -use api::buffers::{BinColumnWriter, TextColumnWriter}; - use crate::array::*; use crate::bitmap::Bitmap; use crate::datatypes::DataType; @@ -7,18 +5,18 @@ use crate::error::{Error, Result}; use crate::offset::Offset; use crate::types::NativeType; -use super::super::api; -use super::super::api::buffers::NullableSliceMut; +use odbc_api as api; +use odbc_api::buffers::{AnySliceMut, BinColumnSliceMut, NullableSliceMut, TextColumnSliceMut}; /// Serializes an [`Array`] to [`api::buffers::AnyColumnViewMut`] /// This operation is CPU-bounded -pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) -> Result<()> { +pub fn serialize(array: &dyn Array, column: &mut AnySliceMut) -> Result<()> { match array.data_type() { DataType::Boolean => { - if let api::buffers::AnyColumnViewMut::Bit(values) = column { + if let AnySliceMut::Bit(values) = column { bool(array.as_any().downcast_ref().unwrap(), values); Ok(()) - } else if let api::buffers::AnyColumnViewMut::NullableBit(values) = column { + } else if let AnySliceMut::NullableBit(values) = column { bool_optional(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -26,10 +24,10 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::Int16 => { - if let api::buffers::AnyColumnViewMut::I16(values) = column { + if let AnySliceMut::I16(values) = column { primitive(array.as_any().downcast_ref().unwrap(), values); Ok(()) - } else if let api::buffers::AnyColumnViewMut::NullableI16(values) = column { + } else if let AnySliceMut::NullableI16(values) = column { primitive_optional(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -37,10 +35,10 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::Int32 => { - if let api::buffers::AnyColumnViewMut::I32(values) = column { + if let AnySliceMut::I32(values) = column { primitive(array.as_any().downcast_ref().unwrap(), values); Ok(()) - } else if let api::buffers::AnyColumnViewMut::NullableI32(values) = column { + } else if let AnySliceMut::NullableI32(values) = column { primitive_optional(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -48,10 +46,10 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::Float32 => { - if let api::buffers::AnyColumnViewMut::F32(values) = column { + if let AnySliceMut::F32(values) = column { primitive(array.as_any().downcast_ref().unwrap(), values); Ok(()) - } else if let api::buffers::AnyColumnViewMut::NullableF32(values) = column { + } else if let AnySliceMut::NullableF32(values) = column { primitive_optional(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -59,10 +57,10 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::Float64 => { - if let api::buffers::AnyColumnViewMut::F64(values) = column { + if let AnySliceMut::F64(values) = column { primitive(array.as_any().downcast_ref().unwrap(), values); Ok(()) - } else if let api::buffers::AnyColumnViewMut::NullableF64(values) = column { + } else if let AnySliceMut::NullableF64(values) = column { primitive_optional(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -70,7 +68,7 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::Utf8 => { - if let api::buffers::AnyColumnViewMut::Text(values) = column { + if let AnySliceMut::Text(values) = column { utf8::(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -78,7 +76,7 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::LargeUtf8 => { - if let api::buffers::AnyColumnViewMut::Text(values) = column { + if let AnySliceMut::Text(values) = column { utf8::(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -86,7 +84,7 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::Binary => { - if let api::buffers::AnyColumnViewMut::Binary(values) = column { + if let AnySliceMut::Binary(values) = column { binary::(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -94,7 +92,7 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::LargeBinary => { - if let api::buffers::AnyColumnViewMut::Binary(values) = column { + if let AnySliceMut::Binary(values) = column { binary::(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -102,7 +100,7 @@ pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) } } DataType::FixedSizeBinary(_) => { - if let api::buffers::AnyColumnViewMut::Binary(values) = column { + if let AnySliceMut::Binary(values) = column { fixed_binary(array.as_any().downcast_ref().unwrap(), values); Ok(()) } else { @@ -152,12 +150,25 @@ fn primitive_optional(array: &PrimitiveArray, values: &mut Nul write_validity(array.validity(), indicators); } -fn fixed_binary(array: &FixedSizeBinaryArray, writer: &mut BinColumnWriter) { - writer.set_max_len(array.size()); - writer.write(array.iter()) +fn fixed_binary(array: &FixedSizeBinaryArray, writer: &mut BinColumnSliceMut) { + // Since the length of each elment is identical and fixed as `array.size`, + // we only need to reallocate and rebind the buffer once. + writer.ensure_max_element_length(array.size(), 0).unwrap(); + + for (row_index, value) in array + .values() + .chunks(array.size()) + .collect::>() + .iter() + .enumerate() + { + writer.set_cell(row_index, Some(value)); + } } -fn binary(array: &BinaryArray, writer: &mut BinColumnWriter) { +fn binary(array: &BinaryArray, writer: &mut BinColumnSliceMut) { + // Get the largest length from all the elements + let max_len = array .offsets() .buffer() @@ -165,11 +176,14 @@ fn binary(array: &BinaryArray, writer: &mut BinColumnWriter) { .map(|x| (x[1] - x[0]).to_usize()) .max() .unwrap_or(0); - writer.set_max_len(max_len); - writer.write(array.iter()) + + writer.ensure_max_element_length(max_len, 0).unwrap(); + + (0..array.offsets().len_proxy()) // loop index of each elements + .for_each(|row_idx| writer.set_cell(row_idx, array.get(row_idx))); } -fn utf8(array: &Utf8Array, writer: &mut TextColumnWriter) { +fn utf8(array: &Utf8Array, writer: &mut TextColumnSliceMut) { let max_len = array .offsets() .buffer() @@ -177,6 +191,8 @@ fn utf8(array: &Utf8Array, writer: &mut TextColumnWriter) { .map(|x| (x[1] - x[0]).to_usize()) .max() .unwrap_or(0); - writer.set_max_len(max_len); - writer.write(array.iter().map(|x| x.map(|x| x.as_bytes()))) + writer.ensure_max_element_length(max_len, 0).ok(); + + (0..array.offsets().len_proxy()) // loop index of each elements + .for_each(|row_idx| writer.set_cell(row_idx, array.get(row_idx).map(|s| s.as_bytes()))); }