Skip to content

Commit

Permalink
feat(connect): sql (#3696)
Browse files Browse the repository at this point in the history
## Description
adds spark's `df.createOrReplaceTempView` and `spark.sql` to enable sql
workflows.

ex:

```py
my_table = spark.read.parquet('my_table.parquet')
my_tablecreateOrReplaceTempView('my_table')

spark.sql('select * from my_table')
```

### Note for reviewers:

in order for it to be easier to interop with connect, I removed the
`SQLCatalog` and refactored the `SQLPlanner` to take in a
`DaftMetaCatalog` instead. This partially closes
#3586. I think we still need
to do some work on the python side of things though.
  • Loading branch information
universalmind303 authored Jan 17, 2025
1 parent 5549d16 commit 6b302af
Show file tree
Hide file tree
Showing 17 changed files with 281 additions and 96 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion daft/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
```python
df = daft.from_pydict({"foo": [1, 2, 3]})
daft.catalog.register_named_table(
daft.catalog.register_table(
"my_table",
df,
)
Expand Down
1 change: 1 addition & 0 deletions src/daft-catalog/python-catalog/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl DataCatalogTable for PythonTable {
}

/// Wrapper around a `daft.catalog.python_catalog.PythonCatalog`
#[derive(Debug)]
pub struct PythonCatalog {
python_catalog_pyobj: PyObject,
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-catalog/src/data_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{data_catalog_table::DataCatalogTable, errors::Result};
///
/// It allows registering and retrieving data sources, as well as querying their schemas.
/// The catalog is used by the query planner to resolve table references in queries.
pub trait DataCatalog: Sync + Send {
pub trait DataCatalog: Sync + Send + std::fmt::Debug {
/// Lists the fully-qualified names of tables in the catalog with the specified prefix
fn list_tables(&self, prefix: &str) -> Result<Vec<String>>;

Expand Down
47 changes: 32 additions & 15 deletions src/daft-catalog/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ pub mod global_catalog {

use lazy_static::lazy_static;

use crate::{DaftMetaCatalog, DataCatalog};
use crate::{DaftCatalog, DataCatalog};

lazy_static! {
pub(crate) static ref GLOBAL_DAFT_META_CATALOG: RwLock<DaftMetaCatalog> =
RwLock::new(DaftMetaCatalog::new_from_env());
pub(crate) static ref GLOBAL_DAFT_META_CATALOG: RwLock<DaftCatalog> =
RwLock::new(DaftCatalog::new_from_env());
}

/// Register a DataCatalog with the global DaftMetaCatalog
Expand All @@ -50,7 +50,8 @@ static DEFAULT_CATALOG_NAME: &str = "default";
///
/// Users of Daft can register various [`DataCatalog`] with Daft, enabling
/// discovery of tables across various [`DataCatalog`] implementations.
pub struct DaftMetaCatalog {
#[derive(Debug, Clone, Default)]
pub struct DaftCatalog {
/// Map of catalog names to the DataCatalog impls.
///
/// NOTE: The default catalog is always named "default"
Expand All @@ -60,11 +61,11 @@ pub struct DaftMetaCatalog {
named_tables: HashMap<String, LogicalPlanBuilder>,
}

impl DaftMetaCatalog {
impl DaftCatalog {
/// Create a `DaftMetaCatalog` from the current environment
pub fn new_from_env() -> Self {
// TODO: Parse a YAML file to produce the catalog
DaftMetaCatalog {
DaftCatalog {
data_catalogs: default::Default::default(),
named_tables: default::Default::default(),
}
Expand Down Expand Up @@ -95,16 +96,25 @@ impl DaftMetaCatalog {
}

/// Registers a LogicalPlan with a name in the DaftMetaCatalog
pub fn register_named_table(&mut self, name: &str, view: LogicalPlanBuilder) -> Result<()> {
pub fn register_table(
&mut self,
name: &str,
view: impl Into<LogicalPlanBuilder>,
) -> Result<()> {
if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(Error::InvalidTableName {
name: name.to_string(),
});
}
self.named_tables.insert(name.to_string(), view);
self.named_tables.insert(name.to_string(), view.into());
Ok(())
}

/// Check if a named table is registered in the DaftCatalog
pub fn contains_table(&self, name: &str) -> bool {
self.named_tables.contains_key(name)
}

/// Provides high-level functionality for reading a table of data against a [`DaftMetaCatalog`]
///
/// Resolves the provided table_identifier against the catalog:
Expand Down Expand Up @@ -146,6 +156,15 @@ impl DaftMetaCatalog {
table_id: searched_table_name.to_string(),
})
}
/// Copy from another catalog, using tables from other in case of conflict
pub fn copy_from(&mut self, other: &Self) {
for (name, plan) in &other.named_tables {
self.named_tables.insert(name.clone(), plan.clone());
}
for (name, catalog) in &other.data_catalogs {
self.data_catalogs.insert(name.clone(), catalog.clone());
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -181,26 +200,24 @@ mod tests {

#[test]
fn test_register_and_unregister_named_table() {
let mut catalog = DaftMetaCatalog::new_from_env();
let mut catalog = DaftCatalog::new_from_env();
let plan = LogicalPlanBuilder::from(mock_plan());

// Register a table
assert!(catalog
.register_named_table("test_table", plan.clone())
.is_ok());
assert!(catalog.register_table("test_table", plan.clone()).is_ok());

// Try to register a table with invalid name
assert!(catalog
.register_named_table("invalid name", plan.clone())
.register_table("invalid name", plan.clone())
.is_err());
}

#[test]
fn test_read_registered_table() {
let mut catalog = DaftMetaCatalog::new_from_env();
let mut catalog = DaftCatalog::new_from_env();
let plan = LogicalPlanBuilder::from(mock_plan());

catalog.register_named_table("test_table", plan).unwrap();
catalog.register_table("test_table", plan).unwrap();

assert!(catalog.read_table("test_table").is_ok());
assert!(catalog.read_table("non_existent_table").is_err());
Expand Down
2 changes: 1 addition & 1 deletion src/daft-catalog/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn py_register_table(
global_catalog::GLOBAL_DAFT_META_CATALOG
.write()
.unwrap()
.register_named_table(table_identifier, logical_plan.builder.clone())?;
.register_table(table_identifier, logical_plan.builder.clone())?;
Ok(table_identifier.to_string())
}

Expand Down
4 changes: 3 additions & 1 deletion src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ async-stream = "0.3.6"
common-daft-config = {workspace = true, optional = true, features = ["python"]}
common-error = {workspace = true, optional = true, features = ["python"]}
common-file-formats = {workspace = true, optional = true, features = ["python"]}
daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"]}
daft-core = {workspace = true, optional = true, features = ["python"]}
daft-dsl = {workspace = true, optional = true, features = ["python"]}
daft-local-execution = {workspace = true, optional = true, features = ["python"]}
Expand Down Expand Up @@ -43,7 +44,8 @@ python = [
"dep:daft-scan",
"dep:daft-schema",
"dep:daft-sql",
"dep:daft-table"
"dep:daft-table",
"dep:daft-catalog"
]

[lints]
Expand Down
16 changes: 11 additions & 5 deletions src/daft-connect/src/connect_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,23 @@ impl SparkConnectService for DaftSparkConnectService {
}
OpType::Command(command) => {
let command = command.command_type.required("command_type")?;

match command {
CommandType::WriteOperation(op) => {
let result = session.execute_write_operation(op, rb).await?;
Ok(Response::new(result))
}
CommandType::CreateDataframeView(create_dataframe) => {
let result = session
.execute_create_dataframe_view(create_dataframe, rb)
.await?;
Ok(Response::new(result))
}
CommandType::SqlCommand(sql) => {
let result = session.execute_sql_command(sql, rb).await?;
Ok(Response::new(result))
}
other => {
return not_yet_implemented!(
"Command type: {}",
command_type_to_str(&other)
)
not_yet_implemented!("CommandType '{:?}'", command_type_to_str(&other))
}
}
}
Expand Down
135 changes: 133 additions & 2 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use daft_logical_plan::LogicalPlanBuilder;
use daft_micropartition::MicroPartition;
use daft_ray_execution::RayEngine;
use daft_table::Table;
use eyre::bail;
use eyre::{bail, Context};
use futures::{
stream::{self, BoxStream},
StreamExt, TryFutureExt, TryStreamExt,
Expand All @@ -19,7 +19,8 @@ use pyo3::Python;
use spark_connect::{
relation::RelType,
write_operation::{SaveMode, SaveType},
ExecutePlanResponse, Relation, ShowString, WriteOperation,
CreateDataFrameViewCommand, ExecutePlanResponse, Relation, ShowString, SqlCommand,
WriteOperation,
};
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
use tracing::debug;
Expand Down Expand Up @@ -236,6 +237,136 @@ impl Session {
Ok(Box::pin(stream))
}

pub async fn execute_create_dataframe_view(
&self,
create_dataframe: CreateDataFrameViewCommand,
rb: ResponseBuilder<ExecutePlanResponse>,
) -> Result<ExecuteStream, Status> {
let CreateDataFrameViewCommand {
input,
name,
is_global,
replace,
} = create_dataframe;

if is_global {
return not_yet_implemented!("Global dataframe view");
}

let input = input.required("input")?;
let input = SparkAnalyzer::new(self)
.to_logical_plan(input)
.await
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)
})?;

{
let catalog = self.catalog.read().unwrap();
if !replace && catalog.contains_table(&name) {
return Err(Status::internal("Dataframe view already exists"));
}
}

let mut catalog = self.catalog.write().unwrap();

catalog.register_table(&name, input).map_err(|e| {
Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"))
})?;

let response = rb.result_complete_response();
let stream = stream::once(ready(Ok(response)));
Ok(Box::pin(stream))
}

#[allow(deprecated)]
pub async fn execute_sql_command(
&self,
SqlCommand {
sql,
args,
pos_args,
named_arguments,
pos_arguments,
input,
}: SqlCommand,
res: ResponseBuilder<ExecutePlanResponse>,
) -> Result<ExecuteStream, Status> {
if !args.is_empty() {
return not_yet_implemented!("Named arguments");
}
if !pos_args.is_empty() {
return not_yet_implemented!("Positional arguments");
}
if !named_arguments.is_empty() {
return not_yet_implemented!("Named arguments");
}
if !pos_arguments.is_empty() {
return not_yet_implemented!("Positional arguments");
}

if input.is_some() {
return not_yet_implemented!("Input");
}

let catalog = self.catalog.read().unwrap();
let catalog = catalog.clone();

let mut planner = daft_sql::SQLPlanner::new(catalog);

let plan = planner
.plan_sql(&sql)
.wrap_err("Error planning SQL")
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)
})?;

let plan = LogicalPlanBuilder::from(plan);

// TODO: code duplication
let result_complete = res.result_complete_response();

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(1);

let this = self.clone();

tokio::spawn(async move {
let execution_fut = async {
let mut result_stream = this.run_query(plan).await?;
while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;
for table in tables.as_slice() {
let response = res.arrow_batch_response(table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());
}
}
}
Ok(())
};
if let Err(e) = execution_fut.await {
let _ = tx.send(Err(e)).await;
}
});

let stream = ReceiverStream::new(rx);

let stream = stream
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)
})
.chain(stream::once(ready(Ok(result_complete))));

Ok(Box::pin(stream))
}

async fn show_string(
&self,
show_string: ShowString,
Expand Down
8 changes: 7 additions & 1 deletion src/daft-connect/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::{collections::BTreeMap, sync::Arc};
use std::{
collections::BTreeMap,
sync::{Arc, RwLock},
};

use daft_catalog::DaftCatalog;
use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use uuid::Uuid;

Expand All @@ -15,6 +19,7 @@ pub struct Session {
/// MicroPartitionSet associated with this session
/// this will be filled up as the user runs queries
pub(crate) psets: Arc<InMemoryPartitionSetCache>,
pub(crate) catalog: Arc<RwLock<DaftCatalog>>,
}

impl Session {
Expand All @@ -34,6 +39,7 @@ impl Session {
id,
server_side_session_id,
psets: Arc::new(InMemoryPartitionSetCache::empty()),
catalog: Arc::new(RwLock::new(DaftCatalog::default())),
}
}

Expand Down
Loading

0 comments on commit 6b302af

Please sign in to comment.